add Dataset Class and Tests

This commit is contained in:
2026-04-03 10:35:21 +02:00
parent 6b69d133b6
commit 2e1f91355b
4 changed files with 217 additions and 0 deletions
+56
View File
@@ -0,0 +1,56 @@
from server.GameBoard import GameBoard
class Dataset:
VALID_MOVES = {"up", "down", "left", "right"}
def __init__(self, game_board: GameBoard):
self.game_board = game_board
def _did_we_win(self):
winners = self.game_board.winner_snake_names or []
return "me" in winners
def _is_good_move(self, move: str):
return move in self.VALID_MOVES
def build(self, only_good_moves: bool = True):
game_type = self.game_board.get_type_of_game()
did_win = self._did_we_win()
samples = []
history = self.game_board.snake_class.get_history()
for index, turn in enumerate(self.game_board.turns):
move = turn.get("move")
is_good_move = did_win and self._is_good_move(move)
if only_good_moves and not is_good_move:
continue
samples.append({
"turn": turn.get("turn"),
"move": move,
"game_board": turn.get("game_board"),
"is_good_move": is_good_move,
"history": history[index] if index < len(history) else {},
})
return {
"game": {
"id": self.game_board.id,
"map": self.game_board.map,
"type": game_type,
},
"snake": {
"type": self.game_board.snake_class.__class__.__name__,
},
"did_win": did_win,
"total_samples": len(samples),
"samples": samples,
}
def labels_by_turn(self):
did_win = self._did_we_win()
labels = {}
for turn in self.game_board.turns:
move = turn.get("move")
labels[turn.get("turn")] = did_win and self._is_good_move(move)
return labels
+64
View File
@@ -0,0 +1,64 @@
import argparse
import json
from pathlib import Path
class DatasetExporter:
def __init__(self, input_dir:str, output_file:str):
self.input_dir = Path(input_dir)
self.output_file = Path(output_file)
def _iter_game_files(self):
if not self.input_dir.exists():
return []
return sorted(self.input_dir.rglob("*.json"))
def _extract_samples(self, payload:dict, source_file:Path):
dataset = payload.get("dataset", {})
game_info = dataset.get("game", payload.get("game", {}))
snake_info = dataset.get("snake", payload.get("snake", {}))
samples = []
for sample in dataset.get("samples", []):
samples.append({
"game_id": game_info.get("id"),
"game_map": game_info.get("map"),
"game_type": game_info.get("type"),
"snake_type": snake_info.get("type"),
"turn": sample.get("turn"),
"move": sample.get("move"),
"is_good_move": sample.get("is_good_move", False),
"game_board": sample.get("game_board"),
"history": sample.get("history"),
"source_file": str(source_file),
})
return samples
def export_jsonl(self):
game_files = self._iter_game_files()
self.output_file.parent.mkdir(parents=True, exist_ok=True)
sample_count = 0
with self.output_file.open("w", encoding="utf-8") as output:
for game_file in game_files:
with game_file.open("r", encoding="utf-8") as source:
payload = json.load(source)
for sample in self._extract_samples(payload, game_file):
output.write(json.dumps(sample, ensure_ascii=False) + "\n")
sample_count += 1
return {
"games_scanned": len(game_files),
"samples_exported": sample_count,
"output_file": str(self.output_file),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL")
parser.add_argument("--input", default="data", help="Input directory with stored game JSON files")
parser.add_argument("--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file")
args = parser.parse_args()
report = DatasetExporter(args.input, args.output).export_jsonl()
print(json.dumps(report, indent=2))