Files
snake-python/server/dataset/Dataset.py
T

60 lines
1.6 KiB
Python

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from server.GameBoard import GameBoard
class Dataset:
VALID_MOVES = {"up", "down", "left", "right"}
def __init__(self, game_board:'GameBoard'):
self.game_board = game_board
def _did_we_win(self):
winners = self.game_board.winner_snake_names or []
return "me" in winners
def _is_good_move(self, move:str):
return move in self.VALID_MOVES
def build(self, only_good_moves:bool=True):
game_type = self.game_board.get_type_of_game()
did_win = self._did_we_win()
samples = []
history = self.game_board.snake_class.get_history()
for index, turn in enumerate(self.game_board.turns):
move = turn.get("move")
is_good_move = did_win and self._is_good_move(move)
if only_good_moves and not is_good_move:
continue
samples.append({
"turn": turn.get("turn"),
"move": move,
"game_board": turn.get("game_board"),
"is_good_move": is_good_move,
"history": history[index] if index < len(history) else {},
})
return {
"game": {
"id": self.game_board.id,
"map": self.game_board.map,
"type": game_type,
},
"snake": {
"type": self.game_board.snake_class.__class__.__name__,
},
"did_win": did_win,
"total_samples": len(samples),
"samples": samples,
}
def labels_by_turn(self):
did_win = self._did_we_win()
labels = {}
for turn in self.game_board.turns:
move = turn.get("move")
labels[turn.get("turn")] = did_win and self._is_good_move(move)
return labels