add Dataset Class and Tests
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
from server.GameBoard import GameBoard
|
||||
|
||||
class Dataset:
|
||||
VALID_MOVES = {"up", "down", "left", "right"}
|
||||
|
||||
def __init__(self, game_board: GameBoard):
|
||||
self.game_board = game_board
|
||||
|
||||
def _did_we_win(self):
|
||||
winners = self.game_board.winner_snake_names or []
|
||||
return "me" in winners
|
||||
|
||||
def _is_good_move(self, move: str):
|
||||
return move in self.VALID_MOVES
|
||||
|
||||
def build(self, only_good_moves: bool = True):
|
||||
game_type = self.game_board.get_type_of_game()
|
||||
did_win = self._did_we_win()
|
||||
|
||||
samples = []
|
||||
history = self.game_board.snake_class.get_history()
|
||||
for index, turn in enumerate(self.game_board.turns):
|
||||
move = turn.get("move")
|
||||
is_good_move = did_win and self._is_good_move(move)
|
||||
if only_good_moves and not is_good_move:
|
||||
continue
|
||||
|
||||
samples.append({
|
||||
"turn": turn.get("turn"),
|
||||
"move": move,
|
||||
"game_board": turn.get("game_board"),
|
||||
"is_good_move": is_good_move,
|
||||
"history": history[index] if index < len(history) else {},
|
||||
})
|
||||
|
||||
return {
|
||||
"game": {
|
||||
"id": self.game_board.id,
|
||||
"map": self.game_board.map,
|
||||
"type": game_type,
|
||||
},
|
||||
"snake": {
|
||||
"type": self.game_board.snake_class.__class__.__name__,
|
||||
},
|
||||
"did_win": did_win,
|
||||
"total_samples": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
def labels_by_turn(self):
|
||||
did_win = self._did_we_win()
|
||||
labels = {}
|
||||
for turn in self.game_board.turns:
|
||||
move = turn.get("move")
|
||||
labels[turn.get("turn")] = did_win and self._is_good_move(move)
|
||||
return labels
|
||||
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DatasetExporter:
|
||||
def __init__(self, input_dir:str, output_file:str):
|
||||
self.input_dir = Path(input_dir)
|
||||
self.output_file = Path(output_file)
|
||||
|
||||
def _iter_game_files(self):
|
||||
if not self.input_dir.exists():
|
||||
return []
|
||||
return sorted(self.input_dir.rglob("*.json"))
|
||||
|
||||
def _extract_samples(self, payload:dict, source_file:Path):
|
||||
dataset = payload.get("dataset", {})
|
||||
game_info = dataset.get("game", payload.get("game", {}))
|
||||
snake_info = dataset.get("snake", payload.get("snake", {}))
|
||||
|
||||
samples = []
|
||||
for sample in dataset.get("samples", []):
|
||||
samples.append({
|
||||
"game_id": game_info.get("id"),
|
||||
"game_map": game_info.get("map"),
|
||||
"game_type": game_info.get("type"),
|
||||
"snake_type": snake_info.get("type"),
|
||||
"turn": sample.get("turn"),
|
||||
"move": sample.get("move"),
|
||||
"is_good_move": sample.get("is_good_move", False),
|
||||
"game_board": sample.get("game_board"),
|
||||
"history": sample.get("history"),
|
||||
"source_file": str(source_file),
|
||||
})
|
||||
return samples
|
||||
|
||||
def export_jsonl(self):
|
||||
game_files = self._iter_game_files()
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sample_count = 0
|
||||
with self.output_file.open("w", encoding="utf-8") as output:
|
||||
for game_file in game_files:
|
||||
with game_file.open("r", encoding="utf-8") as source:
|
||||
payload = json.load(source)
|
||||
|
||||
for sample in self._extract_samples(payload, game_file):
|
||||
output.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||||
sample_count += 1
|
||||
|
||||
return {
|
||||
"games_scanned": len(game_files),
|
||||
"samples_exported": sample_count,
|
||||
"output_file": str(self.output_file),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL")
|
||||
parser.add_argument("--input", default="data", help="Input directory with stored game JSON files")
|
||||
parser.add_argument("--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file")
|
||||
args = parser.parse_args()
|
||||
|
||||
report = DatasetExporter(args.input, args.output).export_jsonl()
|
||||
print(json.dumps(report, indent=2))
|
||||
@@ -0,0 +1,50 @@
|
||||
import unittest
|
||||
|
||||
from server.Dataset import Dataset
|
||||
|
||||
class DummySnake:
|
||||
def get_history(self):
|
||||
return [
|
||||
{"turn": 1, "data": [{"score": 1}]},
|
||||
{"turn": 2, "data": [{"score": 2}]},
|
||||
]
|
||||
|
||||
class DummyGameBoard:
|
||||
def __init__(self, winners):
|
||||
self.id = "game-1"
|
||||
self.map = "standard"
|
||||
self.winner_snake_names = winners
|
||||
self.snake_class = DummySnake()
|
||||
self.turns = [
|
||||
{"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}},
|
||||
{"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}},
|
||||
]
|
||||
|
||||
def get_type_of_game(self):
|
||||
return {"name": "standard", "is_ladder": False}
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
def test_build_only_good_moves_for_wins(self):
|
||||
dataset = Dataset(DummyGameBoard(["me"]))
|
||||
payload = dataset.build(only_good_moves=True)
|
||||
|
||||
self.assertTrue(payload["did_win"])
|
||||
self.assertEqual(payload["total_samples"], 2)
|
||||
self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"]))
|
||||
|
||||
def test_build_returns_no_samples_for_losses_when_only_good(self):
|
||||
dataset = Dataset(DummyGameBoard(["enemy"]))
|
||||
payload = dataset.build(only_good_moves=True)
|
||||
|
||||
self.assertFalse(payload["did_win"])
|
||||
self.assertEqual(payload["total_samples"], 0)
|
||||
|
||||
def test_labels_by_turn(self):
|
||||
winner_labels = Dataset(DummyGameBoard(["me"])).labels_by_turn()
|
||||
loser_labels = Dataset(DummyGameBoard(["enemy"])).labels_by_turn()
|
||||
|
||||
self.assertEqual(winner_labels, {1: True, 2: True})
|
||||
self.assertEqual(loser_labels, {1: False, 2: False})
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,47 @@
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from server.DatasetExporter import DatasetExporter
|
||||
|
||||
class TestDatasetExporter(unittest.TestCase):
|
||||
def test_export_jsonl(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
input_dir = Path(tmp) / "data"
|
||||
output_file = Path(tmp) / "out" / "dataset.jsonl"
|
||||
game_file = input_dir / "game-1.json"
|
||||
game_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
game_payload = {
|
||||
"dataset": {
|
||||
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
|
||||
"snake": {"type": "BestBattleSnake"},
|
||||
"samples": [
|
||||
{
|
||||
"turn": 1,
|
||||
"move": "up",
|
||||
"is_good_move": True,
|
||||
"game_board": {"width": 11, "height": 11},
|
||||
"history": {"data": []},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
|
||||
|
||||
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
|
||||
|
||||
self.assertEqual(report["games_scanned"], 1)
|
||||
self.assertEqual(report["samples_exported"], 1)
|
||||
self.assertTrue(output_file.exists())
|
||||
|
||||
lines = output_file.read_text(encoding="utf-8").strip().splitlines()
|
||||
self.assertEqual(len(lines), 1)
|
||||
first = json.loads(lines[0])
|
||||
self.assertEqual(first["game_id"], "g-1")
|
||||
self.assertEqual(first["move"], "up")
|
||||
self.assertTrue(first["is_good_move"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user