add Dataset Class and Tests

This commit is contained in:
2026-04-03 10:35:21 +02:00
parent 6b69d133b6
commit 2e1f91355b
4 changed files with 217 additions and 0 deletions
+50
View File
@@ -0,0 +1,50 @@
import unittest
from server.Dataset import Dataset
class DummySnake:
def get_history(self):
return [
{"turn": 1, "data": [{"score": 1}]},
{"turn": 2, "data": [{"score": 2}]},
]
class DummyGameBoard:
def __init__(self, winners):
self.id = "game-1"
self.map = "standard"
self.winner_snake_names = winners
self.snake_class = DummySnake()
self.turns = [
{"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}},
{"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}},
]
def get_type_of_game(self):
return {"name": "standard", "is_ladder": False}
class TestDataset(unittest.TestCase):
def test_build_only_good_moves_for_wins(self):
dataset = Dataset(DummyGameBoard(["me"]))
payload = dataset.build(only_good_moves=True)
self.assertTrue(payload["did_win"])
self.assertEqual(payload["total_samples"], 2)
self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"]))
def test_build_returns_no_samples_for_losses_when_only_good(self):
dataset = Dataset(DummyGameBoard(["enemy"]))
payload = dataset.build(only_good_moves=True)
self.assertFalse(payload["did_win"])
self.assertEqual(payload["total_samples"], 0)
def test_labels_by_turn(self):
winner_labels = Dataset(DummyGameBoard(["me"])).labels_by_turn()
loser_labels = Dataset(DummyGameBoard(["enemy"])).labels_by_turn()
self.assertEqual(winner_labels, {1: True, 2: True})
self.assertEqual(loser_labels, {1: False, 2: False})
if __name__ == "__main__":
unittest.main()
+47
View File
@@ -0,0 +1,47 @@
import json
import tempfile
import unittest
from pathlib import Path
from server.DatasetExporter import DatasetExporter
class TestDatasetExporter(unittest.TestCase):
def test_export_jsonl(self):
with tempfile.TemporaryDirectory() as tmp:
input_dir = Path(tmp) / "data"
output_file = Path(tmp) / "out" / "dataset.jsonl"
game_file = input_dir / "game-1.json"
game_file.parent.mkdir(parents=True, exist_ok=True)
game_payload = {
"dataset": {
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
"snake": {"type": "BestBattleSnake"},
"samples": [
{
"turn": 1,
"move": "up",
"is_good_move": True,
"game_board": {"width": 11, "height": 11},
"history": {"data": []},
}
],
}
}
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
self.assertEqual(report["games_scanned"], 1)
self.assertEqual(report["samples_exported"], 1)
self.assertTrue(output_file.exists())
lines = output_file.read_text(encoding="utf-8").strip().splitlines()
self.assertEqual(len(lines), 1)
first = json.loads(lines[0])
self.assertEqual(first["game_id"], "g-1")
self.assertEqual(first["move"], "up")
self.assertTrue(first["is_good_move"])
if __name__ == "__main__":
unittest.main()