Compare commits

...

8 Commits

31 changed files with 2105 additions and 350 deletions
+1 -1
View File
@@ -12,5 +12,5 @@ data/
dbschema/migrations/
*.jsonl
dataset/
/dataset/
models/
+43 -10
View File
@@ -1,15 +1,17 @@
from quart_common.web.logger import build_logger, await_log
from server.Files import read_file
from server.game_board_stats import GameBoardStoreBuilder
from server.game_state_store import GameStateStoreBuilder
from server.GameBoard import GameBoard
from snakes import SnakeBuilder
from quart_common.web.logger import await_log
from quart_common.web.logger import build_logger
from server.metrics.MetricsManager import MetricsManager
from server.metrics.ServerMetricsCollector import ServerMetricsCollector
from server.storage.StorageLoader import StorageLoader
from server.storage import StorageLoader
from server.metrics import (
MetricsStoreBuilder,
MetricsCollector,
)
from quart import Quart, request, jsonify
import logging, json, os, re, time
@@ -25,7 +27,7 @@ class Server:
'version': '1.0.0',
}
def __init__(self, data_path:str, snake_type:str, storage_type:str, debug:bool=False, check_tls_security:bool=False, game_state_backend:str='memory', game_state_redis_url:str='redis://localhost:6379/0', game_state_ttl_sec:int=900, game_state_local_cache:bool=True, metrics_backend:str='memory', metrics_redis_url:str='redis://localhost:6379/0', metrics_ttl_sec:int=None):
def __init__(self, data_path:str, snake_type:str, storage_type:str, debug:bool=False, check_tls_security:bool=False, game_state_backend:str='memory', game_state_redis_url:str='redis://localhost:6379/0', game_state_ttl_sec:int=900, game_state_local_cache:bool=True, metrics_backend:str='memory', metrics_redis_url:str='redis://localhost:6379/0', metrics_ttl_sec:int|None=None):
self.debug = debug
self.snake_type = snake_type
self.storage_type = storage_type
@@ -37,18 +39,20 @@ class Server:
self.store_game_state = False
normalized_backend = (game_state_backend or 'memory').strip().lower()
self.game_state_local_cache = (game_state_local_cache and normalized_backend != 'memory')
self.game_state_store = GameBoardStoreBuilder.build(
self.game_state_store = GameStateStoreBuilder.build(
backend=game_state_backend,
redis_url=game_state_redis_url,
ttl_seconds=game_state_ttl_sec,
)
metrics_backend_normalized = (metrics_backend or 'memory').strip().lower()
self.stale_game_timeout_sec = self._get_stale_game_timeout_sec()
self.running_games:dict[str, GameBoard] = {}
self.game_move_counts:dict[str, int] = {}
self.game_last_seen_unix:dict[str, int] = {}
self.metrics_collector = ServerMetricsCollector(
metrics_manager=MetricsManager(
self.metrics_collector = MetricsCollector(
metrics_manager=MetricsStoreBuilder.build(
backend=metrics_backend_normalized,
redis_url=metrics_redis_url,
ttl_seconds=metrics_ttl_sec,
@@ -61,6 +65,10 @@ class Server:
game_last_seen_unix=self.game_last_seen_unix,
game_move_counts=self.game_move_counts,
)
self.clear_worker_metrics_on_startup = self._env_bool('METRICS_CLEAR_WORKERS_ON_STARTUP', True)
self.worker_metrics_startup_lock_ttl_sec = self._env_int('METRICS_STARTUP_CLEANUP_LOCK_TTL_SEC', 300)
self._startup_worker_metrics_cleared = False
self.logger = build_logger('Battlesnake', debug_env_var='DEBUG_SERVER')
self.snake_version = self._get_snake_version()
@@ -136,6 +144,16 @@ class Server:
response.headers.set('server', 'battlesnake/gitea/snake-python')
return response
@self.app.before_serving
async def clear_startup_worker_metrics_once():
if self._startup_worker_metrics_cleared:
return
self._startup_worker_metrics_cleared = True
if self.clear_worker_metrics_on_startup:
should_clear = await self.metrics_collector.should_clear_worker_metrics_on_startup(self.worker_metrics_startup_lock_ttl_sec)
if should_clear:
await self.metrics_collector.clear_worker_metrics()
@self.app.after_serving
async def shutdown_state_storage():
await self.game_state_store.close()
@@ -210,6 +228,21 @@ class Server:
except ValueError:
return 180
def _env_bool(self, name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {'1', 'true', 'yes', 'on'}
def _env_int(self, name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
return default
async def _create_game_board(self, game_state:dict) -> GameBoard:
game_id = game_state['game']['id']
new_game_board = GameBoard(
@@ -3,17 +3,17 @@ from server.GameBoard import GameBoard
class Dataset:
VALID_MOVES = {"up", "down", "left", "right"}
def __init__(self, game_board: GameBoard):
def __init__(self, game_board:GameBoard):
self.game_board = game_board
def _did_we_win(self):
winners = self.game_board.winner_snake_names or []
return "me" in winners
def _is_good_move(self, move: str):
def _is_good_move(self, move:str):
return move in self.VALID_MOVES
def build(self, only_good_moves: bool = True):
def build(self, only_good_moves:bool=True):
game_type = self.game_board.get_type_of_game()
did_win = self._did_we_win()
@@ -1,23 +1,10 @@
import argparse
import glob
import hashlib
import json
import shutil
import argparse, hashlib, shutil, json
from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class DatasetCurator:
def __init__(
self,
input_files: list[str],
output_file: str,
min_turn: int = 6,
late_turn: int = 20,
max_safe_options: int = 2,
min_score: int = 3,
append: bool = False,
archive_input: bool = False,
archive_dir: str | None = None,
):
def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None):
self.input_files = input_files
self.output_file = Path(output_file)
self.min_turn = min_turn
@@ -26,82 +13,38 @@ class DatasetCurator:
self.min_score = min_score
self.append = append
self.archive_input = archive_input
self.archive_dir = (
Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
)
self.archive_dir = Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
def _resolve_input_files(self):
resolved = []
seen = set()
for item in self.input_files:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _safe_options_count(self, row: dict):
def _safe_options_count(self, row:dict):
history = row.get("history", {})
for item in history.get("data", []):
if item.get("function") == "get_possible_moves":
return len(item.get("safe_positions", {}))
return None
def _state_hash(self, row: dict):
def _state_hash(self, row:dict):
board = row.get("game_board", {})
snakes = board.get("snakes", [])
snakes_key = []
for snake in snakes:
snakes_key.append(
(
snake.get("id"),
snake.get("health"),
tuple(
(seg.get("x"), seg.get("y")) for seg in snake.get("body", [])
),
)
)
snakes_key.append((
snake.get("id"),
snake.get("health"),
tuple((seg.get("x"), seg.get("y")) for seg in snake.get("body", [])),
))
key = {
"width": board.get("width"),
"height": board.get("height"),
"snakes": sorted(snakes_key),
"food": sorted((f.get("x"), f.get("y")) for f in board.get("food", [])),
"hazards": sorted(
(h.get("x"), h.get("y")) for h in board.get("hazards", [])
),
"hazards": sorted((h.get("x"), h.get("y")) for h in board.get("hazards", [])),
}
raw = json.dumps(key, sort_keys=True, separators=(",", ":"))
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
def _score(self, row: dict):
def _score(self, row:dict):
score = 0
turn = int(row.get("turn", 0))
safe_options = self._safe_options_count(row)
@@ -119,7 +62,7 @@ class DatasetCurator:
def curate(self):
self.output_file.parent.mkdir(parents=True, exist_ok=True)
input_paths = self._resolve_input_files()
input_paths = DatasetIO.resolve_input_files(self.input_files)
total = 0
kept = 0
@@ -129,56 +72,47 @@ class DatasetCurator:
seen_states = set()
if self.append and self.output_file.exists():
with self.output_file.open("r", encoding="utf-8") as existing:
for line in existing:
if not line.strip():
continue
row = json.loads(line)
state_key = self._state_hash(row)
seen_states.add((state_key, row.get("move")))
for row in DatasetIO.iter_jsonl_rows(self.output_file):
state_key = self._state_hash(row)
seen_states.add((state_key, row.get("move")))
mode = "a" if self.append else "w"
with self.output_file.open(mode, encoding="utf-8") as dst:
with DatasetIO.open_text(self.output_file, mode) as dst:
for input_path in input_paths:
with input_path.open("r", encoding="utf-8") as src:
for line in src:
if not line.strip():
continue
for row in DatasetIO.iter_jsonl_rows(input_path):
total += 1
total += 1
row = json.loads(line)
if not row.get("is_good_move", False):
skipped_quality += 1
continue
if not row.get("is_good_move", False):
skipped_quality += 1
continue
if int(row.get("turn", 0)) < self.min_turn:
skipped_turn += 1
continue
if int(row.get("turn", 0)) < self.min_turn:
skipped_turn += 1
continue
quality_score, safe_options = self._score(row)
if quality_score < self.min_score:
skipped_quality += 1
continue
quality_score, safe_options = self._score(row)
if quality_score < self.min_score:
skipped_quality += 1
continue
state_key = self._state_hash(row)
dedupe_key = (state_key, row.get("move"))
if dedupe_key in seen_states:
skipped_duplicate += 1
continue
seen_states.add(dedupe_key)
state_key = self._state_hash(row)
dedupe_key = (state_key, row.get("move"))
if dedupe_key in seen_states:
skipped_duplicate += 1
continue
seen_states.add(dedupe_key)
compact_row = {
"game_id": row.get("game_id"),
"turn": row.get("turn"),
"move": row.get("move"),
"game_type": row.get("game_type"),
"quality_score": quality_score,
"safe_options": safe_options,
"game_board": row.get("game_board"),
}
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
kept += 1
compact_row = {
"game_id": row.get("game_id"),
"turn": row.get("turn"),
"move": row.get("move"),
"game_type": row.get("game_type"),
"quality_score": quality_score,
"safe_options": safe_options,
"game_board": row.get("game_board"),
}
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
kept += 1
archived_files = []
if self.archive_input:
@@ -197,7 +131,7 @@ class DatasetCurator:
"output_file": str(self.output_file),
}
def _archive_processed_files(self, input_paths: list[Path]):
def _archive_processed_files(self, input_paths:list[Path]):
self.archive_dir.mkdir(parents=True, exist_ok=True)
archived = []
@@ -235,14 +169,13 @@ class DatasetCurator:
return archived
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create curated best-moves dataset")
parser.add_argument(
"--input",
action="append",
required=True,
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
)
parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--min-turn", type=int, default=6)
@@ -1,7 +1,7 @@
import argparse
import json
from pathlib import Path
import argparse, json
from server.dataset.DatasetIO import DatasetIO
class DatasetExporter:
def __init__(self, input_dir:str, output_file:str):
@@ -9,9 +9,7 @@ class DatasetExporter:
self.output_file = Path(output_file)
def _iter_game_files(self):
if not self.input_dir.exists():
return []
return sorted(self.input_dir.rglob("*.json"))
return DatasetIO.list_directory_files(self.input_dir, directory_pattern="*.json")
def _extract_samples(self, payload:dict, source_file:Path):
dataset = payload.get("dataset", {})
@@ -39,7 +37,7 @@ class DatasetExporter:
self.output_file.parent.mkdir(parents=True, exist_ok=True)
sample_count = 0
with self.output_file.open("w", encoding="utf-8") as output:
with DatasetIO.open_text(self.output_file, "w") as output:
for game_file in game_files:
with game_file.open("r", encoding="utf-8") as source:
payload = json.load(source)
@@ -56,8 +54,12 @@ class DatasetExporter:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL")
parser.add_argument("--input", default="data", help="Input directory with stored game JSON files")
parser.add_argument("--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file")
parser.add_argument(
"--input", default="data", help="Input directory with stored game JSON files"
)
parser.add_argument(
"--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file"
)
args = parser.parse_args()
report = DatasetExporter(args.input, args.output).export_jsonl()
+134
View File
@@ -0,0 +1,134 @@
from typing import Any, TextIO, cast
from datetime import UTC, datetime
from pathlib import Path
import gzip, glob, json
class DatasetIO:
@staticmethod
def resolve_input_files(input_files: list[str], directory_pattern: str | tuple[str, ...] = ("*.jsonl", "*.jsonl.gz")) -> list[Path]:
patterns = (
(directory_pattern,)
if isinstance(directory_pattern, str)
else directory_pattern
)
resolved = []
seen = set()
for item in input_files:
path = Path(item)
if path.is_dir():
for pattern in patterns:
for file_path in sorted(path.rglob(pattern)):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
@staticmethod
def list_directory_files(input_dir:Path, directory_pattern:str) -> list[Path]:
if not input_dir.exists():
return []
return sorted(input_dir.rglob(directory_pattern))
@staticmethod
def open_text(path:Path, mode:str="r"):
text_mode = mode if "t" in mode else f"{mode}t"
open_fn = gzip.open if path.suffix == ".gz" else open
return cast(TextIO, open_fn(path, text_mode, encoding="utf-8"))
@staticmethod
def iter_jsonl_rows(path:Path):
with DatasetIO.open_text(path, "r") as handle:
for line in handle:
if line.strip():
yield json.loads(line)
@staticmethod
def append_jsonl_row(path:Path, row:dict[str, Any]):
with DatasetIO.open_text(path, "a") as raw_handle:
handle = cast(TextIO, raw_handle)
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
@staticmethod
def count_jsonl_rows(path:Path) -> int:
count = 0
candidates = [path]
if path.suffix == ".gz":
candidates.append(path.with_suffix(""))
else:
candidates.append(Path(f"{path}.gz"))
seen = set()
for candidate in candidates:
if candidate in seen:
continue
seen.add(candidate)
if not candidate.exists() or not candidate.is_file():
continue
try:
with DatasetIO.open_text(candidate, "r") as handle:
for line in handle:
if line.strip():
count += 1
except (OSError, UnicodeError):
continue
return count
@staticmethod
def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool:
if max_bytes <= 0:
return False
if not path.exists() or not path.is_file():
return False
if path.stat().st_size < max_bytes:
return False
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}"
rotated_path = path.with_name(rotated_name)
suffix = 1
while rotated_path.exists():
suffix += 1
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}.{suffix}"
rotated_path = path.with_name(rotated_name)
path.rename(rotated_path)
with rotated_path.open("rb") as src:
with gzip.open(f"{rotated_path}.gz", "wb") as dst:
dst.writelines(src)
rotated_path.unlink()
return True
@@ -1,68 +1,30 @@
import argparse
import glob
import json
import re
from collections import Counter, defaultdict
import argparse, json, re
from datetime import datetime
from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class DatasetStats:
DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})")
def __init__(self, input_files: list[str]):
def __init__(self, input_files:list[str]):
self.input_files = input_files
def _resolve_input_files(self):
resolved = []
seen = set()
for item in self.input_files:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _infer_day(self, file_path: Path):
def _infer_day(self, file_path:Path):
match = self.DAY_PATTERN.search(file_path.name)
if match:
return match.group(1)
return datetime.fromtimestamp(file_path.stat().st_mtime).strftime("%Y-%m-%d")
def _game_score(self, game: dict):
def _game_score(self, game:dict):
max_turn = game["max_turn"]
rows = game["rows"]
avg_safe = game["avg_safe_options"]
pressure_bonus = 0 if avg_safe is None else max(0.0, 4.0 - avg_safe)
return round(max_turn * 2.0 + rows + pressure_bonus, 3)
def _pressure_score(self, game: dict):
def _pressure_score(self, game:dict):
max_turn = game["max_turn"]
rows = max(1, game["rows"])
pressure_turns = game["pressure_turns"]
@@ -72,7 +34,7 @@ class DatasetStats:
safe_tightness = 0.0 if avg_safe is None else max(0.0, 3.0 - avg_safe)
return round(max_turn * 1.2 + pressure_ratio * 120.0 + safe_tightness * 20.0, 3)
def _extract_safe_options(self, row: dict):
def _extract_safe_options(self, row:dict):
top_level = row.get("safe_options")
if isinstance(top_level, int):
return top_level
@@ -87,7 +49,7 @@ class DatasetStats:
return None
def analyze(self):
files = self._resolve_input_files()
files = DatasetIO.resolve_input_files(self.input_files)
totals = {
"rows": 0,
@@ -103,57 +65,52 @@ class DatasetStats:
for file_path in files:
day = self._infer_day(file_path)
with file_path.open("r", encoding="utf-8") as source:
for line in source:
if not line.strip():
continue
for row in DatasetIO.iter_jsonl_rows(file_path):
game_id = row.get("game_id")
if not game_id:
continue
row = json.loads(line)
game_id = row.get("game_id")
if not game_id:
continue
turn = int(row.get("turn", 0))
safe_options = self._extract_safe_options(row)
snake_type = row.get("snake_type", "unknown")
move = row.get("move", "unknown")
turn = int(row.get("turn", 0))
safe_options = self._extract_safe_options(row)
snake_type = row.get("snake_type", "unknown")
move = row.get("move", "unknown")
game_type = row.get("game_type", {})
if isinstance(game_type, dict):
game_type_name = game_type.get("name", "unknown")
else:
game_type_name = str(game_type)
game_type = row.get("game_type", {})
if isinstance(game_type, dict):
game_type_name = game_type.get("name", "unknown")
else:
game_type_name = str(game_type)
totals["rows"] += 1
totals["games"].add(game_id)
totals["snake_types"][snake_type] += 1
totals["game_types"][game_type_name] += 1
totals["moves"][move] += 1
totals["days"][day] += 1
totals["rows"] += 1
totals["games"].add(game_id)
totals["snake_types"][snake_type] += 1
totals["game_types"][game_type_name] += 1
totals["moves"][move] += 1
totals["days"][day] += 1
if game_id not in games:
games[game_id] = {
"game_id": game_id,
"day": day,
"snake_type": snake_type,
"game_type": game_type_name,
"rows": 0,
"max_turn": -1,
"safe_options_sum": 0,
"safe_options_count": 0,
"pressure_turns": 0,
}
if game_id not in games:
games[game_id] = {
"game_id": game_id,
"day": day,
"snake_type": snake_type,
"game_type": game_type_name,
"rows": 0,
"max_turn": -1,
"safe_options_sum": 0,
"safe_options_count": 0,
"pressure_turns": 0,
}
game = games[game_id]
game["rows"] += 1
game["max_turn"] = max(game["max_turn"], turn)
if isinstance(safe_options, int):
game["safe_options_sum"] += safe_options
game["safe_options_count"] += 1
if safe_options <= 2:
game["pressure_turns"] += 1
game = games[game_id]
game["rows"] += 1
game["max_turn"] = max(game["max_turn"], turn)
if isinstance(safe_options, int):
game["safe_options_sum"] += safe_options
game["safe_options_count"] += 1
if safe_options <= 2:
game["pressure_turns"] += 1
day_games[day].add(game_id)
day_games[day].add(game_id)
game_summaries = []
for game in games.values():
@@ -229,7 +186,7 @@ if __name__ == "__main__":
"--input",
action="append",
required=True,
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
)
parser.add_argument(
"--output",
+62
View File
@@ -0,0 +1,62 @@
from pathlib import Path
from typing import Any
import os
from server.dataset.DatasetIO import DatasetIO
class RLBootstrapDataset:
def __init__(self):
self.enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
self.min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
self.base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
self.output_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
self.max_bytes = int(float(os.getenv("RL_BOOTSTRAP_MAX_MB", "50")) * 1024 * 1024)
self.needs_more_data = False
@staticmethod
def _env_bool(name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
@staticmethod
def _env_int(name:str, default:int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
return default
def refresh_state(self):
if not self.enabled:
self.needs_more_data = False
return
base_rows = DatasetIO.count_jsonl_rows(self.base_dataset_path)
self.needs_more_data = base_rows < self.min_base_rows
def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None):
if not self.enabled or not self.needs_more_data:
return
try:
self.output_path.parent.mkdir(parents=True, exist_ok=True)
row = {
"source": "best_battlesnake_bootstrap",
"game_id": getattr(game_data, "id", None),
"turn": game_data.get_turn(),
"move": move,
"safe_moves": list(safe_moves.keys()),
"reason": reason,
"game_board": game_data.get_game_board_as_dict(),
}
if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
DatasetIO.append_jsonl_row(self.output_path, row)
DatasetIO.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes)
except Exception:
return
+6
View File
@@ -0,0 +1,6 @@
from .Dataset import Dataset
from .DatasetIO import DatasetIO
from .DatasetExporter import DatasetExporter
from .DatasetCurator import DatasetCurator
from .DatasetStats import DatasetStats
from .RLBootstrapDataset import RLBootstrapDataset
@@ -1,6 +1,5 @@
from server.GameBoard import GameBoard
import inspect
import pickle
import inspect, pickle
class RedisGameBoardStore:
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:gameboard", ttl_seconds:int=900, **kwargs):
@@ -1,7 +1,7 @@
from server.game_board_stats.MemoryGameBoardStore import MemoryGameBoardStore
from server.game_board_stats.RedisGameBoardStore import RedisGameBoardStore
from .MemoryGameBoardStore import MemoryGameBoardStore
from .RedisGameBoardStore import RedisGameBoardStore
class GameBoardStoreBuilder:
class GameStateStoreBuilder:
@classmethod
def build(self, backend:str="memory", **kwargs) -> MemoryGameBoardStore|RedisGameBoardStore:
selected = (backend or "memory").strip().lower()
@@ -1,9 +1,9 @@
from server.metrics.backends.Template import StoreTemplate
import time
from server.metrics.MetricsManager import MetricsManager
class ServerMetricsCollector:
def __init__(self, metrics_manager:MetricsManager, game_state_local_cache:bool, metrics_backend:str, game_state_backend:str, stale_game_timeout_sec:int, game_last_seen_unix:dict, game_move_counts:dict,):
class MetricsCollector:
def __init__(self, metrics_manager:StoreTemplate, game_state_local_cache:bool, metrics_backend:str, game_state_backend:str, stale_game_timeout_sec:int, game_last_seen_unix:dict, game_move_counts:dict):
self._manager = metrics_manager
self._stale_game_timeout_sec = stale_game_timeout_sec
self._game_last_seen_unix = game_last_seen_unix
@@ -167,6 +167,12 @@ class ServerMetricsCollector:
local_snapshot = self.build_local_snapshot(game_last_seen_unix, game_move_counts)
return await self._manager.snapshot(local_snapshot)
async def clear_worker_metrics(self) -> None:
await self._manager.clear_all_workers()
async def should_clear_worker_metrics_on_startup(self, lock_ttl_seconds:int=300) -> bool:
return await self._manager.acquire_startup_cleanup_lock(lock_ttl_seconds)
def build_prometheus_metrics(self, snapshot:dict) -> str:
lines = [
'# HELP snake_games_started_total Total games started by snake server.',
+4 -3
View File
@@ -1,9 +1,10 @@
from server.metrics.MemoryMetricsStore import MemoryMetricsStore
from server.metrics.RedisMetricsStore import RedisMetricsStore
from .backends import StoreTemplate, MemoryMetricsStore, RedisMetricsStore
from .MetricsCollector import MetricsCollector
class MetricsStoreBuilder:
@classmethod
def build(self, backend:str="memory", **kwargs) -> MemoryMetricsStore|RedisMetricsStore:
def build(self, backend:str="memory", **kwargs) -> StoreTemplate:
selected = (backend or "memory").strip().lower()
if selected == "redis":
return RedisMetricsStore(**kwargs)
@@ -1,5 +1,8 @@
class MemoryMetricsStore:
from server.metrics.backends.Template import StoreTemplate
class MemoryMetricsStore(StoreTemplate):
def __init__(self, **kwargs):
super().__init__(backend="memory", **kwargs)
self._snapshots:dict[str, dict] = {}
async def publish(self, worker_id:str, snapshot:dict) -> None:
@@ -8,5 +11,11 @@ class MemoryMetricsStore:
async def load_all(self) -> list[dict]:
return [dict(value) for value in self._snapshots.values()]
async def clear_all(self) -> None:
self._snapshots.clear()
async def _acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300) -> bool:
return True
async def close(self) -> None:
return None
@@ -1,8 +1,10 @@
import inspect
import json
from server.metrics.backends.Template import StoreTemplate
class RedisMetricsStore:
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:metrics:worker", ttl_seconds:int=None, **kwargs):
import inspect, json
class RedisMetricsStore(StoreTemplate):
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:metrics:worker", ttl_seconds:int|None=None, **kwargs):
super().__init__(backend="redis", key_prefix=key_prefix, **kwargs)
self.redis_url = redis_url
self.key_prefix = key_prefix
self.ttl_seconds = ttl_seconds
@@ -41,6 +43,17 @@ class RedisMetricsStore:
continue
return snapshots
async def clear_all(self) -> None:
redis = await self._get_redis()
keys = await redis.keys(f"{self.key_prefix}:*")
if keys:
await redis.delete(*keys)
async def _acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300) -> bool:
redis = await self._get_redis()
locked = await redis.set(lock_key, '1', ex=max(1, int(ttl_seconds)), nx=True)
return bool(locked)
async def close(self) -> None:
if self._redis is None:
return
@@ -1,17 +1,18 @@
from server.metrics import MetricsStoreBuilder
from typing import Any, Awaitable, cast
import inspect, time, os
import time, os
class MetricsManager:
def __init__(self, backend:str="memory", redis_url:str="redis://localhost:6379/0", ttl_seconds:int=90, key_prefix:str="snake:metrics:worker", worker_id:str|None=None):
class StoreTemplate:
def __init__(self, backend:str="memory", key_prefix:str="snake:metrics:worker", worker_id:str|None=None, **kwargs):
self.backend = (backend or "memory").strip().lower()
self.key_prefix = key_prefix
self.worker_id = worker_id or f"{os.getpid()}-{int(time.time() * 1000)}"
self.store = MetricsStoreBuilder.build(
backend=self.backend,
redis_url=redis_url,
ttl_seconds=ttl_seconds,
key_prefix=key_prefix,
)
self.store = self
async def publish(self, worker_id:str, snapshot:dict) -> None:
raise NotImplementedError
async def load_all(self) -> list[dict]:
raise NotImplementedError
async def publish_only(self, snapshot:dict) -> None:
await self.store.publish(self.worker_id, snapshot)
@@ -28,7 +29,37 @@ class MetricsManager:
return self._merge_snapshots(snapshots)
async def close(self) -> None:
await self.store.close()
if self.store is self:
return
close_store = getattr(self.store, "close", None)
if not callable(close_store):
return
maybe_result = close_store()
if inspect.isawaitable(maybe_result):
await cast(Awaitable[Any], maybe_result)
async def clear_all_workers(self) -> None:
clear_all = getattr(self.store, "clear_all", None)
if callable(clear_all):
maybe_result = clear_all()
if inspect.isawaitable(maybe_result):
await cast(Awaitable[Any], maybe_result)
async def acquire_startup_cleanup_lock(self, ttl_seconds:int=300) -> bool:
if self.backend != "redis":
return True
acquire_lock = getattr(self.store, "_acquire_startup_cleanup_lock", None)
if not callable(acquire_lock):
acquire_lock = getattr(self.store, "acquire_startup_cleanup_lock", None)
if not callable(acquire_lock):
return True
lock_key = f"{self.key_prefix}:startup_cleanup_lock"
maybe_result = acquire_lock(lock_key, ttl_seconds)
if inspect.isawaitable(maybe_result):
return bool(await cast(Awaitable[Any], maybe_result))
return bool(maybe_result)
def _merge_snapshots(self, snapshots:list[dict]) -> dict:
merged = {
@@ -86,11 +117,11 @@ class MetricsManager:
merged["max_turn"] = max(merged["max_turn"], int(worker.get("max_turn", 0)))
merged["active_games_peak"] = max(merged["active_games_peak"], int(worker.get("active_games_peak", 0)))
merged["move_response_time_ms_max"] = max(merged["move_response_time_ms_max"], float(worker.get("move_response_time_ms_max", 0.0)))
merged["last_game_start_unix"] = max(merged["last_game_start_unix"], int(worker.get("last_game_start_unix", 0)),)
merged["last_game_start_unix"] = max(merged["last_game_start_unix"], int(worker.get("last_game_start_unix", 0)))
merged["last_game_end_unix"] = max(merged["last_game_end_unix"], int(worker.get("last_game_end_unix", 0)))
merged["last_move_unix"] = max(merged["last_move_unix"], int(worker.get("last_move_unix", 0)))
merged["oldest_active_game_age_sec"] = max(merged["oldest_active_game_age_sec"], int(worker.get("oldest_active_game_age_sec", 0)),)
merged["stale_game_timeout_sec"] = max(merged["stale_game_timeout_sec"], int(worker.get("stale_game_timeout_sec", 0)),)
merged["oldest_active_game_age_sec"] = max(merged["oldest_active_game_age_sec"], int(worker.get("oldest_active_game_age_sec", 0)))
merged["stale_game_timeout_sec"] = max(merged["stale_game_timeout_sec"], int(worker.get("stale_game_timeout_sec", 0)))
merged["game_state_local_cache_enabled"] = merged["game_state_local_cache_enabled"] or bool(worker.get("game_state_local_cache_enabled", False))
for endpoint in merged["http_requests_by_endpoint"]:
+3
View File
@@ -0,0 +1,3 @@
from .Template import StoreTemplate
from .Memory import MemoryMetricsStore
from .Redis import RedisMetricsStore
+3 -3
View File
@@ -1,5 +1,5 @@
from server.GameBoard import GameBoard
from server.Dataset import Dataset
from server.dataset.Dataset import Dataset
from datetime import datetime
import json, time
@@ -55,8 +55,8 @@ class EdgeDB:
calculations = snake_calulations[i] if i < len(snake_calulations) else []
calculations.append({
"dataset": {
"is_good_move": labels_by_turn.get(moves[i]["turn"], False)
}
"is_good_move": labels_by_turn.get(moves[i]["turn"], False)
}
})
data.append({
"turn": moves[i]["turn"],
+2 -3
View File
@@ -1,11 +1,10 @@
from server.dataset.Dataset import Dataset
from server.GameBoard import GameBoard
from server.Dataset import Dataset
from server.Files import save_file
import aiofiles
import aiofiles.os
import gzip
import json, os
import gzip, json, os
class LocalStorage:
def __init__(self, file_path:str, **kwargs):
-7
View File
@@ -1,7 +0,0 @@
class StorageLoader:
@classmethod
def build(self, selected_storage:str):
storage_module = __import__(f'server.storage.{selected_storage}', fromlist=[selected_storage])
storage_class = getattr(storage_module, selected_storage)
return storage_class
+6
View File
@@ -0,0 +1,6 @@
class StorageLoader:
@classmethod
def build(self, selected_storage: str):
storage_module = __import__(f"server.storage.{selected_storage}", fromlist=[selected_storage])
storage_class = getattr(storage_module, selected_storage)
return storage_class
+11 -60
View File
@@ -2,14 +2,15 @@ from collections.abc import Iterator
from collections import deque
from typing import Any, cast
from time import perf_counter
from pathlib import Path
import random, json, os
import random, os
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
from snakes.TemplateSnake import TemplateSnake
from server.GameBoard import GameBoard
class BestBattleSnake(TemplateSnake):
VERSION = "2.6.0"
VERSION = "2.6.1"
Point = tuple[int, int]
Coord = dict[str, int]
SnakeState = dict[str, Any]
@@ -40,11 +41,7 @@ class BestBattleSnake(TemplateSnake):
self.previous_hazards = set()
self.duel_style = self._get_duel_style()
self.timeout_buffer_ms = self._get_timeout_buffer_ms()
self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
self.rl_needs_more_data = False
self.rl_bootstrap = RLBootstrapDataset()
self.future_planning_depth = max(1, min(4, self._env_int("BATTLE_FUTURE_PLANNING_DEPTH", default=2)))
self.future_planning_branch = max(1, min(3, self._env_int("BATTLE_FUTURE_PLANNING_BRANCH", default=2)))
self.future_planning_min_time_ms = max(25, self._env_int("BATTLE_FUTURE_PLANNING_MIN_MS", default=70))
@@ -88,12 +85,6 @@ class BestBattleSnake(TemplateSnake):
except ValueError:
return 120
def _env_bool(self, name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def _env_int(self, name:str, default:int) -> int:
value = os.getenv(name)
if value is None:
@@ -103,46 +94,6 @@ class BestBattleSnake(TemplateSnake):
except ValueError:
return default
def _count_jsonl_rows(self, path:Path) -> int:
if not path.exists() or not path.is_file():
return 0
count = 0
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
count += 1
return count
def _refresh_rl_bootstrap_state(self):
if not self.rl_bootstrap_enabled:
self.rl_needs_more_data = False
return
base_rows = self._count_jsonl_rows(self.rl_base_dataset_path)
self.rl_needs_more_data = base_rows < self.rl_min_base_rows
def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None):
if not self.rl_bootstrap_enabled or not self.rl_needs_more_data:
return
try:
self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
row = {
"source": "best_battlesnake_bootstrap",
"game_id": getattr(game_data, "id", None),
"turn": game_data.get_turn(),
"move": move,
"safe_moves": list(safe_moves.keys()),
"reason": reason,
"game_board": game_data.get_game_board_as_dict(),
}
if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
except Exception:
return
def choose_move(self, game_data:GameBoard) -> str:
"""Pick the next move from a Battlesnake move request.
@@ -162,7 +113,7 @@ class BestBattleSnake(TemplateSnake):
self.last_move = None
self.previous_hazards = set()
self.last_game_id = game_id
self._refresh_rl_bootstrap_state()
self.rl_bootstrap.refresh_state()
my_snake = cast(dict[str, Any], game_data.get_my_snake())
my_head = my_snake["head"]
@@ -215,7 +166,7 @@ class BestBattleSnake(TemplateSnake):
"reason": "no_safe_moves",
}
)
self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves")
self.rl_bootstrap.record_sample(game_data, fallback, safe_moves, "no_safe_moves")
self.previous_hazards = set(hazard_set)
return fallback
@@ -246,7 +197,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point)
self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores)
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "constrictor", scores)
self.previous_hazards = set(hazard_set)
return best_move
@@ -270,7 +221,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point)
self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores)
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "duel", scores)
self.previous_hazards = set(hazard_set)
return best_move
@@ -418,7 +369,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point)
self.last_move = quick_move
self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"})
self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget")
self.rl_bootstrap.record_sample(game_data, quick_move, safe_moves, "timeout_budget")
self.previous_hazards = set(hazard_set)
return quick_move
@@ -462,7 +413,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point)
self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores)
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "multi", scores)
self.previous_hazards = set(hazard_set)
return best_move
File diff suppressed because it is too large Load Diff
+1
View File
@@ -8,6 +8,7 @@ SNAKE_REGISTRY = {
"BetterMasterSnake": "1.3.0",
"BestBattleSnake": "2.6.0",
"TrainedBattleSnake": "0.1.0",
"UltimateBattleSnake": "4.1.0",
}
def build_snake(selected_snake: str):
+1 -1
View File
@@ -1,7 +1,7 @@
import unittest
from typing import cast
from server.Dataset import Dataset
from server.dataset.Dataset import Dataset
from server.GameBoard import GameBoard
class DummySnake:
+39 -3
View File
@@ -1,9 +1,9 @@
import json
import tempfile
import unittest
import tempfile, json, gzip
from pathlib import Path
from server.DatasetExporter import DatasetExporter
from server.dataset.DatasetExporter import DatasetExporter
class TestDatasetExporter(unittest.TestCase):
def test_export_jsonl(self):
@@ -43,5 +43,41 @@ class TestDatasetExporter(unittest.TestCase):
self.assertEqual(first["move"], "up")
self.assertTrue(first["is_good_move"])
def test_export_jsonl_gz(self):
with tempfile.TemporaryDirectory() as tmp:
input_dir = Path(tmp) / "data"
output_file = Path(tmp) / "out" / "dataset.jsonl.gz"
game_file = input_dir / "game-1.json"
game_file.parent.mkdir(parents=True, exist_ok=True)
game_payload = {
"dataset": {
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
"snake": {"type": "BestBattleSnake"},
"samples": [
{
"turn": 1,
"move": "up",
"is_good_move": True,
"game_board": {"width": 11, "height": 11},
"history": {"data": []},
}
],
}
}
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
self.assertEqual(report["games_scanned"], 1)
self.assertEqual(report["samples_exported"], 1)
self.assertTrue(output_file.exists())
with gzip.open(output_file, "rt", encoding="utf-8") as handle:
lines = handle.read().strip().splitlines()
self.assertEqual(len(lines), 1)
first = json.loads(lines[0])
self.assertEqual(first["game_id"], "g-1")
if __name__ == "__main__":
unittest.main()
+46
View File
@@ -0,0 +1,46 @@
import unittest
import tempfile, gzip, json
from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class TestDatasetIO(unittest.TestCase):
def test_resolve_input_files_includes_jsonl_and_gz(self):
with tempfile.TemporaryDirectory() as tmp:
base = Path(tmp)
(base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8")
with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle:
handle.write('{"x":2}\n')
paths = DatasetIO.resolve_input_files([str(base)])
names = {path.name for path in paths}
self.assertIn("a.jsonl", names)
self.assertIn("b.jsonl.gz", names)
def test_iter_jsonl_rows_reads_gz_file(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rows.jsonl.gz"
with gzip.open(path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
handle.write("\n")
handle.write('{"turn":2}\n')
rows = list(DatasetIO.iter_jsonl_rows(path))
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0]["turn"], 1)
self.assertEqual(rows[1]["turn"], 2)
def test_append_jsonl_row_supports_gz_output(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "append.jsonl.gz"
DatasetIO.append_jsonl_row(path, {"move": "up"})
DatasetIO.append_jsonl_row(path, {"move": "left"})
with gzip.open(path, "rt", encoding="utf-8") as handle:
lines = handle.read().strip().splitlines()
self.assertEqual(len(lines), 2)
self.assertEqual(json.loads(lines[0])["move"], "up")
if __name__ == "__main__":
unittest.main()
+4 -6
View File
@@ -2,9 +2,7 @@ import unittest
from typing import Any, cast
from server.GameBoard import GameBoard
from server.game_board_stats import GameBoardStoreBuilder
from server.game_board_stats.MemoryGameBoardStore import MemoryGameBoardStore
from server.game_board_stats.RedisGameBoardStore import RedisGameBoardStore
from server.game_state_store import GameStateStoreBuilder, MemoryGameBoardStore, RedisGameBoardStore
from snakes.TemplateSnake import TemplateSnake
class _FakeRedis:
@@ -73,9 +71,9 @@ class TestGameStateStore(unittest.IsolatedAsyncioTestCase):
return board
def test_builder_selects_store_backend(self):
memory_store = GameBoardStoreBuilder.build(backend="memory")
redis_store = GameBoardStoreBuilder.build(backend="redis")
default_store = GameBoardStoreBuilder.build(backend="unknown")
memory_store = GameStateStoreBuilder.build(backend="memory")
redis_store = GameStateStoreBuilder.build(backend="redis")
default_store = GameStateStoreBuilder.build(backend="unknown")
self.assertIsInstance(memory_store, MemoryGameBoardStore)
self.assertIsInstance(redis_store, RedisGameBoardStore)
@@ -1,10 +1,15 @@
import unittest
from server.metrics.MetricsManager import MetricsManager
from typing import Any, cast
class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
from server.metrics import (
MetricsStoreBuilder,
MemoryMetricsStore,
)
class TestMetricsStoreTemplate(unittest.IsolatedAsyncioTestCase):
async def test_memory_backend_returns_local_snapshot(self):
manager = MetricsManager(backend="memory")
manager = MetricsStoreBuilder.build(backend="memory")
local = {
"games_started": 2,
"games_ended": 1,
@@ -48,7 +53,7 @@ class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
await manager.close()
async def test_merge_snapshots_aggregates_totals(self):
manager = MetricsManager(backend="memory")
manager = MemoryMetricsStore()
merged = manager._merge_snapshots(
[
{
@@ -139,5 +144,27 @@ class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
self.assertEqual(merged["metrics_backend"], "redis")
await manager.close()
async def test_acquire_startup_cleanup_lock_uses_store_for_redis_backend(self):
class FakeStore:
def __init__(self):
self.calls = []
async def acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300):
self.calls.append((lock_key, ttl_seconds))
return True
async def close(self):
return None
manager = MetricsStoreBuilder.build(backend="redis", key_prefix="snake:metrics:worker")
fake_store = FakeStore()
manager.store = cast(Any, fake_store)
allowed = await manager.acquire_startup_cleanup_lock(180)
self.assertTrue(allowed)
self.assertEqual(fake_store.calls, [("snake:metrics:worker:startup_cleanup_lock", 180)])
await manager.close()
if __name__ == "__main__":
unittest.main()
+100
View File
@@ -0,0 +1,100 @@
import unittest
from unittest.mock import patch
from pathlib import Path
import tempfile, gzip
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
from server.dataset.DatasetIO import DatasetIO
class TestRLBootstrapDataset(unittest.TestCase):
def test_count_jsonl_rows_reads_gzip_dataset(self):
with tempfile.TemporaryDirectory() as tmp:
dataset_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
handle.write("\n")
handle.write('{"turn":2}\n')
self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rl_bootstrap.jsonl"
path.write_text("x" * 200, encoding="utf-8")
rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
self.assertTrue(rotated)
self.assertFalse(path.exists())
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
with tempfile.TemporaryDirectory() as tmp:
base_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_MIN_BASE_ROWS": "2",
"RL_BASE_DATASET": str(base_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.refresh_state()
self.assertTrue(dataset.needs_more_data)
def test_rl_bootstrap_dataset_autocompresses_output(self):
class DummyBoard:
id = "game-1"
def get_turn(self):
return 1
def get_game_board_as_dict(self):
return {
"turn": 1,
"board": {
"width": 11,
"height": 11,
"food": [],
"hazards": [],
"snakes": [],
},
"you": {
"id": "me",
"head": {"x": 5, "y": 5},
"body": [{"x": 5, "y": 5}],
},
}
with tempfile.TemporaryDirectory() as tmp:
output_path = Path(tmp) / "rl_bootstrap.jsonl"
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_BOOTSTRAP_MAX_MB": "0.00001",
"RL_BOOTSTRAP_OUTPUT": str(output_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.needs_more_data = True
dataset.record_sample(
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
)
dataset.record_sample(
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
)
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
if __name__ == "__main__":
unittest.main()