93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
import unittest, sqlite3, os
|
|
from pathlib import Path
|
|
|
|
from server.database import GameplayDatabase
|
|
|
|
EXT_PATH = Path(os.environ.get("SQLITE_ZSTD_EXT", "lib/libsqlite_zstd.so")).expanduser().resolve()
|
|
|
|
def _open(db_path: str) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
if EXT_PATH.exists():
|
|
conn.enable_load_extension(True)
|
|
conn.load_extension(str(EXT_PATH))
|
|
conn.enable_load_extension(False)
|
|
return conn
|
|
|
|
@unittest.skipUnless(EXT_PATH.exists(), f"sqlite-zstd extension not found at {EXT_PATH}")
|
|
class TestZstdCompression(unittest.TestCase):
|
|
|
|
def test_extension_loads(self):
|
|
conn = sqlite3.connect(":memory:")
|
|
conn.enable_load_extension(True)
|
|
conn.load_extension(str(EXT_PATH))
|
|
conn.enable_load_extension(False)
|
|
result = conn.execute("SELECT zstd_compress('hello world', 6)").fetchone()[0]
|
|
self.assertIsInstance(result, bytes)
|
|
self.assertGreater(len(result), 0)
|
|
conn.close()
|
|
|
|
def test_compress_decompress_roundtrip(self):
|
|
# Use a long repetitive string so compression actually reduces size
|
|
original = '{"board": {"width": 11, "height": 11, "snakes": []}, "turn": 42} ' * 20
|
|
conn = sqlite3.connect(":memory:")
|
|
conn.enable_load_extension(True)
|
|
conn.load_extension(str(EXT_PATH))
|
|
conn.enable_load_extension(False)
|
|
compressed = conn.execute("SELECT zstd_compress(?, 6)", (original,)).fetchone()[0]
|
|
self.assertIsInstance(compressed, bytes)
|
|
self.assertLess(len(compressed), len(original))
|
|
conn.close()
|
|
|
|
def test_transparent_compression_enabled(self):
|
|
db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3")
|
|
if not Path(db_path).exists():
|
|
self.skipTest(f"Database not found at {db_path}")
|
|
|
|
conn = _open(db_path)
|
|
tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master").fetchall()}
|
|
self.assertIn("_zstd_configs", tables, "_zstd_configs table missing — compression not enabled")
|
|
|
|
configs = conn.execute("SELECT config FROM _zstd_configs").fetchall()
|
|
compressed_columns = [row["config"] for row in configs]
|
|
self.assertEqual(len(compressed_columns), 6)
|
|
|
|
expected = ["board_state_json", "snakes_json", "you_json", "food_json", "hazards_json", "body_json"]
|
|
for col in expected:
|
|
self.assertTrue(
|
|
any(col in cfg for cfg in compressed_columns),
|
|
f"{col} not found in _zstd_configs",
|
|
)
|
|
conn.close()
|
|
|
|
def test_turns_and_snake_turns_are_views(self):
|
|
db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3")
|
|
if not Path(db_path).exists():
|
|
self.skipTest(f"Database not found at {db_path}")
|
|
|
|
conn = _open(db_path)
|
|
views = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='view'").fetchall()}
|
|
self.assertIn("turns", views)
|
|
self.assertIn("snake_turns", views)
|
|
conn.close()
|
|
|
|
def test_board_state_json_is_stored_compressed(self):
|
|
db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3")
|
|
if not Path(db_path).exists():
|
|
self.skipTest(f"Database not found at {db_path}")
|
|
|
|
conn = _open(db_path)
|
|
tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()}
|
|
if "_turns_zstd" not in tables:
|
|
self.skipTest("_turns_zstd table not found")
|
|
|
|
row = conn.execute("SELECT board_state_json FROM _turns_zstd LIMIT 1").fetchone()
|
|
if row is None:
|
|
self.skipTest("No rows in _turns_zstd")
|
|
|
|
self.assertIsInstance(row[0], (bytes, bytearray), "board_state_json should be stored as compressed bytes")
|
|
conn.close()
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|