move RLBootstrapDataset into a own class with its own test file

This commit is contained in:
2026-04-05 00:54:15 +02:00
parent eb290dd634
commit 066a93f755
4 changed files with 237 additions and 60 deletions
+124
View File
@@ -0,0 +1,124 @@
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import gzip, json, os
class RLBootstrapDataset:
def __init__(self):
self.enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
self.min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
self.base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
self.output_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
self.max_bytes = int(float(os.getenv("RL_BOOTSTRAP_MAX_MB", "50")) * 1024 * 1024)
self.needs_more_data = False
@staticmethod
def _env_bool(name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
@staticmethod
def _env_int(name:str, default:int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
return default
@staticmethod
def count_jsonl_rows(path:Path) -> int:
count = 0
candidates = [path]
if path.suffix == ".gz":
candidates.append(path.with_suffix(""))
else:
candidates.append(Path(f"{path}.gz"))
seen = set()
for candidate in candidates:
if candidate in seen:
continue
seen.add(candidate)
if not candidate.exists() or not candidate.is_file():
continue
open_fn = gzip.open if candidate.suffix == ".gz" else open
try:
with open_fn(candidate, "rt", encoding="utf-8") as handle:
for line in handle:
if line.strip():
count += 1
except (OSError, UnicodeError):
continue
return count
@staticmethod
def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool:
if max_bytes <= 0:
return False
if not path.exists() or not path.is_file():
return False
if path.stat().st_size < max_bytes:
return False
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}"
rotated_path = path.with_name(rotated_name)
suffix = 1
while rotated_path.exists():
suffix += 1
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}.{suffix}"
rotated_path = path.with_name(rotated_name)
path.rename(rotated_path)
with rotated_path.open("rb") as src:
with gzip.open(f"{rotated_path}.gz", "wb") as dst:
dst.writelines(src)
rotated_path.unlink()
return True
def refresh_state(self):
if not self.enabled:
self.needs_more_data = False
return
base_rows = self.count_jsonl_rows(self.base_dataset_path)
self.needs_more_data = base_rows < self.min_base_rows
def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None,):
if not self.enabled or not self.needs_more_data:
return
try:
self.output_path.parent.mkdir(parents=True, exist_ok=True)
row = {
"source": "best_battlesnake_bootstrap",
"game_id": getattr(game_data, "id", None),
"turn": game_data.get_turn(),
"move": move,
"safe_moves": list(safe_moves.keys()),
"reason": reason,
"game_board": game_data.get_game_board_as_dict(),
}
if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
with self.output_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
self.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes)
except Exception:
return
+1
View File
@@ -2,3 +2,4 @@ from .Dataset import Dataset
from .DatasetExporter import DatasetExporter from .DatasetExporter import DatasetExporter
from .DatasetCurator import DatasetCurator from .DatasetCurator import DatasetCurator
from .DatasetStats import DatasetStats from .DatasetStats import DatasetStats
from .RLBootstrapDataset import RLBootstrapDataset
+11 -60
View File
@@ -2,14 +2,15 @@ from collections.abc import Iterator
from collections import deque from collections import deque
from typing import Any, cast from typing import Any, cast
from time import perf_counter from time import perf_counter
from pathlib import Path import random, os
import random, json, os
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
from snakes.TemplateSnake import TemplateSnake from snakes.TemplateSnake import TemplateSnake
from server.GameBoard import GameBoard from server.GameBoard import GameBoard
class BestBattleSnake(TemplateSnake): class BestBattleSnake(TemplateSnake):
VERSION = "2.6.0" VERSION = "2.6.1"
Point = tuple[int, int] Point = tuple[int, int]
Coord = dict[str, int] Coord = dict[str, int]
SnakeState = dict[str, Any] SnakeState = dict[str, Any]
@@ -40,11 +41,7 @@ class BestBattleSnake(TemplateSnake):
self.previous_hazards = set() self.previous_hazards = set()
self.duel_style = self._get_duel_style() self.duel_style = self._get_duel_style()
self.timeout_buffer_ms = self._get_timeout_buffer_ms() self.timeout_buffer_ms = self._get_timeout_buffer_ms()
self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False) self.rl_bootstrap = RLBootstrapDataset()
self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
self.rl_needs_more_data = False
self.future_planning_depth = max(1, min(4, self._env_int("BATTLE_FUTURE_PLANNING_DEPTH", default=2))) self.future_planning_depth = max(1, min(4, self._env_int("BATTLE_FUTURE_PLANNING_DEPTH", default=2)))
self.future_planning_branch = max(1, min(3, self._env_int("BATTLE_FUTURE_PLANNING_BRANCH", default=2))) self.future_planning_branch = max(1, min(3, self._env_int("BATTLE_FUTURE_PLANNING_BRANCH", default=2)))
self.future_planning_min_time_ms = max(25, self._env_int("BATTLE_FUTURE_PLANNING_MIN_MS", default=70)) self.future_planning_min_time_ms = max(25, self._env_int("BATTLE_FUTURE_PLANNING_MIN_MS", default=70))
@@ -88,12 +85,6 @@ class BestBattleSnake(TemplateSnake):
except ValueError: except ValueError:
return 120 return 120
def _env_bool(self, name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def _env_int(self, name:str, default:int) -> int: def _env_int(self, name:str, default:int) -> int:
value = os.getenv(name) value = os.getenv(name)
if value is None: if value is None:
@@ -103,46 +94,6 @@ class BestBattleSnake(TemplateSnake):
except ValueError: except ValueError:
return default return default
def _count_jsonl_rows(self, path:Path) -> int:
if not path.exists() or not path.is_file():
return 0
count = 0
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
count += 1
return count
def _refresh_rl_bootstrap_state(self):
if not self.rl_bootstrap_enabled:
self.rl_needs_more_data = False
return
base_rows = self._count_jsonl_rows(self.rl_base_dataset_path)
self.rl_needs_more_data = base_rows < self.rl_min_base_rows
def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None):
if not self.rl_bootstrap_enabled or not self.rl_needs_more_data:
return
try:
self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
row = {
"source": "best_battlesnake_bootstrap",
"game_id": getattr(game_data, "id", None),
"turn": game_data.get_turn(),
"move": move,
"safe_moves": list(safe_moves.keys()),
"reason": reason,
"game_board": game_data.get_game_board_as_dict(),
}
if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
except Exception:
return
def choose_move(self, game_data:GameBoard) -> str: def choose_move(self, game_data:GameBoard) -> str:
"""Pick the next move from a Battlesnake move request. """Pick the next move from a Battlesnake move request.
@@ -162,7 +113,7 @@ class BestBattleSnake(TemplateSnake):
self.last_move = None self.last_move = None
self.previous_hazards = set() self.previous_hazards = set()
self.last_game_id = game_id self.last_game_id = game_id
self._refresh_rl_bootstrap_state() self.rl_bootstrap.refresh_state()
my_snake = cast(dict[str, Any], game_data.get_my_snake()) my_snake = cast(dict[str, Any], game_data.get_my_snake())
my_head = my_snake["head"] my_head = my_snake["head"]
@@ -215,7 +166,7 @@ class BestBattleSnake(TemplateSnake):
"reason": "no_safe_moves", "reason": "no_safe_moves",
} }
) )
self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves") self.rl_bootstrap.record_sample(game_data, fallback, safe_moves, "no_safe_moves")
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return fallback return fallback
@@ -246,7 +197,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores) self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "constrictor", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
@@ -270,7 +221,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores) self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "duel", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
@@ -418,7 +369,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = quick_move self.last_move = quick_move
self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"}) self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"})
self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget") self.rl_bootstrap.record_sample(game_data, quick_move, safe_moves, "timeout_budget")
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return quick_move return quick_move
@@ -462,7 +413,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores) self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "multi", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
+101
View File
@@ -0,0 +1,101 @@
import unittest
from unittest.mock import patch
from pathlib import Path
import tempfile, gzip
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
class TestRLBootstrapDataset(unittest.TestCase):
def test_count_jsonl_rows_reads_gzip_dataset(self):
with tempfile.TemporaryDirectory() as tmp:
dataset_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
handle.write("\n")
handle.write('{"turn":2}\n')
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2)
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rl_bootstrap.jsonl"
path.write_text("x" * 200, encoding="utf-8")
rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached(
path, max_bytes=50
)
self.assertTrue(rotated)
self.assertFalse(path.exists())
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
with tempfile.TemporaryDirectory() as tmp:
base_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_MIN_BASE_ROWS": "2",
"RL_BASE_DATASET": str(base_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.refresh_state()
self.assertTrue(dataset.needs_more_data)
def test_rl_bootstrap_dataset_autocompresses_output(self):
class DummyBoard:
id = "game-1"
def get_turn(self):
return 1
def get_game_board_as_dict(self):
return {
"turn": 1,
"board": {
"width": 11,
"height": 11,
"food": [],
"hazards": [],
"snakes": [],
},
"you": {
"id": "me",
"head": {"x": 5, "y": 5},
"body": [{"x": 5, "y": 5}],
},
}
with tempfile.TemporaryDirectory() as tmp:
output_path = Path(tmp) / "rl_bootstrap.jsonl"
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_BOOTSTRAP_MAX_MB": "0.00001",
"RL_BOOTSTRAP_OUTPUT": str(output_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.needs_more_data = True
dataset.record_sample(
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
)
dataset.record_sample(
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
)
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
if __name__ == "__main__":
unittest.main()