diff --git a/server/dataset/DatasetCurator.py b/server/dataset/DatasetCurator.py index a6a0f35..cd76911 100644 --- a/server/dataset/DatasetCurator.py +++ b/server/dataset/DatasetCurator.py @@ -1,6 +1,8 @@ -import argparse, hashlib, shutil, glob, json +import argparse, hashlib, shutil, json from pathlib import Path +from server.dataset.DatasetIO import DatasetIO + class DatasetCurator: def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None): self.input_files = input_files @@ -11,45 +13,7 @@ class DatasetCurator: self.min_score = min_score self.append = append self.archive_input = archive_input - self.archive_dir = ( - Path(archive_dir) if archive_dir else self.output_file.parent / "archive" - ) - - def _resolve_input_files(self): - resolved = [] - seen = set() - - for item in self.input_files: - path = Path(item) - if path.is_dir(): - for file_path in sorted(path.rglob("*.jsonl")): - key = str(file_path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(file_path) - continue - - if any(ch in item for ch in "*?[]"): - for match in sorted(glob.glob(item)): - file_path = Path(match) - if not file_path.is_file(): - continue - key = str(file_path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(file_path) - continue - - if path.is_file(): - key = str(path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(path) - - return resolved + self.archive_dir = Path(archive_dir) if archive_dir else self.output_file.parent / "archive" def _safe_options_count(self, row:dict): history = row.get("history", {}) @@ -98,7 +62,7 @@ class DatasetCurator: def curate(self): self.output_file.parent.mkdir(parents=True, exist_ok=True) - input_paths = self._resolve_input_files() + input_paths = DatasetIO.resolve_input_files(self.input_files) total = 0 kept = 0 @@ -108,56 +72,47 @@ class DatasetCurator: seen_states = set() if self.append and self.output_file.exists(): - with self.output_file.open("r", encoding="utf-8") as existing: - for line in existing: - if not line.strip(): - continue - row = json.loads(line) - state_key = self._state_hash(row) - seen_states.add((state_key, row.get("move"))) + for row in DatasetIO.iter_jsonl_rows(self.output_file): + state_key = self._state_hash(row) + seen_states.add((state_key, row.get("move"))) mode = "a" if self.append else "w" - with self.output_file.open(mode, encoding="utf-8") as dst: + with DatasetIO.open_text(self.output_file, mode) as dst: for input_path in input_paths: - with input_path.open("r", encoding="utf-8") as src: - for line in src: - if not line.strip(): - continue + for row in DatasetIO.iter_jsonl_rows(input_path): + total += 1 - total += 1 - row = json.loads(line) + if not row.get("is_good_move", False): + skipped_quality += 1 + continue - if not row.get("is_good_move", False): - skipped_quality += 1 - continue + if int(row.get("turn", 0)) < self.min_turn: + skipped_turn += 1 + continue - if int(row.get("turn", 0)) < self.min_turn: - skipped_turn += 1 - continue + quality_score, safe_options = self._score(row) + if quality_score < self.min_score: + skipped_quality += 1 + continue - quality_score, safe_options = self._score(row) - if quality_score < self.min_score: - skipped_quality += 1 - continue + state_key = self._state_hash(row) + dedupe_key = (state_key, row.get("move")) + if dedupe_key in seen_states: + skipped_duplicate += 1 + continue + seen_states.add(dedupe_key) - state_key = self._state_hash(row) - dedupe_key = (state_key, row.get("move")) - if dedupe_key in seen_states: - skipped_duplicate += 1 - continue - seen_states.add(dedupe_key) - - compact_row = { - "game_id": row.get("game_id"), - "turn": row.get("turn"), - "move": row.get("move"), - "game_type": row.get("game_type"), - "quality_score": quality_score, - "safe_options": safe_options, - "game_board": row.get("game_board"), - } - dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n") - kept += 1 + compact_row = { + "game_id": row.get("game_id"), + "turn": row.get("turn"), + "move": row.get("move"), + "game_type": row.get("game_type"), + "quality_score": quality_score, + "safe_options": safe_options, + "game_board": row.get("game_board"), + } + dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n") + kept += 1 archived_files = [] if self.archive_input: @@ -220,7 +175,7 @@ if __name__ == "__main__": "--input", action="append", required=True, - help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.", + help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.", ) parser.add_argument("--output", required=True, help="Output JSONL file") parser.add_argument("--min-turn", type=int, default=6) diff --git a/server/dataset/DatasetExporter.py b/server/dataset/DatasetExporter.py index c9aee09..bc03387 100644 --- a/server/dataset/DatasetExporter.py +++ b/server/dataset/DatasetExporter.py @@ -1,15 +1,15 @@ from pathlib import Path import argparse, json +from server.dataset.DatasetIO import DatasetIO + class DatasetExporter: def __init__(self, input_dir:str, output_file:str): self.input_dir = Path(input_dir) self.output_file = Path(output_file) def _iter_game_files(self): - if not self.input_dir.exists(): - return [] - return sorted(self.input_dir.rglob("*.json")) + return DatasetIO.list_directory_files(self.input_dir, directory_pattern="*.json") def _extract_samples(self, payload:dict, source_file:Path): dataset = payload.get("dataset", {}) @@ -37,7 +37,7 @@ class DatasetExporter: self.output_file.parent.mkdir(parents=True, exist_ok=True) sample_count = 0 - with self.output_file.open("w", encoding="utf-8") as output: + with DatasetIO.open_text(self.output_file, "w") as output: for game_file in game_files: with game_file.open("r", encoding="utf-8") as source: payload = json.load(source) diff --git a/server/dataset/DatasetIO.py b/server/dataset/DatasetIO.py new file mode 100644 index 0000000..bbf9d78 --- /dev/null +++ b/server/dataset/DatasetIO.py @@ -0,0 +1,134 @@ +from typing import Any, TextIO, cast +from datetime import UTC, datetime +from pathlib import Path +import gzip, glob, json + +class DatasetIO: + @staticmethod + def resolve_input_files(input_files: list[str], directory_pattern: str | tuple[str, ...] = ("*.jsonl", "*.jsonl.gz")) -> list[Path]: + patterns = ( + (directory_pattern,) + if isinstance(directory_pattern, str) + else directory_pattern + ) + resolved = [] + seen = set() + + for item in input_files: + path = Path(item) + if path.is_dir(): + for pattern in patterns: + for file_path in sorted(path.rglob(pattern)): + key = str(file_path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(file_path) + continue + + if any(ch in item for ch in "*?[]"): + for match in sorted(glob.glob(item)): + file_path = Path(match) + if not file_path.is_file(): + continue + key = str(file_path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(file_path) + continue + + if path.is_file(): + key = str(path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(path) + + return resolved + + @staticmethod + def list_directory_files(input_dir:Path, directory_pattern:str) -> list[Path]: + if not input_dir.exists(): + return [] + return sorted(input_dir.rglob(directory_pattern)) + + @staticmethod + def open_text(path:Path, mode:str="r"): + text_mode = mode if "t" in mode else f"{mode}t" + open_fn = gzip.open if path.suffix == ".gz" else open + return cast(TextIO, open_fn(path, text_mode, encoding="utf-8")) + + @staticmethod + def iter_jsonl_rows(path:Path): + with DatasetIO.open_text(path, "r") as handle: + for line in handle: + if line.strip(): + yield json.loads(line) + + @staticmethod + def append_jsonl_row(path:Path, row:dict[str, Any]): + with DatasetIO.open_text(path, "a") as raw_handle: + handle = cast(TextIO, raw_handle) + handle.write(json.dumps(row, ensure_ascii=False) + "\n") + + @staticmethod + def count_jsonl_rows(path:Path) -> int: + count = 0 + + candidates = [path] + if path.suffix == ".gz": + candidates.append(path.with_suffix("")) + else: + candidates.append(Path(f"{path}.gz")) + + seen = set() + for candidate in candidates: + if candidate in seen: + continue + seen.add(candidate) + + if not candidate.exists() or not candidate.is_file(): + continue + + try: + with DatasetIO.open_text(candidate, "r") as handle: + for line in handle: + if line.strip(): + count += 1 + except (OSError, UnicodeError): + continue + + return count + + @staticmethod + def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool: + if max_bytes <= 0: + return False + if not path.exists() or not path.is_file(): + return False + if path.stat().st_size < max_bytes: + return False + + timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") + if path.suffix == ".jsonl": + rotated_name = f"{path.stem}.{timestamp}.jsonl" + else: + rotated_name = f"{path.name}.{timestamp}" + + rotated_path = path.with_name(rotated_name) + suffix = 1 + while rotated_path.exists(): + suffix += 1 + if path.suffix == ".jsonl": + rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl" + else: + rotated_name = f"{path.name}.{timestamp}.{suffix}" + rotated_path = path.with_name(rotated_name) + + path.rename(rotated_path) + with rotated_path.open("rb") as src: + with gzip.open(f"{rotated_path}.gz", "wb") as dst: + dst.writelines(src) + rotated_path.unlink() + return True diff --git a/server/dataset/DatasetStats.py b/server/dataset/DatasetStats.py index db3d747..e4b8008 100644 --- a/server/dataset/DatasetStats.py +++ b/server/dataset/DatasetStats.py @@ -1,50 +1,16 @@ from collections import Counter, defaultdict -import argparse, glob, json, re +import argparse, json, re from datetime import datetime from pathlib import Path +from server.dataset.DatasetIO import DatasetIO + class DatasetStats: DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})") def __init__(self, input_files:list[str]): self.input_files = input_files - def _resolve_input_files(self): - resolved = [] - seen = set() - - for item in self.input_files: - path = Path(item) - if path.is_dir(): - for file_path in sorted(path.rglob("*.jsonl")): - key = str(file_path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(file_path) - continue - - if any(ch in item for ch in "*?[]"): - for match in sorted(glob.glob(item)): - file_path = Path(match) - if not file_path.is_file(): - continue - key = str(file_path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(file_path) - continue - - if path.is_file(): - key = str(path.resolve()) - if key in seen: - continue - seen.add(key) - resolved.append(path) - - return resolved - def _infer_day(self, file_path:Path): match = self.DAY_PATTERN.search(file_path.name) if match: @@ -83,7 +49,7 @@ class DatasetStats: return None def analyze(self): - files = self._resolve_input_files() + files = DatasetIO.resolve_input_files(self.input_files) totals = { "rows": 0, @@ -99,57 +65,52 @@ class DatasetStats: for file_path in files: day = self._infer_day(file_path) - with file_path.open("r", encoding="utf-8") as source: - for line in source: - if not line.strip(): - continue + for row in DatasetIO.iter_jsonl_rows(file_path): + game_id = row.get("game_id") + if not game_id: + continue - row = json.loads(line) - game_id = row.get("game_id") - if not game_id: - continue + turn = int(row.get("turn", 0)) + safe_options = self._extract_safe_options(row) + snake_type = row.get("snake_type", "unknown") + move = row.get("move", "unknown") - turn = int(row.get("turn", 0)) - safe_options = self._extract_safe_options(row) - snake_type = row.get("snake_type", "unknown") - move = row.get("move", "unknown") + game_type = row.get("game_type", {}) + if isinstance(game_type, dict): + game_type_name = game_type.get("name", "unknown") + else: + game_type_name = str(game_type) - game_type = row.get("game_type", {}) - if isinstance(game_type, dict): - game_type_name = game_type.get("name", "unknown") - else: - game_type_name = str(game_type) + totals["rows"] += 1 + totals["games"].add(game_id) + totals["snake_types"][snake_type] += 1 + totals["game_types"][game_type_name] += 1 + totals["moves"][move] += 1 + totals["days"][day] += 1 - totals["rows"] += 1 - totals["games"].add(game_id) - totals["snake_types"][snake_type] += 1 - totals["game_types"][game_type_name] += 1 - totals["moves"][move] += 1 - totals["days"][day] += 1 + if game_id not in games: + games[game_id] = { + "game_id": game_id, + "day": day, + "snake_type": snake_type, + "game_type": game_type_name, + "rows": 0, + "max_turn": -1, + "safe_options_sum": 0, + "safe_options_count": 0, + "pressure_turns": 0, + } - if game_id not in games: - games[game_id] = { - "game_id": game_id, - "day": day, - "snake_type": snake_type, - "game_type": game_type_name, - "rows": 0, - "max_turn": -1, - "safe_options_sum": 0, - "safe_options_count": 0, - "pressure_turns": 0, - } + game = games[game_id] + game["rows"] += 1 + game["max_turn"] = max(game["max_turn"], turn) + if isinstance(safe_options, int): + game["safe_options_sum"] += safe_options + game["safe_options_count"] += 1 + if safe_options <= 2: + game["pressure_turns"] += 1 - game = games[game_id] - game["rows"] += 1 - game["max_turn"] = max(game["max_turn"], turn) - if isinstance(safe_options, int): - game["safe_options_sum"] += safe_options - game["safe_options_count"] += 1 - if safe_options <= 2: - game["pressure_turns"] += 1 - - day_games[day].add(game_id) + day_games[day].add(game_id) game_summaries = [] for game in games.values(): @@ -225,7 +186,7 @@ if __name__ == "__main__": "--input", action="append", required=True, - help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.", + help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.", ) parser.add_argument( "--output", diff --git a/server/dataset/RLBootstrapDataset.py b/server/dataset/RLBootstrapDataset.py index 76a5e29..3b06136 100644 --- a/server/dataset/RLBootstrapDataset.py +++ b/server/dataset/RLBootstrapDataset.py @@ -1,8 +1,8 @@ -from datetime import UTC, datetime from pathlib import Path from typing import Any +import os -import gzip, json, os +from server.dataset.DatasetIO import DatasetIO class RLBootstrapDataset: def __init__(self): @@ -30,76 +30,15 @@ class RLBootstrapDataset: except ValueError: return default - @staticmethod - def count_jsonl_rows(path:Path) -> int: - count = 0 - - candidates = [path] - if path.suffix == ".gz": - candidates.append(path.with_suffix("")) - else: - candidates.append(Path(f"{path}.gz")) - - seen = set() - for candidate in candidates: - if candidate in seen: - continue - seen.add(candidate) - - if not candidate.exists() or not candidate.is_file(): - continue - - open_fn = gzip.open if candidate.suffix == ".gz" else open - try: - with open_fn(candidate, "rt", encoding="utf-8") as handle: - for line in handle: - if line.strip(): - count += 1 - except (OSError, UnicodeError): - continue - - return count - - @staticmethod - def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool: - if max_bytes <= 0: - return False - if not path.exists() or not path.is_file(): - return False - if path.stat().st_size < max_bytes: - return False - - timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S") - if path.suffix == ".jsonl": - rotated_name = f"{path.stem}.{timestamp}.jsonl" - else: - rotated_name = f"{path.name}.{timestamp}" - - rotated_path = path.with_name(rotated_name) - suffix = 1 - while rotated_path.exists(): - suffix += 1 - if path.suffix == ".jsonl": - rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl" - else: - rotated_name = f"{path.name}.{timestamp}.{suffix}" - rotated_path = path.with_name(rotated_name) - - path.rename(rotated_path) - with rotated_path.open("rb") as src: - with gzip.open(f"{rotated_path}.gz", "wb") as dst: - dst.writelines(src) - rotated_path.unlink() - return True - def refresh_state(self): if not self.enabled: self.needs_more_data = False return - base_rows = self.count_jsonl_rows(self.base_dataset_path) + + base_rows = DatasetIO.count_jsonl_rows(self.base_dataset_path) self.needs_more_data = base_rows < self.min_base_rows - def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None,): + def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None): if not self.enabled or not self.needs_more_data: return @@ -117,8 +56,7 @@ class RLBootstrapDataset: if scores: row["scores"] = {k: round(v, 5) for k, v in scores.items()} - with self.output_path.open("a", encoding="utf-8") as handle: - handle.write(json.dumps(row, ensure_ascii=False) + "\n") - self.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes) + DatasetIO.append_jsonl_row(self.output_path, row) + DatasetIO.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes) except Exception: return diff --git a/server/dataset/__init__.py b/server/dataset/__init__.py index f34e69f..6cfa3e1 100644 --- a/server/dataset/__init__.py +++ b/server/dataset/__init__.py @@ -1,4 +1,5 @@ from .Dataset import Dataset +from .DatasetIO import DatasetIO from .DatasetExporter import DatasetExporter from .DatasetCurator import DatasetCurator from .DatasetStats import DatasetStats diff --git a/tests/test_DatasetExporter.py b/tests/test_DatasetExporter.py index 2f307de..6877467 100644 --- a/tests/test_DatasetExporter.py +++ b/tests/test_DatasetExporter.py @@ -1,6 +1,6 @@ -import json -import tempfile import unittest + +import tempfile, json, gzip from pathlib import Path from server.dataset.DatasetExporter import DatasetExporter @@ -43,5 +43,41 @@ class TestDatasetExporter(unittest.TestCase): self.assertEqual(first["move"], "up") self.assertTrue(first["is_good_move"]) + def test_export_jsonl_gz(self): + with tempfile.TemporaryDirectory() as tmp: + input_dir = Path(tmp) / "data" + output_file = Path(tmp) / "out" / "dataset.jsonl.gz" + game_file = input_dir / "game-1.json" + game_file.parent.mkdir(parents=True, exist_ok=True) + + game_payload = { + "dataset": { + "game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}}, + "snake": {"type": "BestBattleSnake"}, + "samples": [ + { + "turn": 1, + "move": "up", + "is_good_move": True, + "game_board": {"width": 11, "height": 11}, + "history": {"data": []}, + } + ], + } + } + game_file.write_text(json.dumps(game_payload), encoding="utf-8") + + report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl() + + self.assertEqual(report["games_scanned"], 1) + self.assertEqual(report["samples_exported"], 1) + self.assertTrue(output_file.exists()) + + with gzip.open(output_file, "rt", encoding="utf-8") as handle: + lines = handle.read().strip().splitlines() + self.assertEqual(len(lines), 1) + first = json.loads(lines[0]) + self.assertEqual(first["game_id"], "g-1") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_DatasetIO.py b/tests/test_DatasetIO.py new file mode 100644 index 0000000..79d50dc --- /dev/null +++ b/tests/test_DatasetIO.py @@ -0,0 +1,46 @@ +import unittest + +import tempfile, gzip, json +from pathlib import Path + +from server.dataset.DatasetIO import DatasetIO + +class TestDatasetIO(unittest.TestCase): + def test_resolve_input_files_includes_jsonl_and_gz(self): + with tempfile.TemporaryDirectory() as tmp: + base = Path(tmp) + (base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8") + with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle: + handle.write('{"x":2}\n') + + paths = DatasetIO.resolve_input_files([str(base)]) + names = {path.name for path in paths} + self.assertIn("a.jsonl", names) + self.assertIn("b.jsonl.gz", names) + + def test_iter_jsonl_rows_reads_gz_file(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "rows.jsonl.gz" + with gzip.open(path, "wt", encoding="utf-8") as handle: + handle.write('{"turn":1}\n') + handle.write("\n") + handle.write('{"turn":2}\n') + + rows = list(DatasetIO.iter_jsonl_rows(path)) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["turn"], 1) + self.assertEqual(rows[1]["turn"], 2) + + def test_append_jsonl_row_supports_gz_output(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "append.jsonl.gz" + DatasetIO.append_jsonl_row(path, {"move": "up"}) + DatasetIO.append_jsonl_row(path, {"move": "left"}) + + with gzip.open(path, "rt", encoding="utf-8") as handle: + lines = handle.read().strip().splitlines() + self.assertEqual(len(lines), 2) + self.assertEqual(json.loads(lines[0])["move"], "up") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_RLBootstrapDataset.py b/tests/test_RLBootstrapDataset.py index 32ab18c..dec9404 100644 --- a/tests/test_RLBootstrapDataset.py +++ b/tests/test_RLBootstrapDataset.py @@ -5,6 +5,7 @@ from pathlib import Path import tempfile, gzip from server.dataset.RLBootstrapDataset import RLBootstrapDataset +from server.dataset.DatasetIO import DatasetIO class TestRLBootstrapDataset(unittest.TestCase): def test_count_jsonl_rows_reads_gzip_dataset(self): @@ -15,17 +16,15 @@ class TestRLBootstrapDataset(unittest.TestCase): handle.write("\n") handle.write('{"turn":2}\n') - self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2) - self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2) + self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2) + self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2) def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self): with tempfile.TemporaryDirectory() as tmp: path = Path(tmp) / "rl_bootstrap.jsonl" path.write_text("x" * 200, encoding="utf-8") - rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached( - path, max_bytes=50 - ) + rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50) self.assertTrue(rotated) self.assertFalse(path.exists())