From 066a93f75563a08f2cbd7cf0144364f7f63d5b31 Mon Sep 17 00:00:00 2001 From: Daniel Dolezal Date: Sun, 5 Apr 2026 00:54:15 +0200 Subject: [PATCH] move RLBootstrapDataset into a own class with its own test file --- server/dataset/RLBootstrapDataset.py | 124 +++++++++++++++++++++++++++ server/dataset/__init__.py | 1 + snakes/BestBattleSnake.py | 71 +++------------ tests/test_RLBootstrapDataset.py | 101 ++++++++++++++++++++++ 4 files changed, 237 insertions(+), 60 deletions(-) create mode 100644 server/dataset/RLBootstrapDataset.py create mode 100644 tests/test_RLBootstrapDataset.py diff --git a/server/dataset/RLBootstrapDataset.py b/server/dataset/RLBootstrapDataset.py new file mode 100644 index 0000000..76a5e29 --- /dev/null +++ b/server/dataset/RLBootstrapDataset.py @@ -0,0 +1,124 @@ +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import gzip, json, os + +class RLBootstrapDataset: + def __init__(self): + self.enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False) + self.min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000) + self.base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl")) + self.output_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl")) + self.max_bytes = int(float(os.getenv("RL_BOOTSTRAP_MAX_MB", "50")) * 1024 * 1024) + self.needs_more_data = False + + @staticmethod + def _env_bool(name:str, default:bool=False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.lower() in {"1", "true", "yes", "on"} + + @staticmethod + def _env_int(name:str, default:int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + @staticmethod + def count_jsonl_rows(path:Path) -> int: + count = 0 + + candidates = [path] + if path.suffix == ".gz": + candidates.append(path.with_suffix("")) + else: + candidates.append(Path(f"{path}.gz")) + + seen = set() + for candidate in candidates: + if candidate in seen: + continue + seen.add(candidate) + + if not candidate.exists() or not candidate.is_file(): + continue + + open_fn = gzip.open if candidate.suffix == ".gz" else open + try: + with open_fn(candidate, "rt", encoding="utf-8") as handle: + for line in handle: + if line.strip(): + count += 1 + except (OSError, UnicodeError): + continue + + return count + + @staticmethod + def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool: + if max_bytes <= 0: + return False + if not path.exists() or not path.is_file(): + return False + if path.stat().st_size < max_bytes: + return False + + timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") + if path.suffix == ".jsonl": + rotated_name = f"{path.stem}.{timestamp}.jsonl" + else: + rotated_name = f"{path.name}.{timestamp}" + + rotated_path = path.with_name(rotated_name) + suffix = 1 + while rotated_path.exists(): + suffix += 1 + if path.suffix == ".jsonl": + rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl" + else: + rotated_name = f"{path.name}.{timestamp}.{suffix}" + rotated_path = path.with_name(rotated_name) + + path.rename(rotated_path) + with rotated_path.open("rb") as src: + with gzip.open(f"{rotated_path}.gz", "wb") as dst: + dst.writelines(src) + rotated_path.unlink() + return True + + def refresh_state(self): + if not self.enabled: + self.needs_more_data = False + return + base_rows = self.count_jsonl_rows(self.base_dataset_path) + self.needs_more_data = base_rows < self.min_base_rows + + def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None,): + if not self.enabled or not self.needs_more_data: + return + + try: + self.output_path.parent.mkdir(parents=True, exist_ok=True) + row = { + "source": "best_battlesnake_bootstrap", + "game_id": getattr(game_data, "id", None), + "turn": game_data.get_turn(), + "move": move, + "safe_moves": list(safe_moves.keys()), + "reason": reason, + "game_board": game_data.get_game_board_as_dict(), + } + if scores: + row["scores"] = {k: round(v, 5) for k, v in scores.items()} + + with self.output_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(row, ensure_ascii=False) + "\n") + self.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes) + except Exception: + return diff --git a/server/dataset/__init__.py b/server/dataset/__init__.py index 56431c5..f34e69f 100644 --- a/server/dataset/__init__.py +++ b/server/dataset/__init__.py @@ -2,3 +2,4 @@ from .Dataset import Dataset from .DatasetExporter import DatasetExporter from .DatasetCurator import DatasetCurator from .DatasetStats import DatasetStats +from .RLBootstrapDataset import RLBootstrapDataset diff --git a/snakes/BestBattleSnake.py b/snakes/BestBattleSnake.py index 07a2704..83e391d 100644 --- a/snakes/BestBattleSnake.py +++ b/snakes/BestBattleSnake.py @@ -2,14 +2,15 @@ from collections.abc import Iterator from collections import deque from typing import Any, cast from time import perf_counter -from pathlib import Path -import random, json, os +import random, os + +from server.dataset.RLBootstrapDataset import RLBootstrapDataset from snakes.TemplateSnake import TemplateSnake from server.GameBoard import GameBoard class BestBattleSnake(TemplateSnake): - VERSION = "2.6.0" + VERSION = "2.6.1" Point = tuple[int, int] Coord = dict[str, int] SnakeState = dict[str, Any] @@ -40,11 +41,7 @@ class BestBattleSnake(TemplateSnake): self.previous_hazards = set() self.duel_style = self._get_duel_style() self.timeout_buffer_ms = self._get_timeout_buffer_ms() - self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False) - self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000) - self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl")) - self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl")) - self.rl_needs_more_data = False + self.rl_bootstrap = RLBootstrapDataset() self.future_planning_depth = max(1, min(4, self._env_int("BATTLE_FUTURE_PLANNING_DEPTH", default=2))) self.future_planning_branch = max(1, min(3, self._env_int("BATTLE_FUTURE_PLANNING_BRANCH", default=2))) self.future_planning_min_time_ms = max(25, self._env_int("BATTLE_FUTURE_PLANNING_MIN_MS", default=70)) @@ -88,12 +85,6 @@ class BestBattleSnake(TemplateSnake): except ValueError: return 120 - def _env_bool(self, name:str, default:bool=False) -> bool: - value = os.getenv(name) - if value is None: - return default - return value.lower() in {"1", "true", "yes", "on"} - def _env_int(self, name:str, default:int) -> int: value = os.getenv(name) if value is None: @@ -103,46 +94,6 @@ class BestBattleSnake(TemplateSnake): except ValueError: return default - def _count_jsonl_rows(self, path:Path) -> int: - if not path.exists() or not path.is_file(): - return 0 - count = 0 - with path.open("r", encoding="utf-8") as handle: - for line in handle: - if line.strip(): - count += 1 - return count - - def _refresh_rl_bootstrap_state(self): - if not self.rl_bootstrap_enabled: - self.rl_needs_more_data = False - return - base_rows = self._count_jsonl_rows(self.rl_base_dataset_path) - self.rl_needs_more_data = base_rows < self.rl_min_base_rows - - def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None): - if not self.rl_bootstrap_enabled or not self.rl_needs_more_data: - return - - try: - self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True) - row = { - "source": "best_battlesnake_bootstrap", - "game_id": getattr(game_data, "id", None), - "turn": game_data.get_turn(), - "move": move, - "safe_moves": list(safe_moves.keys()), - "reason": reason, - "game_board": game_data.get_game_board_as_dict(), - } - if scores: - row["scores"] = {k: round(v, 5) for k, v in scores.items()} - - with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle: - handle.write(json.dumps(row, ensure_ascii=False) + "\n") - except Exception: - return - def choose_move(self, game_data:GameBoard) -> str: """Pick the next move from a Battlesnake move request. @@ -162,7 +113,7 @@ class BestBattleSnake(TemplateSnake): self.last_move = None self.previous_hazards = set() self.last_game_id = game_id - self._refresh_rl_bootstrap_state() + self.rl_bootstrap.refresh_state() my_snake = cast(dict[str, Any], game_data.get_my_snake()) my_head = my_snake["head"] @@ -215,7 +166,7 @@ class BestBattleSnake(TemplateSnake): "reason": "no_safe_moves", } ) - self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves") + self.rl_bootstrap.record_sample(game_data, fallback, safe_moves, "no_safe_moves") self.previous_hazards = set(hazard_set) return fallback @@ -246,7 +197,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) - self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores) + self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "constrictor", scores) self.previous_hazards = set(hazard_set) return best_move @@ -270,7 +221,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) - self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores) + self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "duel", scores) self.previous_hazards = set(hazard_set) return best_move @@ -418,7 +369,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = quick_move self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"}) - self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget") + self.rl_bootstrap.record_sample(game_data, quick_move, safe_moves, "timeout_budget") self.previous_hazards = set(hazard_set) return quick_move @@ -462,7 +413,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) - self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores) + self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "multi", scores) self.previous_hazards = set(hazard_set) return best_move diff --git a/tests/test_RLBootstrapDataset.py b/tests/test_RLBootstrapDataset.py new file mode 100644 index 0000000..32ab18c --- /dev/null +++ b/tests/test_RLBootstrapDataset.py @@ -0,0 +1,101 @@ +import unittest +from unittest.mock import patch + +from pathlib import Path +import tempfile, gzip + +from server.dataset.RLBootstrapDataset import RLBootstrapDataset + +class TestRLBootstrapDataset(unittest.TestCase): + def test_count_jsonl_rows_reads_gzip_dataset(self): + with tempfile.TemporaryDirectory() as tmp: + dataset_path = Path(tmp) / "base.jsonl.gz" + with gzip.open(dataset_path, "wt", encoding="utf-8") as handle: + handle.write('{"turn":1}\n') + handle.write("\n") + handle.write('{"turn":2}\n') + + self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2) + self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2) + + def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "rl_bootstrap.jsonl" + path.write_text("x" * 200, encoding="utf-8") + + rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached( + path, max_bytes=50 + ) + + self.assertTrue(rotated) + self.assertFalse(path.exists()) + self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1) + + def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self): + with tempfile.TemporaryDirectory() as tmp: + base_path = Path(tmp) / "base.jsonl.gz" + with gzip.open(base_path, "wt", encoding="utf-8") as handle: + handle.write('{"turn":1}\n') + + with patch.dict( + "os.environ", + { + "RL_BOOTSTRAP_ENABLED": "1", + "RL_MIN_BASE_ROWS": "2", + "RL_BASE_DATASET": str(base_path), + }, + clear=False, + ): + dataset = RLBootstrapDataset() + dataset.refresh_state() + + self.assertTrue(dataset.needs_more_data) + + def test_rl_bootstrap_dataset_autocompresses_output(self): + class DummyBoard: + id = "game-1" + + def get_turn(self): + return 1 + + def get_game_board_as_dict(self): + return { + "turn": 1, + "board": { + "width": 11, + "height": 11, + "food": [], + "hazards": [], + "snakes": [], + }, + "you": { + "id": "me", + "head": {"x": 5, "y": 5}, + "body": [{"x": 5, "y": 5}], + }, + } + + with tempfile.TemporaryDirectory() as tmp: + output_path = Path(tmp) / "rl_bootstrap.jsonl" + with patch.dict( + "os.environ", + { + "RL_BOOTSTRAP_ENABLED": "1", + "RL_BOOTSTRAP_MAX_MB": "0.00001", + "RL_BOOTSTRAP_OUTPUT": str(output_path), + }, + clear=False, + ): + dataset = RLBootstrapDataset() + dataset.needs_more_data = True + dataset.record_sample( + DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test" + ) + dataset.record_sample( + DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test" + ) + + self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1) + +if __name__ == "__main__": + unittest.main()