rework dataset function and class structure
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import tempfile, json, gzip
|
||||
from pathlib import Path
|
||||
|
||||
from server.dataset.DatasetExporter import DatasetExporter
|
||||
@@ -43,5 +43,41 @@ class TestDatasetExporter(unittest.TestCase):
|
||||
self.assertEqual(first["move"], "up")
|
||||
self.assertTrue(first["is_good_move"])
|
||||
|
||||
def test_export_jsonl_gz(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
input_dir = Path(tmp) / "data"
|
||||
output_file = Path(tmp) / "out" / "dataset.jsonl.gz"
|
||||
game_file = input_dir / "game-1.json"
|
||||
game_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
game_payload = {
|
||||
"dataset": {
|
||||
"game": {"id": "g-1", "map": "standard", "type": {"name": "duel"}},
|
||||
"snake": {"type": "BestBattleSnake"},
|
||||
"samples": [
|
||||
{
|
||||
"turn": 1,
|
||||
"move": "up",
|
||||
"is_good_move": True,
|
||||
"game_board": {"width": 11, "height": 11},
|
||||
"history": {"data": []},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
game_file.write_text(json.dumps(game_payload), encoding="utf-8")
|
||||
|
||||
report = DatasetExporter(str(input_dir), str(output_file)).export_jsonl()
|
||||
|
||||
self.assertEqual(report["games_scanned"], 1)
|
||||
self.assertEqual(report["samples_exported"], 1)
|
||||
self.assertTrue(output_file.exists())
|
||||
|
||||
with gzip.open(output_file, "rt", encoding="utf-8") as handle:
|
||||
lines = handle.read().strip().splitlines()
|
||||
self.assertEqual(len(lines), 1)
|
||||
first = json.loads(lines[0])
|
||||
self.assertEqual(first["game_id"], "g-1")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
|
||||
import tempfile, gzip, json
|
||||
from pathlib import Path
|
||||
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class TestDatasetIO(unittest.TestCase):
|
||||
def test_resolve_input_files_includes_jsonl_and_gz(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base = Path(tmp)
|
||||
(base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8")
|
||||
with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"x":2}\n')
|
||||
|
||||
paths = DatasetIO.resolve_input_files([str(base)])
|
||||
names = {path.name for path in paths}
|
||||
self.assertIn("a.jsonl", names)
|
||||
self.assertIn("b.jsonl.gz", names)
|
||||
|
||||
def test_iter_jsonl_rows_reads_gz_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "rows.jsonl.gz"
|
||||
with gzip.open(path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
handle.write("\n")
|
||||
handle.write('{"turn":2}\n')
|
||||
|
||||
rows = list(DatasetIO.iter_jsonl_rows(path))
|
||||
self.assertEqual(len(rows), 2)
|
||||
self.assertEqual(rows[0]["turn"], 1)
|
||||
self.assertEqual(rows[1]["turn"], 2)
|
||||
|
||||
def test_append_jsonl_row_supports_gz_output(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "append.jsonl.gz"
|
||||
DatasetIO.append_jsonl_row(path, {"move": "up"})
|
||||
DatasetIO.append_jsonl_row(path, {"move": "left"})
|
||||
|
||||
with gzip.open(path, "rt", encoding="utf-8") as handle:
|
||||
lines = handle.read().strip().splitlines()
|
||||
self.assertEqual(len(lines), 2)
|
||||
self.assertEqual(json.loads(lines[0])["move"], "up")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import tempfile, gzip
|
||||
|
||||
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
|
||||
from server.dataset.DatasetIO import DatasetIO
|
||||
|
||||
class TestRLBootstrapDataset(unittest.TestCase):
|
||||
def test_count_jsonl_rows_reads_gzip_dataset(self):
|
||||
@@ -15,17 +16,15 @@ class TestRLBootstrapDataset(unittest.TestCase):
|
||||
handle.write("\n")
|
||||
handle.write('{"turn":2}\n')
|
||||
|
||||
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2)
|
||||
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
|
||||
self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
|
||||
self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
|
||||
|
||||
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "rl_bootstrap.jsonl"
|
||||
path.write_text("x" * 200, encoding="utf-8")
|
||||
|
||||
rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached(
|
||||
path, max_bytes=50
|
||||
)
|
||||
rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
|
||||
|
||||
self.assertTrue(rotated)
|
||||
self.assertFalse(path.exists())
|
||||
|
||||
Reference in New Issue
Block a user