rework folder structure complitly
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
from server.GameBoard import GameBoard
|
||||
|
||||
class Dataset:
|
||||
VALID_MOVES = {"up", "down", "left", "right"}
|
||||
|
||||
def __init__(self, game_board:GameBoard):
|
||||
self.game_board = game_board
|
||||
|
||||
def _did_we_win(self):
|
||||
winners = self.game_board.winner_snake_names or []
|
||||
return "me" in winners
|
||||
|
||||
def _is_good_move(self, move:str):
|
||||
return move in self.VALID_MOVES
|
||||
|
||||
def build(self, only_good_moves:bool=True):
|
||||
game_type = self.game_board.get_type_of_game()
|
||||
did_win = self._did_we_win()
|
||||
|
||||
samples = []
|
||||
history = self.game_board.snake_class.get_history()
|
||||
for index, turn in enumerate(self.game_board.turns):
|
||||
move = turn.get("move")
|
||||
is_good_move = did_win and self._is_good_move(move)
|
||||
if only_good_moves and not is_good_move:
|
||||
continue
|
||||
|
||||
samples.append({
|
||||
"turn": turn.get("turn"),
|
||||
"move": move,
|
||||
"game_board": turn.get("game_board"),
|
||||
"is_good_move": is_good_move,
|
||||
"history": history[index] if index < len(history) else {},
|
||||
})
|
||||
|
||||
return {
|
||||
"game": {
|
||||
"id": self.game_board.id,
|
||||
"map": self.game_board.map,
|
||||
"type": game_type,
|
||||
},
|
||||
"snake": {
|
||||
"type": self.game_board.snake_class.__class__.__name__,
|
||||
},
|
||||
"did_win": did_win,
|
||||
"total_samples": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
def labels_by_turn(self):
|
||||
did_win = self._did_we_win()
|
||||
labels = {}
|
||||
for turn in self.game_board.turns:
|
||||
move = turn.get("move")
|
||||
labels[turn.get("turn")] = did_win and self._is_good_move(move)
|
||||
return labels
|
||||
@@ -0,0 +1,258 @@
|
||||
import argparse, hashlib, shutil, glob, json
|
||||
from pathlib import Path
|
||||
|
||||
class DatasetCurator:
|
||||
def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None):
|
||||
self.input_files = input_files
|
||||
self.output_file = Path(output_file)
|
||||
self.min_turn = min_turn
|
||||
self.late_turn = late_turn
|
||||
self.max_safe_options = max_safe_options
|
||||
self.min_score = min_score
|
||||
self.append = append
|
||||
self.archive_input = archive_input
|
||||
self.archive_dir = (
|
||||
Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
|
||||
)
|
||||
|
||||
def _resolve_input_files(self):
|
||||
resolved = []
|
||||
seen = set()
|
||||
|
||||
for item in self.input_files:
|
||||
path = Path(item)
|
||||
if path.is_dir():
|
||||
for file_path in sorted(path.rglob("*.jsonl")):
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if any(ch in item for ch in "*?[]"):
|
||||
for match in sorted(glob.glob(item)):
|
||||
file_path = Path(match)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
key = str(path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(path)
|
||||
|
||||
return resolved
|
||||
|
||||
def _safe_options_count(self, row:dict):
|
||||
history = row.get("history", {})
|
||||
for item in history.get("data", []):
|
||||
if item.get("function") == "get_possible_moves":
|
||||
return len(item.get("safe_positions", {}))
|
||||
return None
|
||||
|
||||
def _state_hash(self, row:dict):
|
||||
board = row.get("game_board", {})
|
||||
snakes = board.get("snakes", [])
|
||||
|
||||
snakes_key = []
|
||||
for snake in snakes:
|
||||
snakes_key.append((
|
||||
snake.get("id"),
|
||||
snake.get("health"),
|
||||
tuple((seg.get("x"), seg.get("y")) for seg in snake.get("body", [])),
|
||||
))
|
||||
|
||||
key = {
|
||||
"width": board.get("width"),
|
||||
"height": board.get("height"),
|
||||
"snakes": sorted(snakes_key),
|
||||
"food": sorted((f.get("x"), f.get("y")) for f in board.get("food", [])),
|
||||
"hazards": sorted((h.get("x"), h.get("y")) for h in board.get("hazards", [])),
|
||||
}
|
||||
raw = json.dumps(key, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
def _score(self, row:dict):
|
||||
score = 0
|
||||
turn = int(row.get("turn", 0))
|
||||
safe_options = self._safe_options_count(row)
|
||||
snakes = row.get("game_board", {}).get("snakes", [])
|
||||
opponents = max(0, len(snakes) - 1)
|
||||
|
||||
if turn >= self.late_turn:
|
||||
score += 2
|
||||
if safe_options is not None and safe_options <= self.max_safe_options:
|
||||
score += 3
|
||||
if opponents >= 1:
|
||||
score += 1
|
||||
|
||||
return score, safe_options
|
||||
|
||||
def curate(self):
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
input_paths = self._resolve_input_files()
|
||||
|
||||
total = 0
|
||||
kept = 0
|
||||
skipped_turn = 0
|
||||
skipped_quality = 0
|
||||
skipped_duplicate = 0
|
||||
seen_states = set()
|
||||
|
||||
if self.append and self.output_file.exists():
|
||||
with self.output_file.open("r", encoding="utf-8") as existing:
|
||||
for line in existing:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
state_key = self._state_hash(row)
|
||||
seen_states.add((state_key, row.get("move")))
|
||||
|
||||
mode = "a" if self.append else "w"
|
||||
with self.output_file.open(mode, encoding="utf-8") as dst:
|
||||
for input_path in input_paths:
|
||||
with input_path.open("r", encoding="utf-8") as src:
|
||||
for line in src:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
total += 1
|
||||
row = json.loads(line)
|
||||
|
||||
if not row.get("is_good_move", False):
|
||||
skipped_quality += 1
|
||||
continue
|
||||
|
||||
if int(row.get("turn", 0)) < self.min_turn:
|
||||
skipped_turn += 1
|
||||
continue
|
||||
|
||||
quality_score, safe_options = self._score(row)
|
||||
if quality_score < self.min_score:
|
||||
skipped_quality += 1
|
||||
continue
|
||||
|
||||
state_key = self._state_hash(row)
|
||||
dedupe_key = (state_key, row.get("move"))
|
||||
if dedupe_key in seen_states:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
seen_states.add(dedupe_key)
|
||||
|
||||
compact_row = {
|
||||
"game_id": row.get("game_id"),
|
||||
"turn": row.get("turn"),
|
||||
"move": row.get("move"),
|
||||
"game_type": row.get("game_type"),
|
||||
"quality_score": quality_score,
|
||||
"safe_options": safe_options,
|
||||
"game_board": row.get("game_board"),
|
||||
}
|
||||
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
|
||||
kept += 1
|
||||
|
||||
archived_files = []
|
||||
if self.archive_input:
|
||||
archived_files = self._archive_processed_files(input_paths)
|
||||
|
||||
return {
|
||||
"input_files": [str(path) for path in input_paths],
|
||||
"total_rows": total,
|
||||
"kept_rows": kept,
|
||||
"skipped_turn": skipped_turn,
|
||||
"skipped_quality": skipped_quality,
|
||||
"skipped_duplicate": skipped_duplicate,
|
||||
"append_mode": self.append,
|
||||
"archive_input": self.archive_input,
|
||||
"archived_files": archived_files,
|
||||
"output_file": str(self.output_file),
|
||||
}
|
||||
|
||||
def _archive_processed_files(self, input_paths:list[Path]):
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
archived = []
|
||||
|
||||
output_resolved = (
|
||||
self.output_file.resolve()
|
||||
if self.output_file.exists()
|
||||
else self.output_file
|
||||
)
|
||||
archive_resolved = self.archive_dir.resolve()
|
||||
|
||||
for source_path in input_paths:
|
||||
if not source_path.exists():
|
||||
continue
|
||||
|
||||
source_resolved = source_path.resolve()
|
||||
if source_resolved == output_resolved:
|
||||
continue
|
||||
if source_resolved.parent == archive_resolved:
|
||||
continue
|
||||
|
||||
destination = self.archive_dir / source_path.name
|
||||
if destination.exists():
|
||||
stem = destination.stem
|
||||
suffix = destination.suffix
|
||||
index = 1
|
||||
while True:
|
||||
candidate = self.archive_dir / f"{stem}.{index}{suffix}"
|
||||
if not candidate.exists():
|
||||
destination = candidate
|
||||
break
|
||||
index += 1
|
||||
|
||||
shutil.move(str(source_path), str(destination))
|
||||
archived.append(str(destination))
|
||||
|
||||
return archived
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Create curated best-moves dataset")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
)
|
||||
parser.add_argument("--output", required=True, help="Output JSONL file")
|
||||
parser.add_argument("--min-turn", type=int, default=6)
|
||||
parser.add_argument("--late-turn", type=int, default=20)
|
||||
parser.add_argument("--max-safe-options", type=int, default=2)
|
||||
parser.add_argument("--min-score", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--append",
|
||||
action="store_true",
|
||||
help="Append to existing output and dedupe against existing rows",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--archive-input",
|
||||
action="store_true",
|
||||
help="Move processed input files to archive directory after successful curation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--archive-dir",
|
||||
default=None,
|
||||
help="Archive directory for processed input files (default: <output-dir>/archive)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
report = DatasetCurator(
|
||||
input_files=args.input,
|
||||
output_file=args.output,
|
||||
min_turn=args.min_turn,
|
||||
late_turn=args.late_turn,
|
||||
max_safe_options=args.max_safe_options,
|
||||
min_score=args.min_score,
|
||||
append=args.append,
|
||||
archive_input=args.archive_input,
|
||||
archive_dir=args.archive_dir,
|
||||
).curate()
|
||||
print(json.dumps(report, indent=2))
|
||||
@@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
import argparse, json
|
||||
|
||||
class DatasetExporter:
|
||||
def __init__(self, input_dir:str, output_file:str):
|
||||
self.input_dir = Path(input_dir)
|
||||
self.output_file = Path(output_file)
|
||||
|
||||
def _iter_game_files(self):
|
||||
if not self.input_dir.exists():
|
||||
return []
|
||||
return sorted(self.input_dir.rglob("*.json"))
|
||||
|
||||
def _extract_samples(self, payload:dict, source_file:Path):
|
||||
dataset = payload.get("dataset", {})
|
||||
game_info = dataset.get("game", payload.get("game", {}))
|
||||
snake_info = dataset.get("snake", payload.get("snake", {}))
|
||||
|
||||
samples = []
|
||||
for sample in dataset.get("samples", []):
|
||||
samples.append({
|
||||
"game_id": game_info.get("id"),
|
||||
"game_map": game_info.get("map"),
|
||||
"game_type": game_info.get("type"),
|
||||
"snake_type": snake_info.get("type"),
|
||||
"turn": sample.get("turn"),
|
||||
"move": sample.get("move"),
|
||||
"is_good_move": sample.get("is_good_move", False),
|
||||
"game_board": sample.get("game_board"),
|
||||
"history": sample.get("history"),
|
||||
"source_file": str(source_file),
|
||||
})
|
||||
return samples
|
||||
|
||||
def export_jsonl(self):
|
||||
game_files = self._iter_game_files()
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sample_count = 0
|
||||
with self.output_file.open("w", encoding="utf-8") as output:
|
||||
for game_file in game_files:
|
||||
with game_file.open("r", encoding="utf-8") as source:
|
||||
payload = json.load(source)
|
||||
|
||||
for sample in self._extract_samples(payload, game_file):
|
||||
output.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||||
sample_count += 1
|
||||
|
||||
return {
|
||||
"games_scanned": len(game_files),
|
||||
"samples_exported": sample_count,
|
||||
"output_file": str(self.output_file),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Export Battlesnake dataset to JSONL")
|
||||
parser.add_argument(
|
||||
"--input", default="data", help="Input directory with stored game JSON files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", default="data/dataset/good_moves.jsonl", help="Output JSONL file"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
report = DatasetExporter(args.input, args.output).export_jsonl()
|
||||
print(json.dumps(report, indent=2))
|
||||
@@ -0,0 +1,243 @@
|
||||
from collections import Counter, defaultdict
|
||||
import argparse, glob, json, re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
class DatasetStats:
|
||||
DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})")
|
||||
|
||||
def __init__(self, input_files:list[str]):
|
||||
self.input_files = input_files
|
||||
|
||||
def _resolve_input_files(self):
|
||||
resolved = []
|
||||
seen = set()
|
||||
|
||||
for item in self.input_files:
|
||||
path = Path(item)
|
||||
if path.is_dir():
|
||||
for file_path in sorted(path.rglob("*.jsonl")):
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if any(ch in item for ch in "*?[]"):
|
||||
for match in sorted(glob.glob(item)):
|
||||
file_path = Path(match)
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
key = str(file_path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(file_path)
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
key = str(path.resolve())
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
resolved.append(path)
|
||||
|
||||
return resolved
|
||||
|
||||
def _infer_day(self, file_path:Path):
|
||||
match = self.DAY_PATTERN.search(file_path.name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return datetime.fromtimestamp(file_path.stat().st_mtime).strftime("%Y-%m-%d")
|
||||
|
||||
def _game_score(self, game:dict):
|
||||
max_turn = game["max_turn"]
|
||||
rows = game["rows"]
|
||||
avg_safe = game["avg_safe_options"]
|
||||
pressure_bonus = 0 if avg_safe is None else max(0.0, 4.0 - avg_safe)
|
||||
return round(max_turn * 2.0 + rows + pressure_bonus, 3)
|
||||
|
||||
def _pressure_score(self, game:dict):
|
||||
max_turn = game["max_turn"]
|
||||
rows = max(1, game["rows"])
|
||||
pressure_turns = game["pressure_turns"]
|
||||
avg_safe = game["avg_safe_options"]
|
||||
|
||||
pressure_ratio = pressure_turns / rows
|
||||
safe_tightness = 0.0 if avg_safe is None else max(0.0, 3.0 - avg_safe)
|
||||
return round(max_turn * 1.2 + pressure_ratio * 120.0 + safe_tightness * 20.0, 3)
|
||||
|
||||
def _extract_safe_options(self, row:dict):
|
||||
top_level = row.get("safe_options")
|
||||
if isinstance(top_level, int):
|
||||
return top_level
|
||||
|
||||
history = row.get("history", {})
|
||||
for item in history.get("data", []):
|
||||
if item.get("function") != "get_possible_moves":
|
||||
continue
|
||||
safe_positions = item.get("safe_positions", {})
|
||||
if isinstance(safe_positions, dict):
|
||||
return len(safe_positions)
|
||||
return None
|
||||
|
||||
def analyze(self):
|
||||
files = self._resolve_input_files()
|
||||
|
||||
totals = {
|
||||
"rows": 0,
|
||||
"games": set(),
|
||||
"snake_types": Counter(),
|
||||
"game_types": Counter(),
|
||||
"moves": Counter(),
|
||||
"days": Counter(),
|
||||
}
|
||||
|
||||
games = {}
|
||||
day_games = defaultdict(set)
|
||||
|
||||
for file_path in files:
|
||||
day = self._infer_day(file_path)
|
||||
with file_path.open("r", encoding="utf-8") as source:
|
||||
for line in source:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
row = json.loads(line)
|
||||
game_id = row.get("game_id")
|
||||
if not game_id:
|
||||
continue
|
||||
|
||||
turn = int(row.get("turn", 0))
|
||||
safe_options = self._extract_safe_options(row)
|
||||
snake_type = row.get("snake_type", "unknown")
|
||||
move = row.get("move", "unknown")
|
||||
|
||||
game_type = row.get("game_type", {})
|
||||
if isinstance(game_type, dict):
|
||||
game_type_name = game_type.get("name", "unknown")
|
||||
else:
|
||||
game_type_name = str(game_type)
|
||||
|
||||
totals["rows"] += 1
|
||||
totals["games"].add(game_id)
|
||||
totals["snake_types"][snake_type] += 1
|
||||
totals["game_types"][game_type_name] += 1
|
||||
totals["moves"][move] += 1
|
||||
totals["days"][day] += 1
|
||||
|
||||
if game_id not in games:
|
||||
games[game_id] = {
|
||||
"game_id": game_id,
|
||||
"day": day,
|
||||
"snake_type": snake_type,
|
||||
"game_type": game_type_name,
|
||||
"rows": 0,
|
||||
"max_turn": -1,
|
||||
"safe_options_sum": 0,
|
||||
"safe_options_count": 0,
|
||||
"pressure_turns": 0,
|
||||
}
|
||||
|
||||
game = games[game_id]
|
||||
game["rows"] += 1
|
||||
game["max_turn"] = max(game["max_turn"], turn)
|
||||
if isinstance(safe_options, int):
|
||||
game["safe_options_sum"] += safe_options
|
||||
game["safe_options_count"] += 1
|
||||
if safe_options <= 2:
|
||||
game["pressure_turns"] += 1
|
||||
|
||||
day_games[day].add(game_id)
|
||||
|
||||
game_summaries = []
|
||||
for game in games.values():
|
||||
avg_safe = None
|
||||
if game["safe_options_count"] > 0:
|
||||
avg_safe = round(
|
||||
game["safe_options_sum"] / game["safe_options_count"], 3
|
||||
)
|
||||
item = {
|
||||
"game_id": game["game_id"],
|
||||
"day": game["day"],
|
||||
"snake_type": game["snake_type"],
|
||||
"game_type": game["game_type"],
|
||||
"rows": game["rows"],
|
||||
"max_turn": game["max_turn"],
|
||||
"avg_safe_options": avg_safe,
|
||||
"pressure_turns": game["pressure_turns"],
|
||||
}
|
||||
item["score"] = self._game_score(item)
|
||||
item["pressure_score"] = self._pressure_score(item)
|
||||
game_summaries.append(item)
|
||||
|
||||
game_summaries.sort(
|
||||
key=lambda x: (x["score"], x["max_turn"], x["rows"]), reverse=True
|
||||
)
|
||||
|
||||
best_overall = game_summaries[0] if game_summaries else None
|
||||
pressure_sorted = sorted(
|
||||
game_summaries,
|
||||
key=lambda x: (x["pressure_score"], x["max_turn"], x["rows"]),
|
||||
reverse=True,
|
||||
)
|
||||
best_pressure_overall = pressure_sorted[0] if pressure_sorted else None
|
||||
|
||||
by_day = {}
|
||||
for day, game_ids in sorted(day_games.items()):
|
||||
day_list = [item for item in game_summaries if item["game_id"] in game_ids]
|
||||
day_list.sort(
|
||||
key=lambda x: (x["score"], x["max_turn"], x["rows"]), reverse=True
|
||||
)
|
||||
day_pressure = sorted(
|
||||
day_list,
|
||||
key=lambda x: (x["pressure_score"], x["max_turn"], x["rows"]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
by_day[day] = {
|
||||
"rows": totals["days"][day],
|
||||
"games": len(game_ids),
|
||||
"best_game": day_list[0] if day_list else None,
|
||||
"best_pressure_game": day_pressure[0] if day_pressure else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"files_scanned": [str(path) for path in files],
|
||||
"overall": {
|
||||
"rows": totals["rows"],
|
||||
"games": len(totals["games"]),
|
||||
"snake_types": dict(totals["snake_types"]),
|
||||
"game_types": dict(totals["game_types"]),
|
||||
"moves": dict(totals["moves"]),
|
||||
"best_game": best_overall,
|
||||
"best_pressure_game": best_pressure_overall,
|
||||
},
|
||||
"by_day": by_day,
|
||||
"top_games": game_summaries[:10],
|
||||
"top_pressure_games": pressure_sorted[:10],
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze Battlesnake JSONL datasets")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Optional path to write JSON report",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
report = DatasetStats(args.input).analyze()
|
||||
print(json.dumps(report, indent=2))
|
||||
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
|
||||
@@ -0,0 +1,4 @@
|
||||
from .Dataset import Dataset
|
||||
from .DatasetExporter import DatasetExporter
|
||||
from .DatasetCurator import DatasetCurator
|
||||
from .DatasetStats import DatasetStats
|
||||
Reference in New Issue
Block a user