51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import unittest
|
|
|
|
from server.Dataset import Dataset
|
|
|
|
class DummySnake:
|
|
def get_history(self):
|
|
return [
|
|
{"turn": 1, "data": [{"score": 1}]},
|
|
{"turn": 2, "data": [{"score": 2}]},
|
|
]
|
|
|
|
class DummyGameBoard:
|
|
def __init__(self, winners):
|
|
self.id = "game-1"
|
|
self.map = "standard"
|
|
self.winner_snake_names = winners
|
|
self.snake_class = DummySnake()
|
|
self.turns = [
|
|
{"turn": 1, "move": "up", "game_board": {"width": 11, "height": 11}},
|
|
{"turn": 2, "move": "left", "game_board": {"width": 11, "height": 11}},
|
|
]
|
|
|
|
def get_type_of_game(self):
|
|
return {"name": "standard", "is_ladder": False}
|
|
|
|
class TestDataset(unittest.TestCase):
|
|
def test_build_only_good_moves_for_wins(self):
|
|
dataset = Dataset(DummyGameBoard(["me"]))
|
|
payload = dataset.build(only_good_moves=True)
|
|
|
|
self.assertTrue(payload["did_win"])
|
|
self.assertEqual(payload["total_samples"], 2)
|
|
self.assertTrue(all(sample["is_good_move"] for sample in payload["samples"]))
|
|
|
|
def test_build_returns_no_samples_for_losses_when_only_good(self):
|
|
dataset = Dataset(DummyGameBoard(["enemy"]))
|
|
payload = dataset.build(only_good_moves=True)
|
|
|
|
self.assertFalse(payload["did_win"])
|
|
self.assertEqual(payload["total_samples"], 0)
|
|
|
|
def test_labels_by_turn(self):
|
|
winner_labels = Dataset(DummyGameBoard(["me"])).labels_by_turn()
|
|
loser_labels = Dataset(DummyGameBoard(["enemy"])).labels_by_turn()
|
|
|
|
self.assertEqual(winner_labels, {1: True, 2: True})
|
|
self.assertEqual(loser_labels, {1: False, 2: False})
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|