Files
snake-python/tests/test_Dataset.py
T

53 lines
1.7 KiB
Python

import unittest
from typing import cast
from server.dataset.Dataset import Dataset
from server.GameBoard import GameBoard
class DummySnake:
def get_history(self):
return [
{"turn": 1, "data": [{"score": 1}]},
{"turn": 2, "data": [{"score": 2}]},
]
class DummyGameBoard:
def __init__(self, winners):
self.id = "game-1"
self.map = "standard"
self.winner_snake_names = winners
self.snake_class = DummySnake()
self.turns = [
{"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}},
{"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}},
]
def get_type_of_game(self):
return {"name": "standard", "is_ladder": False}
class TestDataset(unittest.TestCase):
def test_build_only_good_moves_for_wins(self):
dataset = Dataset(cast(GameBoard, DummyGameBoard(["me"])))
payload = dataset.build(only_good_moves=True)
self.assertTrue(payload["did_win"])
self.assertEqual(payload["total_samples"], 2)
self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"]))
def test_build_returns_no_samples_for_losses_when_only_good(self):
dataset = Dataset(cast(GameBoard, DummyGameBoard(["enemy"])))
payload = dataset.build(only_good_moves=True)
self.assertFalse(payload["did_win"])
self.assertEqual(payload["total_samples"], 0)
def test_labels_by_turn(self):
winner_labels = Dataset(cast(GameBoard, DummyGameBoard(["me"]))).labels_by_turn()
loser_labels = Dataset(cast(GameBoard, DummyGameBoard(["enemy"]))).labels_by_turn()
self.assertEqual(winner_labels, {1: True, 2: True})
self.assertEqual(loser_labels, {1: False, 2: False})
if __name__ == "__main__":
unittest.main()