Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
643f4b468e
|
|||
|
332e86e3cc
|
|||
|
066a93f755
|
|||
|
eb290dd634
|
|||
|
2b8f0396e3
|
|||
|
043d7654f9
|
|||
|
a38a600bdc
|
|||
|
9d33c6fded
|
+1
-1
@@ -12,5 +12,5 @@ data/
|
||||
dbschema/migrations/
|
||||
|
||||
*.jsonl
|
||||
dataset/
|
||||
/dataset/
|
||||
models/
|
||||
|
||||
+43
-10
@@ -1,15 +1,17 @@
|
||||
from quart_common.web.logger import build_logger, await_log
|
||||
from server.Files import read_file
|
||||
|
||||
from server.game_board_stats import GameBoardStoreBuilder
|
||||
from server.game_state_store import GameStateStoreBuilder
|
||||
from server.GameBoard import GameBoard
|
||||
|
||||
from snakes import SnakeBuilder
|
||||
from quart_common.web.logger import await_log
|
||||
from quart_common.web.logger import build_logger
|
||||
|
||||
from server.metrics.MetricsManager import MetricsManager
|
||||
from server.metrics.ServerMetricsCollector import ServerMetricsCollector
|
||||
from server.storage.StorageLoader import StorageLoader
|
||||
from server.storage import StorageLoader
|
||||
|
||||
from server.metrics import (
|
||||
MetricsStoreBuilder,
|
||||
MetricsCollector,
|
||||
)
|
||||
|
||||
from quart import Quart, request, jsonify
|
||||
import logging, json, os, re, time
|
||||
@@ -25,7 +27,7 @@ class Server:
|
||||
'version': '1.0.0',
|
||||
}
|
||||
|
||||
def __init__(self, data_path:str, snake_type:str, storage_type:str, debug:bool=False, check_tls_security:bool=False, game_state_backend:str='memory', game_state_redis_url:str='redis://localhost:6379/0', game_state_ttl_sec:int=900, game_state_local_cache:bool=True, metrics_backend:str='memory', metrics_redis_url:str='redis://localhost:6379/0', metrics_ttl_sec:int=None):
|
||||
def __init__(self, data_path:str, snake_type:str, storage_type:str, debug:bool=False, check_tls_security:bool=False, game_state_backend:str='memory', game_state_redis_url:str='redis://localhost:6379/0', game_state_ttl_sec:int=900, game_state_local_cache:bool=True, metrics_backend:str='memory', metrics_redis_url:str='redis://localhost:6379/0', metrics_ttl_sec:int|None=None):
|
||||
self.debug = debug
|
||||
self.snake_type = snake_type
|
||||
self.storage_type = storage_type
|
||||
@@ -37,18 +39,20 @@ class Server:
|
||||
self.store_game_state = False
|
||||
normalized_backend = (game_state_backend or 'memory').strip().lower()
|
||||
self.game_state_local_cache = (game_state_local_cache and normalized_backend != 'memory')
|
||||
self.game_state_store = GameBoardStoreBuilder.build(
|
||||
self.game_state_store = GameStateStoreBuilder.build(
|
||||
backend=game_state_backend,
|
||||
redis_url=game_state_redis_url,
|
||||
ttl_seconds=game_state_ttl_sec,
|
||||
)
|
||||
metrics_backend_normalized = (metrics_backend or 'memory').strip().lower()
|
||||
self.stale_game_timeout_sec = self._get_stale_game_timeout_sec()
|
||||
|
||||
self.running_games:dict[str, GameBoard] = {}
|
||||
self.game_move_counts:dict[str, int] = {}
|
||||
self.game_last_seen_unix:dict[str, int] = {}
|
||||
self.metrics_collector = ServerMetricsCollector(
|
||||
metrics_manager=MetricsManager(
|
||||
|
||||
self.metrics_collector = MetricsCollector(
|
||||
metrics_manager=MetricsStoreBuilder.build(
|
||||
backend=metrics_backend_normalized,
|
||||
redis_url=metrics_redis_url,
|
||||
ttl_seconds=metrics_ttl_sec,
|
||||
@@ -61,6 +65,10 @@ class Server:
|
||||
game_last_seen_unix=self.game_last_seen_unix,
|
||||
game_move_counts=self.game_move_counts,
|
||||
)
|
||||
self.clear_worker_metrics_on_startup = self._env_bool('METRICS_CLEAR_WORKERS_ON_STARTUP', True)
|
||||
self.worker_metrics_startup_lock_ttl_sec = self._env_int('METRICS_STARTUP_CLEANUP_LOCK_TTL_SEC', 300)
|
||||
self._startup_worker_metrics_cleared = False
|
||||
|
||||
self.logger = build_logger('Battlesnake', debug_env_var='DEBUG_SERVER')
|
||||
self.snake_version = self._get_snake_version()
|
||||
|
||||
@@ -136,6 +144,16 @@ class Server:
|
||||
response.headers.set('server', 'battlesnake/gitea/snake-python')
|
||||
return response
|
||||
|
||||
@self.app.before_serving
|
||||
async def clear_startup_worker_metrics_once():
|
||||
if self._startup_worker_metrics_cleared:
|
||||
return
|
||||
self._startup_worker_metrics_cleared = True
|
||||
if self.clear_worker_metrics_on_startup:
|
||||
should_clear = await self.metrics_collector.should_clear_worker_metrics_on_startup(self.worker_metrics_startup_lock_ttl_sec)
|
||||
if should_clear:
|
||||
await self.metrics_collector.clear_worker_metrics()
|
||||
|
||||
@self.app.after_serving
|
||||
async def shutdown_state_storage():
|
||||
await self.game_state_store.close()
|
||||
@@ -210,6 +228,21 @@ class Server:
|
||||
except ValueError:
|
||||
return 180
|
||||
|
||||
def _env_bool(self, name:str, default:bool=False) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {'1', 'true', 'yes', 'on'}
|
||||
|
||||
def _env_int(self, name: str, default: int) -> int:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
async def _create_game_board(self, game_state:dict) -> GameBoard:
|
||||
game_id = game_state['game']['id']
|
||||
new_game_board = GameBoard(
|
||||
|
||||
@@ -3,17 +3,17 @@ from server.GameBoard import GameBoard
|
||||
class Dataset:
|
||||
VALID_MOVES = {"up", "down", "left", "right"}
|
||||
|
||||
def __init__(self, game_board: GameBoard):
|
||||
def __init__(self, game_board:GameBoard):
|
||||
self.game_board = game_board
|
||||
|
||||
def _did_we_win(self):
|
||||
winners = self.game_board.winner_snake_names or []
|
||||
return "me" in winners
|
||||
|
||||
def _is_good_move(self, move: str):
|
||||
def _is_good_move(self, move:str):
|
||||
return move in self.VALID_MOVES
|
||||
|
||||
def build(self, only_good_moves: bool = True):
|
||||
def build(self, only_good_moves:bool=True):
|
||||
game_type = self.game_board.get_type_of_game()
|
||||
did_win = self._did_we_win()
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import shutil
|
||||
import argparse, hashlib, shutil, json
|
||||
from pathlib import Path
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class DatasetCurator:
|
||||
def __init__(
|
||||
self,
|
||||
input_files: list[str],
|
||||
output_file: str,
|
||||
min_turn: int = 6,
|
||||
late_turn: int = 20,
|
||||
max_safe_options: int = 2,
|
||||
min_score: int = 3,
|
||||
append: bool = False,
|
||||
archive_input: bool = False,
|
||||
archive_dir: str | None = None,
|
||||
):
|
||||
def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None):
|
||||
self.input_files = input_files
|
||||
self.output_file = Path(output_file)
|
||||
self.min_turn = min_turn
|
||||
@@ -26,82 +13,38 @@ class DatasetCurator:
|
||||
self.min_score = min_score
|
||||
self.append = append
|
||||
self.archive_input = archive_input
|
||||
self.archive_dir = (
|
||||
Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
|
||||
)
|
||||
self.archive_dir = Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
|
||||
|
||||
def _resolve_input_files(self):
|
||||
resolved = []
|
||||
seen = set()
|
||||
|
||||
for item in self.input_files:
|
||||
path = Path(item)
|
||||
if path.is_dir():
|
||||
for file_path in sorted(path.rglob("*.jsonl")):
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if any(ch in item for ch in "*?[]"):
|
||||
for match in sorted(glob.glob(item)):
|
||||
file_path = Path(match)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
key = str(path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(path)
|
||||
|
||||
return resolved
|
||||
|
||||
def _safe_options_count(self, row: dict):
|
||||
def _safe_options_count(self, row:dict):
|
||||
history = row.get("history", {})
|
||||
for item in history.get("data", []):
|
||||
if item.get("function") == "get_possible_moves":
|
||||
return len(item.get("safe_positions", {}))
|
||||
return None
|
||||
|
||||
def _state_hash(self, row: dict):
|
||||
def _state_hash(self, row:dict):
|
||||
board = row.get("game_board", {})
|
||||
snakes = board.get("snakes", [])
|
||||
|
||||
snakes_key = []
|
||||
for snake in snakes:
|
||||
snakes_key.append(
|
||||
(
|
||||
snake.get("id"),
|
||||
snake.get("health"),
|
||||
tuple(
|
||||
(seg.get("x"), seg.get("y")) for seg in snake.get("body", [])
|
||||
),
|
||||
)
|
||||
)
|
||||
snakes_key.append((
|
||||
snake.get("id"),
|
||||
snake.get("health"),
|
||||
tuple((seg.get("x"), seg.get("y")) for seg in snake.get("body", [])),
|
||||
))
|
||||
|
||||
key = {
|
||||
"width": board.get("width"),
|
||||
"height": board.get("height"),
|
||||
"snakes": sorted(snakes_key),
|
||||
"food": sorted((f.get("x"), f.get("y")) for f in board.get("food", [])),
|
||||
"hazards": sorted(
|
||||
(h.get("x"), h.get("y")) for h in board.get("hazards", [])
|
||||
),
|
||||
"hazards": sorted((h.get("x"), h.get("y")) for h in board.get("hazards", [])),
|
||||
}
|
||||
raw = json.dumps(key, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
def _score(self, row: dict):
|
||||
def _score(self, row:dict):
|
||||
score = 0
|
||||
turn = int(row.get("turn", 0))
|
||||
safe_options = self._safe_options_count(row)
|
||||
@@ -119,7 +62,7 @@ class DatasetCurator:
|
||||
|
||||
def curate(self):
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
input_paths = self._resolve_input_files()
|
||||
input_paths = DatasetIO.resolve_input_files(self.input_files)
|
||||
|
||||
total = 0
|
||||
kept = 0
|
||||
@@ -129,56 +72,47 @@ class DatasetCurator:
|
||||
seen_states = set()
|
||||
|
||||
if self.append and self.output_file.exists():
|
||||
with self.output_file.open("r", encoding="utf-8") as existing:
|
||||
for line in existing:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
state_key = self._state_hash(row)
|
||||
seen_states.add((state_key, row.get("move")))
|
||||
for row in DatasetIO.iter_jsonl_rows(self.output_file):
|
||||
state_key = self._state_hash(row)
|
||||
seen_states.add((state_key, row.get("move")))
|
||||
|
||||
mode = "a" if self.append else "w"
|
||||
with self.output_file.open(mode, encoding="utf-8") as dst:
|
||||
with DatasetIO.open_text(self.output_file, mode) as dst:
|
||||
for input_path in input_paths:
|
||||
with input_path.open("r", encoding="utf-8") as src:
|
||||
for line in src:
|
||||
if not line.strip():
|
||||
continue
|
||||
for row in DatasetIO.iter_jsonl_rows(input_path):
|
||||
total += 1
|
||||
|
||||
total += 1
|
||||
row = json.loads(line)
|
||||
if not row.get("is_good_move", False):
|
||||
skipped_quality += 1
|
||||
continue
|
||||
|
||||
if not row.get("is_good_move", False):
|
||||
skipped_quality += 1
|
||||
continue
|
||||
if int(row.get("turn", 0)) < self.min_turn:
|
||||
skipped_turn += 1
|
||||
continue
|
||||
|
||||
if int(row.get("turn", 0)) < self.min_turn:
|
||||
skipped_turn += 1
|
||||
continue
|
||||
quality_score, safe_options = self._score(row)
|
||||
if quality_score < self.min_score:
|
||||
skipped_quality += 1
|
||||
continue
|
||||
|
||||
quality_score, safe_options = self._score(row)
|
||||
if quality_score < self.min_score:
|
||||
skipped_quality += 1
|
||||
continue
|
||||
state_key = self._state_hash(row)
|
||||
dedupe_key = (state_key, row.get("move"))
|
||||
if dedupe_key in seen_states:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
seen_states.add(dedupe_key)
|
||||
|
||||
state_key = self._state_hash(row)
|
||||
dedupe_key = (state_key, row.get("move"))
|
||||
if dedupe_key in seen_states:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
seen_states.add(dedupe_key)
|
||||
|
||||
compact_row = {
|
||||
"game_id": row.get("game_id"),
|
||||
"turn": row.get("turn"),
|
||||
"move": row.get("move"),
|
||||
"game_type": row.get("game_type"),
|
||||
"quality_score": quality_score,
|
||||
"safe_options": safe_options,
|
||||
"game_board": row.get("game_board"),
|
||||
}
|
||||
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
|
||||
kept += 1
|
||||
compact_row = {
|
||||
"game_id": row.get("game_id"),
|
||||
"turn": row.get("turn"),
|
||||
"move": row.get("move"),
|
||||
"game_type": row.get("game_type"),
|
||||
"quality_score": quality_score,
|
||||
"safe_options": safe_options,
|
||||
"game_board": row.get("game_board"),
|
||||
}
|
||||
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
|
||||
kept += 1
|
||||
|
||||
archived_files = []
|
||||
if self.archive_input:
|
||||
@@ -197,7 +131,7 @@ class DatasetCurator:
|
||||
"output_file": str(self.output_file),
|
||||
}
|
||||
|
||||
def _archive_processed_files(self, input_paths: list[Path]):
|
||||
def _archive_processed_files(self, input_paths:list[Path]):
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
archived = []
|
||||
|
||||
@@ -235,14 +169,13 @@ class DatasetCurator:
|
||||
|
||||
return archived
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Create curated best-moves dataset")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
)
|
||||
parser.add_argument("--output", required=True, help="Output JSONL file")
|
||||
parser.add_argument("--min-turn", type=int, default=6)
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import argparse, json
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class DatasetExporter:
|
||||
def __init__(self, input_dir:str, output_file:str):
|
||||
@@ -9,9 +9,7 @@ class DatasetExporter:
|
||||
self.output_file = Path(output_file)
|
||||
|
||||
def _iter_game_files(self):
|
||||
if not self.input_dir.exists():
|
||||
return []
|
||||
return sorted(self.input_dir.rglob("*.json"))
|
||||
return DatasetIO.list_directory_files(self.input_dir, directory_pattern="*.json")
|
||||
|
||||
def _extract_samples(self, payload:dict, source_file:Path):
|
||||
dataset = payload.get("dataset", {})
|
||||
@@ -39,7 +37,7 @@ class DatasetExporter:
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sample_count = 0
|
||||
with self.output_file.open("w", encoding="utf-8") as output:
|
||||
with DatasetIO.open_text(self.output_file, "w") as output:
|
||||
for game_file in game_files:
|
||||
with game_file.open("r", encoding="utf-8") as source:
|
||||
payload = json.load(source)
|
||||
@@ -56,8 +54,12 @@ class DatasetExporter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL")
|
||||
parser.add_argument("--input", default="data", help="Input directory with stored game JSON files")
|
||||
parser.add_argument("--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file")
|
||||
parser.add_argument(
|
||||
"--input", default="data", help="Input directory with stored game JSON files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
report = DatasetExporter(args.input, args.output).export_jsonl()
|
||||
@@ -0,0 +1,134 @@
|
||||
from typing import Any, TextIO, cast
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
import gzip, glob, json
|
||||
|
||||
class DatasetIO:
|
||||
@staticmethod
|
||||
def resolve_input_files(input_files: list[str], directory_pattern: str | tuple[str, ...] = ("*.jsonl", "*.jsonl.gz")) -> list[Path]:
|
||||
patterns = (
|
||||
(directory_pattern,)
|
||||
if isinstance(directory_pattern, str)
|
||||
else directory_pattern
|
||||
)
|
||||
resolved = []
|
||||
seen = set()
|
||||
|
||||
for item in input_files:
|
||||
path = Path(item)
|
||||
if path.is_dir():
|
||||
for pattern in patterns:
|
||||
for file_path in sorted(path.rglob(pattern)):
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if any(ch in item for ch in "*?[]"):
|
||||
for match in sorted(glob.glob(item)):
|
||||
file_path = Path(match)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
key = str(path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(path)
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def list_directory_files(input_dir:Path, directory_pattern:str) -> list[Path]:
|
||||
if not input_dir.exists():
|
||||
return []
|
||||
return sorted(input_dir.rglob(directory_pattern))
|
||||
|
||||
@staticmethod
|
||||
def open_text(path:Path, mode:str="r"):
|
||||
text_mode = mode if "t" in mode else f"{mode}t"
|
||||
open_fn = gzip.open if path.suffix == ".gz" else open
|
||||
return cast(TextIO, open_fn(path, text_mode, encoding="utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def iter_jsonl_rows(path:Path):
|
||||
with DatasetIO.open_text(path, "r") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
yield json.loads(line)
|
||||
|
||||
@staticmethod
|
||||
def append_jsonl_row(path:Path, row:dict[str, Any]):
|
||||
with DatasetIO.open_text(path, "a") as raw_handle:
|
||||
handle = cast(TextIO, raw_handle)
|
||||
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
|
||||
@staticmethod
|
||||
def count_jsonl_rows(path:Path) -> int:
|
||||
count = 0
|
||||
|
||||
candidates = [path]
|
||||
if path.suffix == ".gz":
|
||||
candidates.append(path.with_suffix(""))
|
||||
else:
|
||||
candidates.append(Path(f"{path}.gz"))
|
||||
|
||||
seen = set()
|
||||
for candidate in candidates:
|
||||
if candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
|
||||
if not candidate.exists() or not candidate.is_file():
|
||||
continue
|
||||
|
||||
try:
|
||||
with DatasetIO.open_text(candidate, "r") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
count += 1
|
||||
except (OSError, UnicodeError):
|
||||
continue
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool:
|
||||
if max_bytes <= 0:
|
||||
return False
|
||||
if not path.exists() or not path.is_file():
|
||||
return False
|
||||
if path.stat().st_size < max_bytes:
|
||||
return False
|
||||
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
||||
if path.suffix == ".jsonl":
|
||||
rotated_name = f"{path.stem}.{timestamp}.jsonl"
|
||||
else:
|
||||
rotated_name = f"{path.name}.{timestamp}"
|
||||
|
||||
rotated_path = path.with_name(rotated_name)
|
||||
suffix = 1
|
||||
while rotated_path.exists():
|
||||
suffix += 1
|
||||
if path.suffix == ".jsonl":
|
||||
rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl"
|
||||
else:
|
||||
rotated_name = f"{path.name}.{timestamp}.{suffix}"
|
||||
rotated_path = path.with_name(rotated_name)
|
||||
|
||||
path.rename(rotated_path)
|
||||
with rotated_path.open("rb") as src:
|
||||
with gzip.open(f"{rotated_path}.gz", "wb") as dst:
|
||||
dst.writelines(src)
|
||||
rotated_path.unlink()
|
||||
return True
|
||||
@@ -1,68 +1,30 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
import argparse, json, re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class DatasetStats:
|
||||
DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})")
|
||||
|
||||
def __init__(self, input_files: list[str]):
|
||||
def __init__(self, input_files:list[str]):
|
||||
self.input_files = input_files
|
||||
|
||||
def _resolve_input_files(self):
|
||||
resolved = []
|
||||
seen = set()
|
||||
|
||||
for item in self.input_files:
|
||||
path = Path(item)
|
||||
if path.is_dir():
|
||||
for file_path in sorted(path.rglob("*.jsonl")):
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if any(ch in item for ch in "*?[]"):
|
||||
for match in sorted(glob.glob(item)):
|
||||
file_path = Path(match)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
key = str(path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(path)
|
||||
|
||||
return resolved
|
||||
|
||||
def _infer_day(self, file_path: Path):
|
||||
def _infer_day(self, file_path:Path):
|
||||
match = self.DAY_PATTERN.search(file_path.name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return datetime.fromtimestamp(file_path.stat().st_mtime).strftime("%Y-%m-%d")
|
||||
|
||||
def _game_score(self, game: dict):
|
||||
def _game_score(self, game:dict):
|
||||
max_turn = game["max_turn"]
|
||||
rows = game["rows"]
|
||||
avg_safe = game["avg_safe_options"]
|
||||
pressure_bonus = 0 if avg_safe is None else max(0.0, 4.0 - avg_safe)
|
||||
return round(max_turn * 2.0 + rows + pressure_bonus, 3)
|
||||
|
||||
def _pressure_score(self, game: dict):
|
||||
def _pressure_score(self, game:dict):
|
||||
max_turn = game["max_turn"]
|
||||
rows = max(1, game["rows"])
|
||||
pressure_turns = game["pressure_turns"]
|
||||
@@ -72,7 +34,7 @@ class DatasetStats:
|
||||
safe_tightness = 0.0 if avg_safe is None else max(0.0, 3.0 - avg_safe)
|
||||
return round(max_turn * 1.2 + pressure_ratio * 120.0 + safe_tightness * 20.0, 3)
|
||||
|
||||
def _extract_safe_options(self, row: dict):
|
||||
def _extract_safe_options(self, row:dict):
|
||||
top_level = row.get("safe_options")
|
||||
if isinstance(top_level, int):
|
||||
return top_level
|
||||
@@ -87,7 +49,7 @@ class DatasetStats:
|
||||
return None
|
||||
|
||||
def analyze(self):
|
||||
files = self._resolve_input_files()
|
||||
files = DatasetIO.resolve_input_files(self.input_files)
|
||||
|
||||
totals = {
|
||||
"rows": 0,
|
||||
@@ -103,57 +65,52 @@ class DatasetStats:
|
||||
|
||||
for file_path in files:
|
||||
day = self._infer_day(file_path)
|
||||
with file_path.open("r", encoding="utf-8") as source:
|
||||
for line in source:
|
||||
if not line.strip():
|
||||
continue
|
||||
for row in DatasetIO.iter_jsonl_rows(file_path):
|
||||
game_id = row.get("game_id")
|
||||
if not game_id:
|
||||
continue
|
||||
|
||||
row = json.loads(line)
|
||||
game_id = row.get("game_id")
|
||||
if not game_id:
|
||||
continue
|
||||
turn = int(row.get("turn", 0))
|
||||
safe_options = self._extract_safe_options(row)
|
||||
snake_type = row.get("snake_type", "unknown")
|
||||
move = row.get("move", "unknown")
|
||||
|
||||
turn = int(row.get("turn", 0))
|
||||
safe_options = self._extract_safe_options(row)
|
||||
snake_type = row.get("snake_type", "unknown")
|
||||
move = row.get("move", "unknown")
|
||||
game_type = row.get("game_type", {})
|
||||
if isinstance(game_type, dict):
|
||||
game_type_name = game_type.get("name", "unknown")
|
||||
else:
|
||||
game_type_name = str(game_type)
|
||||
|
||||
game_type = row.get("game_type", {})
|
||||
if isinstance(game_type, dict):
|
||||
game_type_name = game_type.get("name", "unknown")
|
||||
else:
|
||||
game_type_name = str(game_type)
|
||||
totals["rows"] += 1
|
||||
totals["games"].add(game_id)
|
||||
totals["snake_types"][snake_type] += 1
|
||||
totals["game_types"][game_type_name] += 1
|
||||
totals["moves"][move] += 1
|
||||
totals["days"][day] += 1
|
||||
|
||||
totals["rows"] += 1
|
||||
totals["games"].add(game_id)
|
||||
totals["snake_types"][snake_type] += 1
|
||||
totals["game_types"][game_type_name] += 1
|
||||
totals["moves"][move] += 1
|
||||
totals["days"][day] += 1
|
||||
if game_id not in games:
|
||||
games[game_id] = {
|
||||
"game_id": game_id,
|
||||
"day": day,
|
||||
"snake_type": snake_type,
|
||||
"game_type": game_type_name,
|
||||
"rows": 0,
|
||||
"max_turn": -1,
|
||||
"safe_options_sum": 0,
|
||||
"safe_options_count": 0,
|
||||
"pressure_turns": 0,
|
||||
}
|
||||
|
||||
if game_id not in games:
|
||||
games[game_id] = {
|
||||
"game_id": game_id,
|
||||
"day": day,
|
||||
"snake_type": snake_type,
|
||||
"game_type": game_type_name,
|
||||
"rows": 0,
|
||||
"max_turn": -1,
|
||||
"safe_options_sum": 0,
|
||||
"safe_options_count": 0,
|
||||
"pressure_turns": 0,
|
||||
}
|
||||
game = games[game_id]
|
||||
game["rows"] += 1
|
||||
game["max_turn"] = max(game["max_turn"], turn)
|
||||
if isinstance(safe_options, int):
|
||||
game["safe_options_sum"] += safe_options
|
||||
game["safe_options_count"] += 1
|
||||
if safe_options <= 2:
|
||||
game["pressure_turns"] += 1
|
||||
|
||||
game = games[game_id]
|
||||
game["rows"] += 1
|
||||
game["max_turn"] = max(game["max_turn"], turn)
|
||||
if isinstance(safe_options, int):
|
||||
game["safe_options_sum"] += safe_options
|
||||
game["safe_options_count"] += 1
|
||||
if safe_options <= 2:
|
||||
game["pressure_turns"] += 1
|
||||
|
||||
day_games[day].add(game_id)
|
||||
day_games[day].add(game_id)
|
||||
|
||||
game_summaries = []
|
||||
for game in games.values():
|
||||
@@ -229,7 +186,7 @@ if __name__ == "__main__":
|
||||
"--input",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
@@ -0,0 +1,62 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import os
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class RLBootstrapDataset:
|
||||
def __init__(self):
|
||||
self.enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
|
||||
self.min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
|
||||
self.base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
|
||||
self.output_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
|
||||
self.max_bytes = int(float(os.getenv("RL_BOOTSTRAP_MAX_MB", "50")) * 1024 * 1024)
|
||||
self.needs_more_data = False
|
||||
|
||||
@staticmethod
|
||||
def _env_bool(name:str, default:bool=False) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
@staticmethod
|
||||
def _env_int(name:str, default:int) -> int:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
def refresh_state(self):
|
||||
if not self.enabled:
|
||||
self.needs_more_data = False
|
||||
return
|
||||
|
||||
base_rows = DatasetIO.count_jsonl_rows(self.base_dataset_path)
|
||||
self.needs_more_data = base_rows < self.min_base_rows
|
||||
|
||||
def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None):
|
||||
if not self.enabled or not self.needs_more_data:
|
||||
return
|
||||
|
||||
try:
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
row = {
|
||||
"source": "best_battlesnake_bootstrap",
|
||||
"game_id": getattr(game_data, "id", None),
|
||||
"turn": game_data.get_turn(),
|
||||
"move": move,
|
||||
"safe_moves": list(safe_moves.keys()),
|
||||
"reason": reason,
|
||||
"game_board": game_data.get_game_board_as_dict(),
|
||||
}
|
||||
if scores:
|
||||
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
|
||||
|
||||
DatasetIO.append_jsonl_row(self.output_path, row)
|
||||
DatasetIO.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes)
|
||||
except Exception:
|
||||
return
|
||||
@@ -0,0 +1,6 @@
|
||||
from .Dataset import Dataset
|
||||
from .DatasetIO import DatasetIO
|
||||
from .DatasetExporter import DatasetExporter
|
||||
from .DatasetCurator import DatasetCurator
|
||||
from .DatasetStats import DatasetStats
|
||||
from .RLBootstrapDataset import RLBootstrapDataset
|
||||
+1
-2
@@ -1,6 +1,5 @@
|
||||
from server.GameBoard import GameBoard
|
||||
import inspect
|
||||
import pickle
|
||||
import inspect, pickle
|
||||
|
||||
class RedisGameBoardStore:
|
||||
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:gameboard", ttl_seconds:int=900, **kwargs):
|
||||
@@ -1,7 +1,7 @@
|
||||
from server.game_board_stats.MemoryGameBoardStore import MemoryGameBoardStore
|
||||
from server.game_board_stats.RedisGameBoardStore import RedisGameBoardStore
|
||||
from .MemoryGameBoardStore import MemoryGameBoardStore
|
||||
from .RedisGameBoardStore import RedisGameBoardStore
|
||||
|
||||
class GameBoardStoreBuilder:
|
||||
class GameStateStoreBuilder:
|
||||
@classmethod
|
||||
def build(self, backend:str="memory", **kwargs) -> MemoryGameBoardStore|RedisGameBoardStore:
|
||||
selected = (backend or "memory").strip().lower()
|
||||
@@ -1,9 +1,9 @@
|
||||
from server.metrics.backends.Template import StoreTemplate
|
||||
|
||||
import time
|
||||
|
||||
from server.metrics.MetricsManager import MetricsManager
|
||||
|
||||
class ServerMetricsCollector:
|
||||
def __init__(self, metrics_manager:MetricsManager, game_state_local_cache:bool, metrics_backend:str, game_state_backend:str, stale_game_timeout_sec:int, game_last_seen_unix:dict, game_move_counts:dict,):
|
||||
class MetricsCollector:
|
||||
def __init__(self, metrics_manager:StoreTemplate, game_state_local_cache:bool, metrics_backend:str, game_state_backend:str, stale_game_timeout_sec:int, game_last_seen_unix:dict, game_move_counts:dict):
|
||||
self._manager = metrics_manager
|
||||
self._stale_game_timeout_sec = stale_game_timeout_sec
|
||||
self._game_last_seen_unix = game_last_seen_unix
|
||||
@@ -167,6 +167,12 @@ class ServerMetricsCollector:
|
||||
local_snapshot = self.build_local_snapshot(game_last_seen_unix, game_move_counts)
|
||||
return await self._manager.snapshot(local_snapshot)
|
||||
|
||||
async def clear_worker_metrics(self) -> None:
|
||||
await self._manager.clear_all_workers()
|
||||
|
||||
async def should_clear_worker_metrics_on_startup(self, lock_ttl_seconds:int=300) -> bool:
|
||||
return await self._manager.acquire_startup_cleanup_lock(lock_ttl_seconds)
|
||||
|
||||
def build_prometheus_metrics(self, snapshot:dict) -> str:
|
||||
lines = [
|
||||
'# HELP snake_games_started_total Total games started by snake server.',
|
||||
@@ -1,9 +1,10 @@
|
||||
from server.metrics.MemoryMetricsStore import MemoryMetricsStore
|
||||
from server.metrics.RedisMetricsStore import RedisMetricsStore
|
||||
from .backends import StoreTemplate, MemoryMetricsStore, RedisMetricsStore
|
||||
|
||||
from .MetricsCollector import MetricsCollector
|
||||
|
||||
class MetricsStoreBuilder:
|
||||
@classmethod
|
||||
def build(self, backend:str="memory", **kwargs) -> MemoryMetricsStore|RedisMetricsStore:
|
||||
def build(self, backend:str="memory", **kwargs) -> StoreTemplate:
|
||||
selected = (backend or "memory").strip().lower()
|
||||
if selected == "redis":
|
||||
return RedisMetricsStore(**kwargs)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
class MemoryMetricsStore:
|
||||
from server.metrics.backends.Template import StoreTemplate
|
||||
|
||||
class MemoryMetricsStore(StoreTemplate):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(backend="memory", **kwargs)
|
||||
self._snapshots:dict[str, dict] = {}
|
||||
|
||||
async def publish(self, worker_id:str, snapshot:dict) -> None:
|
||||
@@ -8,5 +11,11 @@ class MemoryMetricsStore:
|
||||
async def load_all(self) -> list[dict]:
|
||||
return [dict(value) for value in self._snapshots.values()]
|
||||
|
||||
async def clear_all(self) -> None:
|
||||
self._snapshots.clear()
|
||||
|
||||
async def _acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300) -> bool:
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
@@ -1,8 +1,10 @@
|
||||
import inspect
|
||||
import json
|
||||
from server.metrics.backends.Template import StoreTemplate
|
||||
|
||||
class RedisMetricsStore:
|
||||
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:metrics:worker", ttl_seconds:int=None, **kwargs):
|
||||
import inspect, json
|
||||
|
||||
class RedisMetricsStore(StoreTemplate):
|
||||
def __init__(self, redis_url:str="redis://localhost:6379/0", key_prefix:str="snake:metrics:worker", ttl_seconds:int|None=None, **kwargs):
|
||||
super().__init__(backend="redis", key_prefix=key_prefix, **kwargs)
|
||||
self.redis_url = redis_url
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl_seconds = ttl_seconds
|
||||
@@ -41,6 +43,17 @@ class RedisMetricsStore:
|
||||
continue
|
||||
return snapshots
|
||||
|
||||
async def clear_all(self) -> None:
|
||||
redis = await self._get_redis()
|
||||
keys = await redis.keys(f"{self.key_prefix}:*")
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
|
||||
async def _acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300) -> bool:
|
||||
redis = await self._get_redis()
|
||||
locked = await redis.set(lock_key, '1', ex=max(1, int(ttl_seconds)), nx=True)
|
||||
return bool(locked)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._redis is None:
|
||||
return
|
||||
@@ -1,17 +1,18 @@
|
||||
from server.metrics import MetricsStoreBuilder
|
||||
from typing import Any, Awaitable, cast
|
||||
import inspect, time, os
|
||||
|
||||
import time, os
|
||||
|
||||
class MetricsManager:
|
||||
def __init__(self, backend:str="memory", redis_url:str="redis://localhost:6379/0", ttl_seconds:int=90, key_prefix:str="snake:metrics:worker", worker_id:str|None=None):
|
||||
class StoreTemplate:
|
||||
def __init__(self, backend:str="memory", key_prefix:str="snake:metrics:worker", worker_id:str|None=None, **kwargs):
|
||||
self.backend = (backend or "memory").strip().lower()
|
||||
self.key_prefix = key_prefix
|
||||
self.worker_id = worker_id or f"{os.getpid()}-{int(time.time() * 1000)}"
|
||||
self.store = MetricsStoreBuilder.build(
|
||||
backend=self.backend,
|
||||
redis_url=redis_url,
|
||||
ttl_seconds=ttl_seconds,
|
||||
key_prefix=key_prefix,
|
||||
)
|
||||
self.store = self
|
||||
|
||||
async def publish(self, worker_id:str, snapshot:dict) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def load_all(self) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def publish_only(self, snapshot:dict) -> None:
|
||||
await self.store.publish(self.worker_id, snapshot)
|
||||
@@ -28,7 +29,37 @@ class MetricsManager:
|
||||
return self._merge_snapshots(snapshots)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.store.close()
|
||||
if self.store is self:
|
||||
return
|
||||
close_store = getattr(self.store, "close", None)
|
||||
if not callable(close_store):
|
||||
return
|
||||
maybe_result = close_store()
|
||||
if inspect.isawaitable(maybe_result):
|
||||
await cast(Awaitable[Any], maybe_result)
|
||||
|
||||
async def clear_all_workers(self) -> None:
|
||||
clear_all = getattr(self.store, "clear_all", None)
|
||||
if callable(clear_all):
|
||||
maybe_result = clear_all()
|
||||
if inspect.isawaitable(maybe_result):
|
||||
await cast(Awaitable[Any], maybe_result)
|
||||
|
||||
async def acquire_startup_cleanup_lock(self, ttl_seconds:int=300) -> bool:
|
||||
if self.backend != "redis":
|
||||
return True
|
||||
|
||||
acquire_lock = getattr(self.store, "_acquire_startup_cleanup_lock", None)
|
||||
if not callable(acquire_lock):
|
||||
acquire_lock = getattr(self.store, "acquire_startup_cleanup_lock", None)
|
||||
if not callable(acquire_lock):
|
||||
return True
|
||||
|
||||
lock_key = f"{self.key_prefix}:startup_cleanup_lock"
|
||||
maybe_result = acquire_lock(lock_key, ttl_seconds)
|
||||
if inspect.isawaitable(maybe_result):
|
||||
return bool(await cast(Awaitable[Any], maybe_result))
|
||||
return bool(maybe_result)
|
||||
|
||||
def _merge_snapshots(self, snapshots:list[dict]) -> dict:
|
||||
merged = {
|
||||
@@ -86,11 +117,11 @@ class MetricsManager:
|
||||
merged["max_turn"] = max(merged["max_turn"], int(worker.get("max_turn", 0)))
|
||||
merged["active_games_peak"] = max(merged["active_games_peak"], int(worker.get("active_games_peak", 0)))
|
||||
merged["move_response_time_ms_max"] = max(merged["move_response_time_ms_max"], float(worker.get("move_response_time_ms_max", 0.0)))
|
||||
merged["last_game_start_unix"] = max(merged["last_game_start_unix"], int(worker.get("last_game_start_unix", 0)),)
|
||||
merged["last_game_start_unix"] = max(merged["last_game_start_unix"], int(worker.get("last_game_start_unix", 0)))
|
||||
merged["last_game_end_unix"] = max(merged["last_game_end_unix"], int(worker.get("last_game_end_unix", 0)))
|
||||
merged["last_move_unix"] = max(merged["last_move_unix"], int(worker.get("last_move_unix", 0)))
|
||||
merged["oldest_active_game_age_sec"] = max(merged["oldest_active_game_age_sec"], int(worker.get("oldest_active_game_age_sec", 0)),)
|
||||
merged["stale_game_timeout_sec"] = max(merged["stale_game_timeout_sec"], int(worker.get("stale_game_timeout_sec", 0)),)
|
||||
merged["oldest_active_game_age_sec"] = max(merged["oldest_active_game_age_sec"], int(worker.get("oldest_active_game_age_sec", 0)))
|
||||
merged["stale_game_timeout_sec"] = max(merged["stale_game_timeout_sec"], int(worker.get("stale_game_timeout_sec", 0)))
|
||||
merged["game_state_local_cache_enabled"] = merged["game_state_local_cache_enabled"] or bool(worker.get("game_state_local_cache_enabled", False))
|
||||
|
||||
for endpoint in merged["http_requests_by_endpoint"]:
|
||||
@@ -0,0 +1,3 @@
|
||||
from .Template import StoreTemplate
|
||||
from .Memory import MemoryMetricsStore
|
||||
from .Redis import RedisMetricsStore
|
||||
@@ -1,5 +1,5 @@
|
||||
from server.GameBoard import GameBoard
|
||||
from server.Dataset import Dataset
|
||||
from server.dataset.Dataset import Dataset
|
||||
|
||||
from datetime import datetime
|
||||
import json, time
|
||||
@@ -55,8 +55,8 @@ class EdgeDB:
|
||||
calculations = snake_calulations[i] if i < len(snake_calulations) else []
|
||||
calculations.append({
|
||||
"dataset": {
|
||||
"is_good_move": labels_by_turn.get(moves[i]["turn"], False)
|
||||
}
|
||||
"is_good_move": labels_by_turn.get(moves[i]["turn"], False)
|
||||
}
|
||||
})
|
||||
data.append({
|
||||
"turn": moves[i]["turn"],
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from server.dataset.Dataset import Dataset
|
||||
from server.GameBoard import GameBoard
|
||||
from server.Dataset import Dataset
|
||||
from server.Files import save_file
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
import gzip
|
||||
import json, os
|
||||
import gzip, json, os
|
||||
|
||||
class LocalStorage:
|
||||
def __init__(self, file_path:str, **kwargs):
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
|
||||
class StorageLoader:
|
||||
@classmethod
|
||||
def build(self, selected_storage:str):
|
||||
storage_module = __import__(f'server.storage.{selected_storage}', fromlist=[selected_storage])
|
||||
storage_class = getattr(storage_module, selected_storage)
|
||||
return storage_class
|
||||
@@ -0,0 +1,6 @@
|
||||
class StorageLoader:
|
||||
@classmethod
|
||||
def build(self, selected_storage: str):
|
||||
storage_module = __import__(f"server.storage.{selected_storage}", fromlist=[selected_storage])
|
||||
storage_class = getattr(storage_module, selected_storage)
|
||||
return storage_class
|
||||
+11
-60
@@ -2,14 +2,15 @@ from collections.abc import Iterator
|
||||
from collections import deque
|
||||
from typing import Any, cast
|
||||
from time import perf_counter
|
||||
from pathlib import Path
|
||||
import random, json, os
|
||||
import random, os
|
||||
|
||||
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
|
||||
|
||||
from snakes.TemplateSnake import TemplateSnake
|
||||
from server.GameBoard import GameBoard
|
||||
|
||||
class BestBattleSnake(TemplateSnake):
|
||||
VERSION = "2.6.0"
|
||||
VERSION = "2.6.1"
|
||||
Point = tuple[int, int]
|
||||
Coord = dict[str, int]
|
||||
SnakeState = dict[str, Any]
|
||||
@@ -40,11 +41,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.previous_hazards = set()
|
||||
self.duel_style = self._get_duel_style()
|
||||
self.timeout_buffer_ms = self._get_timeout_buffer_ms()
|
||||
self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
|
||||
self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
|
||||
self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
|
||||
self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
|
||||
self.rl_needs_more_data = False
|
||||
self.rl_bootstrap = RLBootstrapDataset()
|
||||
self.future_planning_depth = max(1, min(4, self._env_int("BATTLE_FUTURE_PLANNING_DEPTH", default=2)))
|
||||
self.future_planning_branch = max(1, min(3, self._env_int("BATTLE_FUTURE_PLANNING_BRANCH", default=2)))
|
||||
self.future_planning_min_time_ms = max(25, self._env_int("BATTLE_FUTURE_PLANNING_MIN_MS", default=70))
|
||||
@@ -88,12 +85,6 @@ class BestBattleSnake(TemplateSnake):
|
||||
except ValueError:
|
||||
return 120
|
||||
|
||||
def _env_bool(self, name:str, default:bool=False) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
def _env_int(self, name:str, default:int) -> int:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
@@ -103,46 +94,6 @@ class BestBattleSnake(TemplateSnake):
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
def _count_jsonl_rows(self, path:Path) -> int:
|
||||
if not path.exists() or not path.is_file():
|
||||
return 0
|
||||
count = 0
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _refresh_rl_bootstrap_state(self):
|
||||
if not self.rl_bootstrap_enabled:
|
||||
self.rl_needs_more_data = False
|
||||
return
|
||||
base_rows = self._count_jsonl_rows(self.rl_base_dataset_path)
|
||||
self.rl_needs_more_data = base_rows < self.rl_min_base_rows
|
||||
|
||||
def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None):
|
||||
if not self.rl_bootstrap_enabled or not self.rl_needs_more_data:
|
||||
return
|
||||
|
||||
try:
|
||||
self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
row = {
|
||||
"source": "best_battlesnake_bootstrap",
|
||||
"game_id": getattr(game_data, "id", None),
|
||||
"turn": game_data.get_turn(),
|
||||
"move": move,
|
||||
"safe_moves": list(safe_moves.keys()),
|
||||
"reason": reason,
|
||||
"game_board": game_data.get_game_board_as_dict(),
|
||||
}
|
||||
if scores:
|
||||
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
|
||||
|
||||
with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def choose_move(self, game_data:GameBoard) -> str:
|
||||
"""Pick the next move from a Battlesnake move request.
|
||||
|
||||
@@ -162,7 +113,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.last_move = None
|
||||
self.previous_hazards = set()
|
||||
self.last_game_id = game_id
|
||||
self._refresh_rl_bootstrap_state()
|
||||
self.rl_bootstrap.refresh_state()
|
||||
|
||||
my_snake = cast(dict[str, Any], game_data.get_my_snake())
|
||||
my_head = my_snake["head"]
|
||||
@@ -215,7 +166,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
"reason": "no_safe_moves",
|
||||
}
|
||||
)
|
||||
self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves")
|
||||
self.rl_bootstrap.record_sample(game_data, fallback, safe_moves, "no_safe_moves")
|
||||
self.previous_hazards = set(hazard_set)
|
||||
return fallback
|
||||
|
||||
@@ -246,7 +197,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.recent_heads.append(current_head_point)
|
||||
self.last_move = best_move
|
||||
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
|
||||
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores)
|
||||
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "constrictor", scores)
|
||||
self.previous_hazards = set(hazard_set)
|
||||
return best_move
|
||||
|
||||
@@ -270,7 +221,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.recent_heads.append(current_head_point)
|
||||
self.last_move = best_move
|
||||
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
|
||||
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores)
|
||||
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "duel", scores)
|
||||
self.previous_hazards = set(hazard_set)
|
||||
return best_move
|
||||
|
||||
@@ -418,7 +369,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.recent_heads.append(current_head_point)
|
||||
self.last_move = quick_move
|
||||
self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"})
|
||||
self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget")
|
||||
self.rl_bootstrap.record_sample(game_data, quick_move, safe_moves, "timeout_budget")
|
||||
self.previous_hazards = set(hazard_set)
|
||||
return quick_move
|
||||
|
||||
@@ -462,7 +413,7 @@ class BestBattleSnake(TemplateSnake):
|
||||
self.recent_heads.append(current_head_point)
|
||||
self.last_move = best_move
|
||||
self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
|
||||
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores)
|
||||
self.rl_bootstrap.record_sample(game_data, best_move, safe_moves, "multi", scores)
|
||||
self.previous_hazards = set(hazard_set)
|
||||
return best_move
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ SNAKE_REGISTRY = {
|
||||
"BetterMasterSnake": "1.3.0",
|
||||
"BestBattleSnake": "2.6.0",
|
||||
"TrainedBattleSnake": "0.1.0",
|
||||
"UltimateBattleSnake": "4.1.0",
|
||||
}
|
||||
|
||||
def build_snake(selected_snake: str):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from typing import cast
|
||||
|
||||
from server.Dataset import Dataset
|
||||
from server.dataset.Dataset import Dataset
|
||||
from server.GameBoard import GameBoard
|
||||
|
||||
class DummySnake:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import tempfile, json, gzip
|
||||
from pathlib import Path
|
||||
|
||||
from server.DatasetExporter import DatasetExporter
|
||||
from server.dataset.DatasetExporter import DatasetExporter
|
||||
|
||||
class TestDatasetExporter(unittest.TestCase):
|
||||
def test_export_jsonl(self):
|
||||
@@ -43,5 +43,41 @@ class TestDatasetExporter(unittest.TestCase):
|
||||
self.assertEqual(first["move"], "up")
|
||||
self.assertTrue(first["is_good_move"])
|
||||
|
||||
def test_export_jsonl_gz(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
input_dir = Path(tmp) / "data"
|
||||
output_file = Path(tmp) / "out" / "dataset.jsonl.gz"
|
||||
game_file = input_dir / "game-1.json"
|
||||
game_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
game_payload = {
|
||||
"dataset": {
|
||||
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
|
||||
"snake": {"type": "BestBattleSnake"},
|
||||
"samples": [
|
||||
{
|
||||
"turn": 1,
|
||||
"move": "up",
|
||||
"is_good_move": True,
|
||||
"game_board": {"width": 11, "height": 11},
|
||||
"history": {"data": []},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
|
||||
|
||||
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
|
||||
|
||||
self.assertEqual(report["games_scanned"], 1)
|
||||
self.assertEqual(report["samples_exported"], 1)
|
||||
self.assertTrue(output_file.exists())
|
||||
|
||||
with gzip.open(output_file, "rt", encoding="utf-8") as handle:
|
||||
lines = handle.read().strip().splitlines()
|
||||
self.assertEqual(len(lines), 1)
|
||||
first = json.loads(lines[0])
|
||||
self.assertEqual(first["game_id"], "g-1")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
|
||||
import tempfile, gzip, json
|
||||
from pathlib import Path
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class TestDatasetIO(unittest.TestCase):
|
||||
def test_resolve_input_files_includes_jsonl_and_gz(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base = Path(tmp)
|
||||
(base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8")
|
||||
with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"x":2}\n')
|
||||
|
||||
paths = DatasetIO.resolve_input_files([str(base)])
|
||||
names = {path.name for path in paths}
|
||||
self.assertIn("a.jsonl", names)
|
||||
self.assertIn("b.jsonl.gz", names)
|
||||
|
||||
def test_iter_jsonl_rows_reads_gz_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "rows.jsonl.gz"
|
||||
with gzip.open(path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
handle.write("\n")
|
||||
handle.write('{"turn":2}\n')
|
||||
|
||||
rows = list(DatasetIO.iter_jsonl_rows(path))
|
||||
self.assertEqual(len(rows), 2)
|
||||
self.assertEqual(rows[0]["turn"], 1)
|
||||
self.assertEqual(rows[1]["turn"], 2)
|
||||
|
||||
def test_append_jsonl_row_supports_gz_output(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "append.jsonl.gz"
|
||||
DatasetIO.append_jsonl_row(path, {"move": "up"})
|
||||
DatasetIO.append_jsonl_row(path, {"move": "left"})
|
||||
|
||||
with gzip.open(path, "rt", encoding="utf-8") as handle:
|
||||
lines = handle.read().strip().splitlines()
|
||||
self.assertEqual(len(lines), 2)
|
||||
self.assertEqual(json.loads(lines[0])["move"], "up")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,9 +2,7 @@ import unittest
|
||||
from typing import Any, cast
|
||||
|
||||
from server.GameBoard import GameBoard
|
||||
from server.game_board_stats import GameBoardStoreBuilder
|
||||
from server.game_board_stats.MemoryGameBoardStore import MemoryGameBoardStore
|
||||
from server.game_board_stats.RedisGameBoardStore import RedisGameBoardStore
|
||||
from server.game_state_store import GameStateStoreBuilder, MemoryGameBoardStore, RedisGameBoardStore
|
||||
from snakes.TemplateSnake import TemplateSnake
|
||||
|
||||
class _FakeRedis:
|
||||
@@ -73,9 +71,9 @@ class TestGameStateStore(unittest.IsolatedAsyncioTestCase):
|
||||
return board
|
||||
|
||||
def test_builder_selects_store_backend(self):
|
||||
memory_store = GameBoardStoreBuilder.build(backend="memory")
|
||||
redis_store = GameBoardStoreBuilder.build(backend="redis")
|
||||
default_store = GameBoardStoreBuilder.build(backend="unknown")
|
||||
memory_store = GameStateStoreBuilder.build(backend="memory")
|
||||
redis_store = GameStateStoreBuilder.build(backend="redis")
|
||||
default_store = GameStateStoreBuilder.build(backend="unknown")
|
||||
|
||||
self.assertIsInstance(memory_store, MemoryGameBoardStore)
|
||||
self.assertIsInstance(redis_store, RedisGameBoardStore)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import unittest
|
||||
|
||||
from server.metrics.MetricsManager import MetricsManager
|
||||
from typing import Any, cast
|
||||
|
||||
class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
|
||||
from server.metrics import (
|
||||
MetricsStoreBuilder,
|
||||
MemoryMetricsStore,
|
||||
)
|
||||
|
||||
class TestMetricsStoreTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_memory_backend_returns_local_snapshot(self):
|
||||
manager = MetricsManager(backend="memory")
|
||||
manager = MetricsStoreBuilder.build(backend="memory")
|
||||
local = {
|
||||
"games_started": 2,
|
||||
"games_ended": 1,
|
||||
@@ -48,7 +53,7 @@ class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
|
||||
await manager.close()
|
||||
|
||||
async def test_merge_snapshots_aggregates_totals(self):
|
||||
manager = MetricsManager(backend="memory")
|
||||
manager = MemoryMetricsStore()
|
||||
merged = manager._merge_snapshots(
|
||||
[
|
||||
{
|
||||
@@ -139,5 +144,27 @@ class TestMetricsManager(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(merged["metrics_backend"], "redis")
|
||||
await manager.close()
|
||||
|
||||
async def test_acquire_startup_cleanup_lock_uses_store_for_redis_backend(self):
|
||||
class FakeStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def acquire_startup_cleanup_lock(self, lock_key:str, ttl_seconds:int=300):
|
||||
self.calls.append((lock_key, ttl_seconds))
|
||||
return True
|
||||
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
manager = MetricsStoreBuilder.build(backend="redis", key_prefix="snake:metrics:worker")
|
||||
fake_store = FakeStore()
|
||||
manager.store = cast(Any, fake_store)
|
||||
|
||||
allowed = await manager.acquire_startup_cleanup_lock(180)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEqual(fake_store.calls, [("snake:metrics:worker:startup_cleanup_lock", 180)])
|
||||
|
||||
await manager.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,100 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from pathlib import Path
|
||||
import tempfile, gzip
|
||||
|
||||
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class TestRLBootstrapDataset(unittest.TestCase):
|
||||
def test_count_jsonl_rows_reads_gzip_dataset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
dataset_path = Path(tmp) / "base.jsonl.gz"
|
||||
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
handle.write("\n")
|
||||
handle.write('{"turn":2}\n')
|
||||
|
||||
self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
|
||||
self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
|
||||
|
||||
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "rl_bootstrap.jsonl"
|
||||
path.write_text("x" * 200, encoding="utf-8")
|
||||
|
||||
rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
|
||||
|
||||
self.assertTrue(rotated)
|
||||
self.assertFalse(path.exists())
|
||||
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
||||
|
||||
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base_path = Path(tmp) / "base.jsonl.gz"
|
||||
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RL_BOOTSTRAP_ENABLED": "1",
|
||||
"RL_MIN_BASE_ROWS": "2",
|
||||
"RL_BASE_DATASET": str(base_path),
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
dataset = RLBootstrapDataset()
|
||||
dataset.refresh_state()
|
||||
|
||||
self.assertTrue(dataset.needs_more_data)
|
||||
|
||||
def test_rl_bootstrap_dataset_autocompresses_output(self):
|
||||
class DummyBoard:
|
||||
id = "game-1"
|
||||
|
||||
def get_turn(self):
|
||||
return 1
|
||||
|
||||
def get_game_board_as_dict(self):
|
||||
return {
|
||||
"turn": 1,
|
||||
"board": {
|
||||
"width": 11,
|
||||
"height": 11,
|
||||
"food": [],
|
||||
"hazards": [],
|
||||
"snakes": [],
|
||||
},
|
||||
"you": {
|
||||
"id": "me",
|
||||
"head": {"x": 5, "y": 5},
|
||||
"body": [{"x": 5, "y": 5}],
|
||||
},
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_path = Path(tmp) / "rl_bootstrap.jsonl"
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RL_BOOTSTRAP_ENABLED": "1",
|
||||
"RL_BOOTSTRAP_MAX_MB": "0.00001",
|
||||
"RL_BOOTSTRAP_OUTPUT": str(output_path),
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
dataset = RLBootstrapDataset()
|
||||
dataset.needs_more_data = True
|
||||
dataset.record_sample(
|
||||
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
|
||||
)
|
||||
dataset.record_sample(
|
||||
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user