57 lines
1.5 KiB
Python
57 lines
1.5 KiB
Python
from server.GameBoard import GameBoard
|
|
|
|
class Dataset:
|
|
VALID_MOVES = {"up", "down", "left", "right"}
|
|
|
|
def __init__(self, game_board: GameBoard):
|
|
self.game_board = game_board
|
|
|
|
def _did_we_win(self):
|
|
winners = self.game_board.winner_snake_names or []
|
|
return "me" in winners
|
|
|
|
def _is_good_move(self, move: str):
|
|
return move in self.VALID_MOVES
|
|
|
|
def build(self, only_good_moves: bool = True):
|
|
game_type = self.game_board.get_type_of_game()
|
|
did_win = self._did_we_win()
|
|
|
|
samples = []
|
|
history = self.game_board.snake_class.get_history()
|
|
for index, turn in enumerate(self.game_board.turns):
|
|
move = turn.get("move")
|
|
is_good_move = did_win and self._is_good_move(move)
|
|
if only_good_moves and not is_good_move:
|
|
continue
|
|
|
|
samples.append({
|
|
"turn": turn.get("turn"),
|
|
"move": move,
|
|
"game_board": turn.get("game_board"),
|
|
"is_good_move": is_good_move,
|
|
"history": history[index] if index < len(history) else {},
|
|
})
|
|
|
|
return {
|
|
"game": {
|
|
"id": self.game_board.id,
|
|
"map": self.game_board.map,
|
|
"type": game_type,
|
|
},
|
|
"snake": {
|
|
"type": self.game_board.snake_class.__class__.__name__,
|
|
},
|
|
"did_win": did_win,
|
|
"total_samples": len(samples),
|
|
"samples": samples,
|
|
}
|
|
|
|
def labels_by_turn(self):
|
|
did_win = self._did_we_win()
|
|
labels = {}
|
|
for turn in self.game_board.turns:
|
|
move = turn.get("move")
|
|
labels[turn.get("turn")] = did_win and self._is_good_move(move)
|
|
return labels
|