move RLBootstrapDataset into a own class with its own test file
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from pathlib import Path
|
||||
import tempfile, gzip
|
||||
|
||||
from server.dataset.RLBootstrapDataset import RLBootstrapDataset
|
||||
|
||||
class TestRLBootstrapDataset(unittest.TestCase):
|
||||
def test_count_jsonl_rows_reads_gzip_dataset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
dataset_path = Path(tmp) / "base.jsonl.gz"
|
||||
with gzip.open(dataset_path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
handle.write("\n")
|
||||
handle.write('{"turn":2}\n')
|
||||
|
||||
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(dataset_path), 2)
|
||||
self.assertEqual(RLBootstrapDataset.count_jsonl_rows(Path(tmp) / "base.jsonl"), 2)
|
||||
|
||||
def test_rotate_and_gzip_if_size_reached_rotates_jsonl(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "rl_bootstrap.jsonl"
|
||||
path.write_text("x" * 200, encoding="utf-8")
|
||||
|
||||
rotated = RLBootstrapDataset.rotate_and_gzip_if_size_reached(
|
||||
path, max_bytes=50
|
||||
)
|
||||
|
||||
self.assertTrue(rotated)
|
||||
self.assertFalse(path.exists())
|
||||
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
||||
|
||||
def test_rl_bootstrap_dataset_reads_env_and_gzip_base(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base_path = Path(tmp) / "base.jsonl.gz"
|
||||
with gzip.open(base_path, "wt", encoding="utf-8") as handle:
|
||||
handle.write('{"turn":1}\n')
|
||||
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RL_BOOTSTRAP_ENABLED": "1",
|
||||
"RL_MIN_BASE_ROWS": "2",
|
||||
"RL_BASE_DATASET": str(base_path),
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
dataset = RLBootstrapDataset()
|
||||
dataset.refresh_state()
|
||||
|
||||
self.assertTrue(dataset.needs_more_data)
|
||||
|
||||
def test_rl_bootstrap_dataset_autocompresses_output(self):
|
||||
class DummyBoard:
|
||||
id = "game-1"
|
||||
|
||||
def get_turn(self):
|
||||
return 1
|
||||
|
||||
def get_game_board_as_dict(self):
|
||||
return {
|
||||
"turn": 1,
|
||||
"board": {
|
||||
"width": 11,
|
||||
"height": 11,
|
||||
"food": [],
|
||||
"hazards": [],
|
||||
"snakes": [],
|
||||
},
|
||||
"you": {
|
||||
"id": "me",
|
||||
"head": {"x": 5, "y": 5},
|
||||
"body": [{"x": 5, "y": 5}],
|
||||
},
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_path = Path(tmp) / "rl_bootstrap.jsonl"
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"RL_BOOTSTRAP_ENABLED": "1",
|
||||
"RL_BOOTSTRAP_MAX_MB": "0.00001",
|
||||
"RL_BOOTSTRAP_OUTPUT": str(output_path),
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
dataset = RLBootstrapDataset()
|
||||
dataset.needs_more_data = True
|
||||
dataset.record_sample(
|
||||
DummyBoard(), "up", {"up": {"x": 5, "y": 6}}, "test"
|
||||
)
|
||||
dataset.record_sample(
|
||||
DummyBoard(), "left", {"left": {"x": 4, "y": 5}}, "test"
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(list(Path(tmp).glob("rl_bootstrap.*.jsonl.gz"))), 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user