add more flags and allow to read more input file then one

This commit is contained in:
2026-04-03 15:44:19 +02:00
parent 37de34cc5e
commit 49f2e0b008
+142 -6
View File
@@ -1,16 +1,70 @@
import argparse import argparse
import glob
import hashlib import hashlib
import json import json
import shutil
from pathlib import Path from pathlib import Path
class DatasetCurator: class DatasetCurator:
def __init__(self, input_file: str, output_file: str, min_turn: int = 6, late_turn: int = 20, max_safe_options: int = 2, min_score: int = 3,): def __init__(
self.input_file = Path(input_file) self,
input_files: list[str],
output_file: str,
min_turn: int = 6,
late_turn: int = 20,
max_safe_options: int = 2,
min_score: int = 3,
append: bool = False,
archive_input: bool = False,
archive_dir: str | None = None,
):
self.input_files = input_files
self.output_file = Path(output_file) self.output_file = Path(output_file)
self.min_turn = min_turn self.min_turn = min_turn
self.late_turn = late_turn self.late_turn = late_turn
self.max_safe_options = max_safe_options self.max_safe_options = max_safe_options
self.min_score = min_score self.min_score = min_score
self.append = append
self.archive_input = archive_input
self.archive_dir = (
Path(archive_dir) if archive_dir else self.output_file.parent / "archive"
)
def _resolve_input_files(self):
resolved = []
seen = set()
for item in self.input_files:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _safe_options_count(self, row: dict): def _safe_options_count(self, row: dict):
history = row.get("history", {}) history = row.get("history", {})
@@ -65,6 +119,7 @@ class DatasetCurator:
def curate(self): def curate(self):
self.output_file.parent.mkdir(parents=True, exist_ok=True) self.output_file.parent.mkdir(parents=True, exist_ok=True)
input_paths = self._resolve_input_files()
total = 0 total = 0
kept = 0 kept = 0
@@ -73,8 +128,19 @@ class DatasetCurator:
skipped_duplicate = 0 skipped_duplicate = 0
seen_states = set() seen_states = set()
with self.input_file.open("r", encoding="utf-8") as src: if self.append and self.output_file.exists():
with self.output_file.open("w", encoding="utf-8") as dst: with self.output_file.open("r", encoding="utf-8") as existing:
for line in existing:
if not line.strip():
continue
row = json.loads(line)
state_key = self._state_hash(row)
seen_states.add((state_key, row.get("move")))
mode = "a" if self.append else "w"
with self.output_file.open(mode, encoding="utf-8") as dst:
for input_path in input_paths:
with input_path.open("r", encoding="utf-8") as src:
for line in src: for line in src:
if not line.strip(): if not line.strip():
continue continue
@@ -114,31 +180,101 @@ class DatasetCurator:
dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n") dst.write(json.dumps(compact_row, ensure_ascii=False) + "\n")
kept += 1 kept += 1
archived_files = []
if self.archive_input:
archived_files = self._archive_processed_files(input_paths)
return { return {
"input_files": [str(path) for path in input_paths],
"total_rows": total, "total_rows": total,
"kept_rows": kept, "kept_rows": kept,
"skipped_turn": skipped_turn, "skipped_turn": skipped_turn,
"skipped_quality": skipped_quality, "skipped_quality": skipped_quality,
"skipped_duplicate": skipped_duplicate, "skipped_duplicate": skipped_duplicate,
"append_mode": self.append,
"archive_input": self.archive_input,
"archived_files": archived_files,
"output_file": str(self.output_file), "output_file": str(self.output_file),
} }
def _archive_processed_files(self, input_paths: list[Path]):
self.archive_dir.mkdir(parents=True, exist_ok=True)
archived = []
output_resolved = (
self.output_file.resolve()
if self.output_file.exists()
else self.output_file
)
archive_resolved = self.archive_dir.resolve()
for source_path in input_paths:
if not source_path.exists():
continue
source_resolved = source_path.resolve()
if source_resolved == output_resolved:
continue
if source_resolved.parent == archive_resolved:
continue
destination = self.archive_dir / source_path.name
if destination.exists():
stem = destination.stem
suffix = destination.suffix
index = 1
while True:
candidate = self.archive_dir / f"{stem}.{index}{suffix}"
if not candidate.exists():
destination = candidate
break
index += 1
shutil.move(str(source_path), str(destination))
archived.append(str(destination))
return archived
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create curated best-moves dataset") parser = argparse.ArgumentParser(description="Create curated best-moves dataset")
parser.add_argument("--input", required=True, help="Input JSONL file") parser.add_argument(
"--input",
action="append",
required=True,
help="Input JSONL file, directory, or glob pattern. Repeat for multiple inputs.",
)
parser.add_argument("--output", required=True, help="Output JSONL file") parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--min-turn", type=int, default=6) parser.add_argument("--min-turn", type=int, default=6)
parser.add_argument("--late-turn", type=int, default=20) parser.add_argument("--late-turn", type=int, default=20)
parser.add_argument("--max-safe-options", type=int, default=2) parser.add_argument("--max-safe-options", type=int, default=2)
parser.add_argument("--min-score", type=int, default=3) parser.add_argument("--min-score", type=int, default=3)
parser.add_argument(
"--append",
action="store_true",
help="Append to existing output and dedupe against existing rows",
)
parser.add_argument(
"--archive-input",
action="store_true",
help="Move processed input files to archive directory after successful curation",
)
parser.add_argument(
"--archive-dir",
default=None,
help="Archive directory for processed input files (default: <output-dir>/archive)",
)
args = parser.parse_args() args = parser.parse_args()
report = DatasetCurator( report = DatasetCurator(
input_file=args.input, input_files=args.input,
output_file=args.output, output_file=args.output,
min_turn=args.min_turn, min_turn=args.min_turn,
late_turn=args.late_turn, late_turn=args.late_turn,
max_safe_options=args.max_safe_options, max_safe_options=args.max_safe_options,
min_score=args.min_score, min_score=args.min_score,
append=args.append,
archive_input=args.archive_input,
archive_dir=args.archive_dir,
).curate() ).curate()
print(json.dumps(report, indent=2)) print(json.dumps(report, indent=2))