47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
import unittest
|
|
|
|
import tempfile, gzip, json
|
|
from pathlib import Path
|
|
|
|
from server.dataset.DatasetIO import DatasetIO
|
|
|
|
class TestDatasetIO(unittest.TestCase):
|
|
def test_resolve_input_files_includes_jsonl_and_gz(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
base = Path(tmp)
|
|
(base / "a.jsonl").write_text('{"x":1}\n', encoding="utf-8")
|
|
with gzip.open(base / "b.jsonl.gz", "wt", encoding="utf-8") as handle:
|
|
handle.write('{"x":2}\n')
|
|
|
|
paths = DatasetIO.resolve_input_files([str(base)])
|
|
names = {path.name for path in paths}
|
|
self.assertIn("a.jsonl", names)
|
|
self.assertIn("b.jsonl.gz", names)
|
|
|
|
def test_iter_jsonl_rows_reads_gz_file(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
path = Path(tmp) / "rows.jsonl.gz"
|
|
with gzip.open(path, "wt", encoding="utf-8") as handle:
|
|
handle.write('{"turn":1}\n')
|
|
handle.write("\n")
|
|
handle.write('{"turn":2}\n')
|
|
|
|
rows = list(DatasetIO.iter_jsonl_rows(path))
|
|
self.assertEqual(len(rows), 2)
|
|
self.assertEqual(rows[0]["turn"], 1)
|
|
self.assertEqual(rows[1]["turn"], 2)
|
|
|
|
def test_append_jsonl_row_supports_gz_output(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
path = Path(tmp) / "append.jsonl.gz"
|
|
DatasetIO.append_jsonl_row(path, {"move": "up"})
|
|
DatasetIO.append_jsonl_row(path, {"move": "left"})
|
|
|
|
with gzip.open(path, "rt", encoding="utf-8") as handle:
|
|
lines = handle.read().strip().splitlines()
|
|
self.assertEqual(len(lines), 2)
|
|
self.assertEqual(json.loads(lines[0])["move"], "up")
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|