101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from pathlib import Path
|
|
import tempfile, gzip
|
|
|
|
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
|
|
from server.dataset.DatasetIO import DatasetIO
|
|
|
|
class TestRLBootstrapDataset(unittest.TestCase):
|
|
def test_count_jsonl_rows_reads_gzip_dataset(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
dataset_path = Path(tmp) / "base.jsonl.gz"
|
|
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
|
|
handle.write('{"turn":1}\n')
|
|
handle.write("\n")
|
|
handle.write('{"turn":2}\n')
|
|
|
|
self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
|
|
self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
|
|
|
|
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
path = Path(tmp) / "rl_bootstrap.jsonl"
|
|
path.write_text("x" * 200, encoding="utf-8")
|
|
|
|
rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
|
|
|
|
self.assertTrue(rotated)
|
|
self.assertFalse(path.exists())
|
|
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
|
|
|
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
base_path = Path(tmp) / "base.jsonl.gz"
|
|
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
|
|
handle.write('{"turn":1}\n')
|
|
|
|
with patch.dict(
|
|
"os.environ",
|
|
{
|
|
"RL_BOOTSTRAP_ENABLED": "1",
|
|
"RL_MIN_BASE_ROWS": "2",
|
|
"RL_BASE_DATASET": str(base_path),
|
|
},
|
|
clear=False,
|
|
):
|
|
dataset = RLBootstrapDataset()
|
|
dataset.refresh_state()
|
|
|
|
self.assertTrue(dataset.needs_more_data)
|
|
|
|
def test_rl_bootstrap_dataset_autocompresses_output(self):
|
|
class DummyBoard:
|
|
id = "game-1"
|
|
|
|
def get_turn(self):
|
|
return 1
|
|
|
|
def get_game_board_as_dict(self):
|
|
return {
|
|
"turn": 1,
|
|
"board": {
|
|
"width": 11,
|
|
"height": 11,
|
|
"food": [],
|
|
"hazards": [],
|
|
"snakes": [],
|
|
},
|
|
"you": {
|
|
"id": "me",
|
|
"head": {"x": 5, "y": 5},
|
|
"body": [{"x": 5, "y": 5}],
|
|
},
|
|
}
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
output_path = Path(tmp) / "rl_bootstrap.jsonl"
|
|
with patch.dict(
|
|
"os.environ",
|
|
{
|
|
"RL_BOOTSTRAP_ENABLED": "1",
|
|
"RL_BOOTSTRAP_MAX_MB": "0.00001",
|
|
"RL_BOOTSTRAP_OUTPUT": str(output_path),
|
|
},
|
|
clear=False,
|
|
):
|
|
dataset = RLBootstrapDataset()
|
|
dataset.needs_more_data = True
|
|
dataset.record_sample(
|
|
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
|
|
)
|
|
dataset.record_sample(
|
|
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
|
|
)
|
|
|
|
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|