From 2e1f91355bd79c062f184aaab0f5f71a33899437 Mon Sep 17 00:00:00 2001 From: Daniel Dolezal Date: Fri, 3 Apr 2026 10:35:21 +0200 Subject: [PATCH] add Dataset Class and Tests --- server/Dataset.py | 56 ++++++++++++++++++++++++++++++ server/DatasetExporter.py | 64 +++++++++++++++++++++++++++++++++++ tests/test_Dataset.py | 50 +++++++++++++++++++++++++++ tests/test_DatasetExporter.py | 47 +++++++++++++++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 server/Dataset.py create mode 100644 server/DatasetExporter.py create mode 100644 tests/test_Dataset.py create mode 100644 tests/test_DatasetExporter.py diff --git a/server/Dataset.py b/server/Dataset.py new file mode 100644 index 0000000..02be287 --- /dev/null +++ b/server/Dataset.py @@ -0,0 +1,56 @@ +from server.GameBoard import GameBoard + +class Dataset: + VALID_MOVES = {"up", "down", "left", "right"} + + def __init__(self, game_board: GameBoard): + self.game_board = game_board + + def _did_we_win(self): + winners = self.game_board.winner_snake_names or [] + return "me" in winners + + def _is_good_move(self, move: str): + return move in self.VALID_MOVES + + def build(self, only_good_moves: bool = True): + game_type = self.game_board.get_type_of_game() + did_win = self._did_we_win() + + samples = [] + history = self.game_board.snake_class.get_history() + for index, turn in enumerate(self.game_board.turns): + move = turn.get("move") + is_good_move = did_win and self._is_good_move(move) + if only_good_moves and not is_good_move: + continue + + samples.append({ + "turn": turn.get("turn"), + "move": move, + "game_board": turn.get("game_board"), + "is_good_move": is_good_move, + "history": history[index] if index < len(history) else {}, + }) + + return { + "game": { + "id": self.game_board.id, + "map": self.game_board.map, + "type": game_type, + }, + "snake": { + "type": self.game_board.snake_class.__class__.__name__, + }, + "did_win": did_win, + "total_samples": len(samples), + "samples": samples, + } + + def labels_by_turn(self): + did_win = self._did_we_win() + labels = {} + for turn in self.game_board.turns: + move = turn.get("move") + labels[turn.get("turn")] = did_win and self._is_good_move(move) + return labels diff --git a/server/DatasetExporter.py b/server/DatasetExporter.py new file mode 100644 index 0000000..b936a0a --- /dev/null +++ b/server/DatasetExporter.py @@ -0,0 +1,64 @@ +import argparse +import json +from pathlib import Path + + +class DatasetExporter: + def __init__(self, input_dir:str, output_file:str): + self.input_dir = Path(input_dir) + self.output_file = Path(output_file) + + def _iter_game_files(self): + if not self.input_dir.exists(): + return [] + return sorted(self.input_dir.rglob("*.json")) + + def _extract_samples(self, payload:dict, source_file:Path): + dataset = payload.get("dataset", {}) + game_info = dataset.get("game", payload.get("game", {})) + snake_info = dataset.get("snake", payload.get("snake", {})) + + samples = [] + for sample in dataset.get("samples", []): + samples.append({ + "game_id": game_info.get("id"), + "game_map": game_info.get("map"), + "game_type": game_info.get("type"), + "snake_type": snake_info.get("type"), + "turn": sample.get("turn"), + "move": sample.get("move"), + "is_good_move": sample.get("is_good_move", False), + "game_board": sample.get("game_board"), + "history": sample.get("history"), + "source_file": str(source_file), + }) + return samples + + def export_jsonl(self): + game_files = self._iter_game_files() + self.output_file.parent.mkdir(parents=True, exist_ok=True) + + sample_count = 0 + with self.output_file.open("w", encoding="utf-8") as output: + for game_file in game_files: + with game_file.open("r", encoding="utf-8") as source: + payload = json.load(source) + + for sample in self._extract_samples(payload, game_file): + output.write(json.dumps(sample, ensure_ascii=False) + "\n") + sample_count += 1 + + return { + "games_scanned": len(game_files), + "samples_exported": sample_count, + "output_file": str(self.output_file), + } + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL") + parser.add_argument("--input", default="data", help="Input directory with stored game JSON files") + parser.add_argument("--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file") + args = parser.parse_args() + + report = DatasetExporter(args.input, args.output).export_jsonl() + print(json.dumps(report, indent=2)) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py new file mode 100644 index 0000000..da45ac9 --- /dev/null +++ b/tests/test_Dataset.py @@ -0,0 +1,50 @@ +import unittest + +from server.Dataset import Dataset + +class DummySnake: + def get_history(self): + return [ + {"turn": 1, "data": [{"score": 1}]}, + {"turn": 2, "data": [{"score": 2}]}, + ] + +class DummyGameBoard: + def __init__(self, winners): + self.id = "game-1" + self.map = "standard" + self.winner_snake_names = winners + self.snake_class = DummySnake() + self.turns = [ + {"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}}, + {"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}}, + ] + + def get_type_of_game(self): + return {"name": "standard", "is_ladder": False} + +class TestDataset(unittest.TestCase): + def test_build_only_good_moves_for_wins(self): + dataset = Dataset(DummyGameBoard(["me"])) + payload = dataset.build(only_good_moves=True) + + self.assertTrue(payload["did_win"]) + self.assertEqual(payload["total_samples"], 2) + self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"])) + + def test_build_returns_no_samples_for_losses_when_only_good(self): + dataset = Dataset(DummyGameBoard(["enemy"])) + payload = dataset.build(only_good_moves=True) + + self.assertFalse(payload["did_win"]) + self.assertEqual(payload["total_samples"], 0) + + def test_labels_by_turn(self): + winner_labels = Dataset(DummyGameBoard(["me"])).labels_by_turn() + loser_labels = Dataset(DummyGameBoard(["enemy"])).labels_by_turn() + + self.assertEqual(winner_labels, {1: True, 2: True}) + self.assertEqual(loser_labels, {1: False, 2: False}) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_DatasetExporter.py b/tests/test_DatasetExporter.py new file mode 100644 index 0000000..56bf361 --- /dev/null +++ b/tests/test_DatasetExporter.py @@ -0,0 +1,47 @@ +import json +import tempfile +import unittest +from pathlib import Path + +from server.DatasetExporter import DatasetExporter + +class TestDatasetExporter(unittest.TestCase): + def test_export_jsonl(self): + with tempfile.TemporaryDirectory() as tmp: + input_dir = Path(tmp) / "data" + output_file = Path(tmp) / "out" / "dataset.jsonl" + game_file = input_dir / "game-1.json" + game_file.parent.mkdir(parents=True, exist_ok=True) + + game_payload = { + "dataset": { + "game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}}, + "snake": {"type": "BestBattleSnake"}, + "samples": [ + { + "turn": 1, + "move": "up", + "is_good_move": True, + "game_board": {"width": 11, "height": 11}, + "history": {"data": []}, + } + ], + } + } + game_file.write_text(json.dumps(game_payload), encoding="utf-8") + + report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl() + + self.assertEqual(report["games_scanned"], 1) + self.assertEqual(report["samples_exported"], 1) + self.assertTrue(output_file.exists()) + + lines = output_file.read_text(encoding="utf-8").strip().splitlines() + self.assertEqual(len(lines), 1) + first = json.loads(lines[0]) + self.assertEqual(first["game_id"], "g-1") + self.assertEqual(first["move"], "up") + self.assertTrue(first["is_good_move"]) + +if __name__ == "__main__": + unittest.main()