Files
snake-python/tests/test_RLBootstrapDataset.py

101 lines
3.0 KiB
Python

import unittest
from unittest.mock import patch
from pathlib import Path
import tempfile, gzip
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
from server.dataset.DatasetIO import DatasetIO
class TestRLBootstrapDataset(unittest.TestCase):
def test_count_jsonl_rows_reads_gzip_dataset(self):
with tempfile.TemporaryDirectory() as tmp:
dataset_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
handle.write("\n")
handle.write('{"turn":2}\n')
self.assertEqual(DatasetIO.count_jsonl_rows(dataset_path), 2)
self.assertEqual(DatasetIO.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "rl_bootstrap.jsonl"
path.write_text("x" * 200, encoding="utf-8")
rotated = DatasetIO.rotate_and_gzip_if_size_reached(path, max_bytes=50)
self.assertTrue(rotated)
self.assertFalse(path.exists())
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
with tempfile.TemporaryDirectory() as tmp:
base_path = Path(tmp) / "base.jsonl.gz"
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
handle.write('{"turn":1}\n')
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_MIN_BASE_ROWS": "2",
"RL_BASE_DATASET": str(base_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.refresh_state()
self.assertTrue(dataset.needs_more_data)
def test_rl_bootstrap_dataset_autocompresses_output(self):
class DummyBoard:
id = "game-1"
def get_turn(self):
return 1
def get_game_board_as_dict(self):
return {
"turn": 1,
"board": {
"width": 11,
"height": 11,
"food": [],
"hazards": [],
"snakes": [],
},
"you": {
"id": "me",
"head": {"x": 5, "y": 5},
"body": [{"x": 5, "y": 5}],
},
}
with tempfile.TemporaryDirectory() as tmp:
output_path = Path(tmp) / "rl_bootstrap.jsonl"
with patch.dict(
"os.environ",
{
"RL_BOOTSTRAP_ENABLED": "1",
"RL_BOOTSTRAP_MAX_MB": "0.00001",
"RL_BOOTSTRAP_OUTPUT": str(output_path),
},
clear=False,
):
dataset = RLBootstrapDataset()
dataset.needs_more_data = True
dataset.record_sample(
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
)
dataset.record_sample(
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
)
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
if __name__ == "__main__":
unittest.main()