import unittest from typing import cast from server.Dataset import Dataset from server.GameBoard import GameBoard class DummySnake: def get_history(self): return [ {"turn": 1, "data": [{"score": 1}]}, {"turn": 2, "data": [{"score": 2}]}, ] class DummyGameBoard: def __init__(self, winners): self.id = "game-1" self.map = "standard" self.winner_snake_names = winners self.snake_class = DummySnake() self.turns = [ {"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}}, {"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}}, ] def get_type_of_game(self): return {"name": "standard", "is_ladder": False} class TestDataset(unittest.TestCase): def test_build_only_good_moves_for_wins(self): dataset = Dataset(cast(GameBoard, DummyGameBoard(["me"]))) payload = dataset.build(only_good_moves=True) self.assertTrue(payload["did_win"]) self.assertEqual(payload["total_samples"], 2) self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"])) def test_build_returns_no_samples_for_losses_when_only_good(self): dataset = Dataset(cast(GameBoard, DummyGameBoard(["enemy"]))) payload = dataset.build(only_good_moves=True) self.assertFalse(payload["did_win"]) self.assertEqual(payload["total_samples"], 0) def test_labels_by_turn(self): winner_labels = Dataset(cast(GameBoard, DummyGameBoard(["me"]))).labels_by_turn() loser_labels = Dataset(cast(GameBoard, DummyGameBoard(["enemy"]))).labels_by_turn() self.assertEqual(winner_labels, {1: True, 2: True}) self.assertEqual(loser_labels, {1: False, 2: False}) if __name__ == "__main__": unittest.main()