from server.GameBoard import GameBoard class Dataset: VALID_MOVES = {"up", "down", "left", "right"} def __init__(self, game_board: GameBoard): self.game_board = game_board def _did_we_win(self): winners = self.game_board.winner_snake_names or [] return "me" in winners def _is_good_move(self, move: str): return move in self.VALID_MOVES def build(self, only_good_moves: bool = True): game_type = self.game_board.get_type_of_game() did_win = self._did_we_win() samples = [] history = self.game_board.snake_class.get_history() for index, turn in enumerate(self.game_board.turns): move = turn.get("move") is_good_move = did_win and self._is_good_move(move) if only_good_moves and not is_good_move: continue samples.append({ "turn": turn.get("turn"), "move": move, "game_board": turn.get("game_board"), "is_good_move": is_good_move, "history": history[index] if index < len(history) else {}, }) return { "game": { "id": self.game_board.id, "map": self.game_board.map, "type": game_type, }, "snake": { "type": self.game_board.snake_class.__class__.__name__, }, "did_win": did_win, "total_samples": len(samples), "samples": samples, } def labels_by_turn(self): did_win = self._did_we_win() labels = {} for turn in self.game_board.turns: move = turn.get("move") labels[turn.get("turn")] = did_win and self._is_good_move(move) return labels