rework dataset function and class structure

This commit is contained in:
2026-04-05 02:21:15 +02:00
parent 066a93f755
commit 332e86e3cc
9 changed files with 318 additions and 248 deletions
+9 -54
View File
@@ -1,6 +1,8 @@
import argparse, hashlib, shutil, glob, json import argparse, hashlib, shutil, json
from pathlib import Path from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class DatasetCurator: class DatasetCurator:
def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None): def __init__(self, input_files:list[str], output_file:str, min_turn:int=6, late_turn:int=20, max_safe_options:int=2, min_score:int=3, append:bool=False, archive_input:bool=False, archive_dir:str|None=None):
self.input_files = input_files self.input_files = input_files
@@ -11,45 +13,7 @@ class DatasetCurator:
self.min_score = min_score self.min_score = min_score
self.append = append self.append = append
self.archive_input = archive_input self.archive_input = archive_input
self.archive_dir = ( self.archive_dir = Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
)
def _resolve_input_files(self):
resolved = []
seen = set()
for item in self.input_files:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _safe_options_count(self, row:dict): def _safe_options_count(self, row:dict):
history = row.get("history", {}) history = row.get("history", {})
@@ -98,7 +62,7 @@ class DatasetCurator:
def curate(self): def curate(self):
self.output_file.parent.mkdir(parents=True, exist_ok=True) self.output_file.parent.mkdir(parents=True, exist_ok=True)
input_paths = self._resolve_input_files() input_paths = DatasetIO.resolve_input_files(self.input_files)
total = 0 total = 0
kept = 0 kept = 0
@@ -108,24 +72,15 @@ class DatasetCurator:
seen_states = set() seen_states = set()
if self.append and self.output_file.exists(): if self.append and self.output_file.exists():
with self.output_file.open("r", encoding="utf-8") as existing: for row in DatasetIO.iter_jsonl_rows(self.output_file):
for line in existing:
if not line.strip():
continue
row = json.loads(line)
state_key = self._state_hash(row) state_key = self._state_hash(row)
seen_states.add((state_key, row.get("move"))) seen_states.add((state_key, row.get("move")))
mode = "a" if self.append else "w" mode = "a" if self.append else "w"
with self.output_file.open(mode, encoding="utf-8") as dst: with DatasetIO.open_text(self.output_file, mode) as dst:
for input_path in input_paths: for input_path in input_paths:
with input_path.open("r", encoding="utf-8") as src: for row in DatasetIO.iter_jsonl_rows(input_path):
for line in src:
if not line.strip():
continue
total += 1 total += 1
row = json.loads(line)
if not row.get("is_good_move", False): if not row.get("is_good_move", False):
skipped_quality += 1 skipped_quality += 1
@@ -220,7 +175,7 @@ if __name__ == "__main__":
"--input", "--input",
action="append", action="append",
required=True, required=True,
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.", help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
) )
parser.add_argument("--output", required=True, help="Output JSONL file") parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--min-turn", type=int, default=6) parser.add_argument("--min-turn", type=int, default=6)
+4 -4
View File
@@ -1,15 +1,15 @@
from pathlib import Path from pathlib import Path
import argparse, json import argparse, json
from server.dataset.DatasetIO import DatasetIO
class DatasetExporter: class DatasetExporter:
def __init__(self, input_dir:str, output_file:str): def __init__(self, input_dir:str, output_file:str):
self.input_dir = Path(input_dir) self.input_dir = Path(input_dir)
self.output_file = Path(output_file) self.output_file = Path(output_file)
def _iter_game_files(self): def _iter_game_files(self):
if not self.input_dir.exists(): return DatasetIO.list_directory_files(self.input_dir, directory_pattern="*.json")
return []
return sorted(self.input_dir.rglob("*.json"))
def _extract_samples(self, payload:dict, source_file:Path): def _extract_samples(self, payload:dict, source_file:Path):
dataset = payload.get("dataset", {}) dataset = payload.get("dataset", {})
@@ -37,7 +37,7 @@ class DatasetExporter:
self.output_file.parent.mkdir(parents=True, exist_ok=True) self.output_file.parent.mkdir(parents=True, exist_ok=True)
sample_count = 0 sample_count = 0
with self.output_file.open("w", encoding="utf-8") as output: with DatasetIO.open_text(self.output_file, "w") as output:
for game_file in game_files: for game_file in game_files:
with game_file.open("r", encoding="utf-8") as source: with game_file.open("r", encoding="utf-8") as source:
payload = json.load(source) payload = json.load(source)
+134
View File
@@ -0,0 +1,134 @@
from typing import Any, TextIO, cast
from datetime import UTC, datetime
from pathlib import Path
import gzip, glob, json
class DatasetIO:
@staticmethod
def resolve_input_files(input_files: list[str], directory_pattern: str | tuple[str, ...] = ("*.jsonl", "*.jsonl.gz")) -> list[Path]:
patterns = (
(directory_pattern,)
if isinstance(directory_pattern, str)
else directory_pattern
)
resolved = []
seen = set()
for item in input_files:
path = Path(item)
if path.is_dir():
for pattern in patterns:
for file_path in sorted(path.rglob(pattern)):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
@staticmethod
def list_directory_files(input_dir:Path, directory_pattern:str) -> list[Path]:
if not input_dir.exists():
return []
return sorted(input_dir.rglob(directory_pattern))
@staticmethod
def open_text(path:Path, mode:str="r"):
text_mode = mode if "t" in mode else f"{mode}t"
open_fn = gzip.open if path.suffix == ".gz" else open
return cast(TextIO, open_fn(path, text_mode, encoding="utf-8"))
@staticmethod
def iter_jsonl_rows(path:Path):
with DatasetIO.open_text(path, "r") as handle:
for line in handle:
if line.strip():
yield json.loads(line)
@staticmethod
def append_jsonl_row(path:Path, row:dict[str, Any]):
with DatasetIO.open_text(path, "a") as raw_handle:
handle = cast(TextIO, raw_handle)
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
@staticmethod
def count_jsonl_rows(path:Path) -> int:
count = 0
candidates = [path]
if path.suffix == ".gz":
candidates.append(path.with_suffix(""))
else:
candidates.append(Path(f"{path}.gz"))
seen = set()
for candidate in candidates:
if candidate in seen:
continue
seen.add(candidate)
if not candidate.exists() or not candidate.is_file():
continue
try:
with DatasetIO.open_text(candidate, "r") as handle:
for line in handle:
if line.strip():
count += 1
except (OSError, UnicodeError):
continue
return count
@staticmethod
def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool:
if max_bytes <= 0:
return False
if not path.exists() or not path.is_file():
return False
if path.stat().st_size < max_bytes:
return False
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}"
rotated_path = path.with_name(rotated_name)
suffix = 1
while rotated_path.exists():
suffix += 1
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}.{suffix}"
rotated_path = path.with_name(rotated_name)
path.rename(rotated_path)
with rotated_path.open("rb") as src:
with gzip.open(f"{rotated_path}.gz", "wb") as dst:
dst.writelines(src)
rotated_path.unlink()
return True
+6 -45
View File
@@ -1,50 +1,16 @@
from collections import Counter, defaultdict from collections import Counter, defaultdict
import argparse, glob, json, re import argparse, json, re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class DatasetStats: class DatasetStats:
DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})") DAY_PATTERN = re.compile(r"(\d{4}-\d{2}-\d{2})")
def __init__(self, input_files:list[str]): def __init__(self, input_files:list[str]):
self.input_files = input_files self.input_files = input_files
def _resolve_input_files(self):
resolved = []
seen = set()
for item in self.input_files:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _infer_day(self, file_path:Path): def _infer_day(self, file_path:Path):
match = self.DAY_PATTERN.search(file_path.name) match = self.DAY_PATTERN.search(file_path.name)
if match: if match:
@@ -83,7 +49,7 @@ class DatasetStats:
return None return None
def analyze(self): def analyze(self):
files = self._resolve_input_files() files = DatasetIO.resolve_input_files(self.input_files)
totals = { totals = {
"rows": 0, "rows": 0,
@@ -99,12 +65,7 @@ class DatasetStats:
for file_path in files: for file_path in files:
day = self._infer_day(file_path) day = self._infer_day(file_path)
with file_path.open("r", encoding="utf-8") as source: for row in DatasetIO.iter_jsonl_rows(file_path):
for line in source:
if not line.strip():
continue
row = json.loads(line)
game_id = row.get("game_id") game_id = row.get("game_id")
if not game_id: if not game_id:
continue continue
@@ -225,7 +186,7 @@ if __name__ == "__main__":
"--input", "--input",
action="append", action="append",
required=True, required=True,
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.", help="Input JSONL/JSONL.GZ file, directory, or glob pattern. Repeat for multiple inputs.",
) )
parser.add_argument( parser.add_argument(
"--output", "--output",
+7 -69
View File
@@ -1,8 +1,8 @@
from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import os
import gzip, json, os from server.dataset.DatasetIO import DatasetIO
class RLBootstrapDataset: class RLBootstrapDataset:
def __init__(self): def __init__(self):
@@ -30,76 +30,15 @@ class RLBootstrapDataset:
except ValueError: except ValueError:
return default return default
@staticmethod
def count_jsonl_rows(path:Path) -> int:
count = 0
candidates = [path]
if path.suffix == ".gz":
candidates.append(path.with_suffix(""))
else:
candidates.append(Path(f"{path}.gz"))
seen = set()
for candidate in candidates:
if candidate in seen:
continue
seen.add(candidate)
if not candidate.exists() or not candidate.is_file():
continue
open_fn = gzip.open if candidate.suffix == ".gz" else open
try:
with open_fn(candidate, "rt", encoding="utf-8") as handle:
for line in handle:
if line.strip():
count += 1
except (OSError, UnicodeError):
continue
return count
@staticmethod
def rotate_and_gzip_if_size_reached(path:Path, max_bytes:int) -> bool:
if max_bytes <= 0:
return False
if not path.exists() or not path.is_file():
return False
if path.stat().st_size < max_bytes:
return False
timestamp = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}"
rotated_path = path.with_name(rotated_name)
suffix = 1
while rotated_path.exists():
suffix += 1
if path.suffix == ".jsonl":
rotated_name = f"{path.stem}.{timestamp}.{suffix}.jsonl"
else:
rotated_name = f"{path.name}.{timestamp}.{suffix}"
rotated_path = path.with_name(rotated_name)
path.rename(rotated_path)
with rotated_path.open("rb") as src:
with gzip.open(f"{rotated_path}.gz", "wb") as dst:
dst.writelines(src)
rotated_path.unlink()
return True
def refresh_state(self): def refresh_state(self):
if not self.enabled: if not self.enabled:
self.needs_more_data = False self.needs_more_data = False
return return
base_rows = self.count_jsonl_rows(self.base_dataset_path)
base_rows = DatasetIO.count_jsonl_rows(self.base_dataset_path)
self.needs_more_data = base_rows < self.min_base_rows self.needs_more_data = base_rows < self.min_base_rows
def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None,): def record_sample(self, game_data:Any, move:str, safe_moves:dict[str, dict[str, int]], reason:str, scores:dict[str, float]|None=None):
if not self.enabled or not self.needs_more_data: if not self.enabled or not self.needs_more_data:
return return
@@ -117,8 +56,7 @@ class RLBootstrapDataset:
if scores: if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()} row["scores"] = {k: round(v, 5) for k, v in scores.items()}
with self.output_path.open("a", encoding="utf-8") as handle: DatasetIO.append_jsonl_row(self.output_path, row)
handle.write(json.dumps(row, ensure_ascii=False) + "\n") DatasetIO.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes)
self.rotate_and_gzip_if_size_reached(self.output_path, self.max_bytes)
except Exception: except Exception:
return return
+1
View File
@@ -1,4 +1,5 @@
from .Dataset import Dataset from .Dataset import Dataset
from .DatasetIO import DatasetIO
from .DatasetExporter import DatasetExporter from .DatasetExporter import DatasetExporter
from .DatasetCurator import DatasetCurator from .DatasetCurator import DatasetCurator
from .DatasetStats import DatasetStats from .DatasetStats import DatasetStats
+38 -2
View File
@@ -1,6 +1,6 @@
import json
import tempfile
import unittest import unittest
import tempfile, json, gzip
from pathlib import Path from pathlib import Path
from server.dataset.DatasetExporter import DatasetExporter from server.dataset.DatasetExporter import DatasetExporter
@@ -43,5 +43,41 @@ class TestDatasetExporter(unittest.TestCase):
self.assertEqual(first["move"], "up") self.assertEqual(first["move"], "up")
self.assertTrue(first["is_good_move"]) self.assertTrue(first["is_good_move"])
def test_export_jsonl_gz(self):
with tempfile.TemporaryDirectory() as tmp:
input_dir = Path(tmp) / "data"
output_file = Path(tmp) / "out" / "dataset.jsonl.gz"
game_file = input_dir / "game-1.json"
game_file.parent.mkdir(parents=True, exist_ok=True)
game_payload = {
"dataset": {
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
"snake": {"type": "BestBattleSnake"},
"samples": [
{
"turn": 1,
"move": "up",
"is_good_move": True,
"game_board": {"width": 11, "height": 11},
"history": {"data": []},
}
],
}
}
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
self.assertEqual(report["games_scanned"], 1)
self.assertEqual(report["samples_exported"], 1)
self.assertTrue(output_file.exists())
with gzip.open(output_file, "rt", encoding="utf-8") as handle:
lines = handle.read().strip().splitlines()
self.assertEqual(len(lines), 1)
first = json.loads(lines[0])
self.assertEqual(first["game_id"], "g-1")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
+46
View File
@@ -0,0 +1,46 @@
import unittest
import tempfile, gzip, json
from pathlib import Path
from server.dataset.DatasetIO import DatasetIO
class TestDatasetIO(unittest.TestCase):
def test_resolve_input_files_includes_jsonl_and_gz(self):
with tempfile.TemporaryDirectory() as tmp:
base = Path(tmp)
(base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8")
with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle:
handle.write('{"x":2}\n')
paths = DatasetIO.resolve_input_files([str(base)])
names = {path.name for path in paths}
self.assertIn("a.jsonl", names)
self.assertIn("b.jsonl.gz", names)
def test_iter_jsonl_rows_reads_gz_file(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rows.jsonl.gz"
with gzip.open(path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
handle.write("\n")
handle.write('{"turn":2}\n')
rows = list(DatasetIO.iter_jsonl_rows(path))
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0]["turn"], 1)
self.assertEqual(rows[1]["turn"], 2)
def test_append_jsonl_row_supports_gz_output(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "append.jsonl.gz"
DatasetIO.append_jsonl_row(path, {"move": "up"})
DatasetIO.append_jsonl_row(path, {"move": "left"})
with gzip.open(path, "rt", encoding="utf-8") as handle:
lines = handle.read().strip().splitlines()
self.assertEqual(len(lines), 2)
self.assertEqual(json.loads(lines[0])["move"], "up")
if __name__ == "__main__":
unittest.main()
+4 -5
View File
@@ -5,6 +5,7 @@ from pathlib import Path
import tempfile, gzip import tempfile, gzip
from server.dataset.RLBootstrapDataset import RLBootstrapDataset from server.dataset.RLBootstrapDataset import RLBootstrapDataset
from server.dataset.DatasetIO import DatasetIO
class TestRLBootstrapDataset(unittest.TestCase): class TestRLBootstrapDataset(unittest.TestCase):
def test_count_jsonl_rows_reads_gzip_dataset(self): def test_count_jsonl_rows_reads_gzip_dataset(self):
@@ -15,17 +16,15 @@ class TestRLBootstrapDataset(unittest.TestCase):
handle.write("\n") handle.write("\n")
handle.write('{"turn":2}\n') handle.write('{"turn":2}\n')
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2) self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2) self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self): def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rl_bootstrap.jsonl" path = Path(tmp) / "rl_bootstrap.jsonl"
path.write_text("x" * 200, encoding="utf-8") path.write_text("x" * 200, encoding="utf-8")
rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached( rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
path, max_bytes=50
)
self.assertTrue(rotated) self.assertTrue(rotated)
self.assertFalse(path.exists()) self.assertFalse(path.exists())