From fe999c11f4f4f4e6501cb4ec1f3f0f539277ac50 Mon Sep 17 00:00:00 2001 From: Daniel Dolezal Date: Mon, 6 Apr 2026 16:34:15 +0200 Subject: [PATCH] add compression to GameplayDatabase and test if compression works with the sqlite_zstd extension --- server/database/GameplayDatabase.py | 56 ++++++++++++++++-- tests/test_ZstdCompression.py | 92 +++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 tests/test_ZstdCompression.py diff --git a/server/database/GameplayDatabase.py b/server/database/GameplayDatabase.py index 63ee9d8..f74a660 100644 --- a/server/database/GameplayDatabase.py +++ b/server/database/GameplayDatabase.py @@ -1,7 +1,9 @@ from datetime import datetime, timezone -import asyncio, sqlite3, json +import asyncio, sqlite3, json, os from pathlib import Path +_ZSTD_EXT = Path(os.environ.get("SQLITE_ZSTD_EXT", "/usr/local/lib/libsqlite_zstd.so")).expanduser().resolve() + class GameplayDatabase: def __init__(self, db_path:str, busy_timeout_ms:int=5000): self.db_path = db_path @@ -14,6 +16,12 @@ class GameplayDatabase: timeout=max(1, self.busy_timeout_ms // 1000), isolation_level=None, ) + + if Path(_ZSTD_EXT).exists(): + connection.enable_load_extension(True) + connection.load_extension(str(_ZSTD_EXT)) + connection.enable_load_extension(False) + connection.row_factory = sqlite3.Row connection.execute("PRAGMA foreign_keys = ON") connection.execute("PRAGMA journal_mode = WAL") @@ -26,7 +34,10 @@ class GameplayDatabase: def _initialize_database(self) -> None: Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) with self._connect() as connection: - connection.execute("PRAGMA auto_vacuum = INCREMENTAL") + current_vacuum = connection.execute("PRAGMA auto_vacuum").fetchone()[0] + if current_vacuum != 1: + connection.execute("PRAGMA auto_vacuum = FULL") + connection.execute("VACUUM") connection.executescript(""" CREATE TABLE IF NOT EXISTS games ( game_id TEXT PRIMARY KEY, @@ -81,17 +92,31 @@ class GameplayDatabase: FOREIGN KEY (game_id) REFERENCES games(game_id) ON DELETE CASCADE ); - CREATE INDEX IF NOT EXISTS idx_turns_game_turn ON turns(game_id, turn); - CREATE INDEX IF NOT EXISTS idx_games_status ON games(status); - CREATE INDEX IF NOT EXISTS idx_snake_turns_game_turn ON snake_turns(game_id, turn); """) + self._create_indexes_if_tables(connection) self._ensure_column_exists(connection, "turns", "my_thinking_json", "TEXT") self._ensure_column_exists(connection, "games", "your_snake_type", "TEXT") self._ensure_column_exists(connection, "games", "your_snake_version", "TEXT") self._ensure_column_exists(connection, "games", "game_type", "TEXT") self._ensure_column_exists(connection, "snake_turns", "latency", "TEXT") + self._enable_zstd_compression(connection) connection.execute("PRAGMA optimize") + def _create_indexes_if_tables(self, connection: sqlite3.Connection) -> None: + real_tables = { + row[0] for row in connection.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + } + indexes = [ + ("idx_turns_game_turn", "turns", "game_id, turn"), + ("idx_games_status", "games", "status"), + ("idx_snake_turns_game_turn", "snake_turns", "game_id, turn"), + ] + for idx_name, table, cols in indexes: + if table in real_tables: + connection.execute(f"CREATE INDEX IF NOT EXISTS {idx_name} ON {table}({cols})") + def _ensure_column_exists(self, connection:sqlite3.Connection, table_name:str, column_name:str, column_type:str) -> None: existing = connection.execute(f"PRAGMA table_info({table_name})").fetchall() if any(row["name"] == column_name for row in existing): @@ -99,6 +124,26 @@ class GameplayDatabase: connection.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}") + def _enable_zstd_compression(self, connection: sqlite3.Connection) -> None: + compressed_columns = [ + ("turns", "board_state_json"), + ("turns", "snakes_json"), + ("turns", "you_json"), + ("turns", "food_json"), + ("turns", "hazards_json"), + ("snake_turns", "body_json"), + ] + for table, column in compressed_columns: + try: + connection.execute( + "SELECT zstd_enable_transparent(?)", + [json.dumps({"table": table, "column": column, "compression_level": 6, "dict_chooser": "'a'"})], + ) + except sqlite3.OperationalError: + pass # already enabled + + connection.execute("SELECT zstd_incremental_maintenance(null, 1)") + def _utc_now(self) -> str: return datetime.now(timezone.utc).isoformat() @@ -200,7 +245,6 @@ class GameplayDatabase: ), ) connection.execute("PRAGMA wal_checkpoint(PASSIVE)") - connection.execute("PRAGMA incremental_vacuum(200)") connection.execute("PRAGMA optimize") def _record_turn_sync(self, game_state:dict, my_move:str|None, my_thinking:dict|None) -> None: diff --git a/tests/test_ZstdCompression.py b/tests/test_ZstdCompression.py new file mode 100644 index 0000000..592c182 --- /dev/null +++ b/tests/test_ZstdCompression.py @@ -0,0 +1,92 @@ +import unittest, sqlite3, os +from pathlib import Path + +from server.database import GameplayDatabase + +EXT_PATH = Path(os.environ.get("SQLITE_ZSTD_EXT", "lib/libsqlite_zstd.so")).expanduser().resolve() + +def _open(db_path: str) -> sqlite3.Connection: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + if EXT_PATH.exists(): + conn.enable_load_extension(True) + conn.load_extension(str(EXT_PATH)) + conn.enable_load_extension(False) + return conn + +@unittest.skipUnless(EXT_PATH.exists(), f"sqlite-zstd extension not found at {EXT_PATH}") +class TestZstdCompression(unittest.TestCase): + + def test_extension_loads(self): + conn = sqlite3.connect(":memory:") + conn.enable_load_extension(True) + conn.load_extension(str(EXT_PATH)) + conn.enable_load_extension(False) + result = conn.execute("SELECT zstd_compress('hello world', 6)").fetchone()[0] + self.assertIsInstance(result, bytes) + self.assertGreater(len(result), 0) + conn.close() + + def test_compress_decompress_roundtrip(self): + # Use a long repetitive string so compression actually reduces size + original = '{"board": {"width": 11, "height": 11, "snakes": []}, "turn": 42} ' * 20 + conn = sqlite3.connect(":memory:") + conn.enable_load_extension(True) + conn.load_extension(str(EXT_PATH)) + conn.enable_load_extension(False) + compressed = conn.execute("SELECT zstd_compress(?, 6)", (original,)).fetchone()[0] + self.assertIsInstance(compressed, bytes) + self.assertLess(len(compressed), len(original)) + conn.close() + + def test_transparent_compression_enabled(self): + db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3") + if not Path(db_path).exists(): + self.skipTest(f"Database not found at {db_path}") + + conn = _open(db_path) + tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master").fetchall()} + self.assertIn("_zstd_configs", tables, "_zstd_configs table missing — compression not enabled") + + configs = conn.execute("SELECT config FROM _zstd_configs").fetchall() + compressed_columns = [row["config"] for row in configs] + self.assertEqual(len(compressed_columns), 6) + + expected = ["board_state_json", "snakes_json", "you_json", "food_json", "hazards_json", "body_json"] + for col in expected: + self.assertTrue( + any(col in cfg for cfg in compressed_columns), + f"{col} not found in _zstd_configs", + ) + conn.close() + + def test_turns_and_snake_turns_are_views(self): + db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3") + if not Path(db_path).exists(): + self.skipTest(f"Database not found at {db_path}") + + conn = _open(db_path) + views = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='view'").fetchall()} + self.assertIn("turns", views) + self.assertIn("snake_turns", views) + conn.close() + + def test_board_state_json_is_stored_compressed(self): + db_path = os.environ.get("GAMEPLAY_DB_PATH", "data/database/gameplay.sqlite3") + if not Path(db_path).exists(): + self.skipTest(f"Database not found at {db_path}") + + conn = _open(db_path) + tables = {r[0] for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()} + if "_turns_zstd" not in tables: + self.skipTest("_turns_zstd table not found") + + row = conn.execute("SELECT board_state_json FROM _turns_zstd LIMIT 1").fetchone() + if row is None: + self.skipTest("No rows in _turns_zstd") + + self.assertIsInstance(row[0], (bytes, bytearray), "board_state_json should be stored as compressed bytes") + conn.close() + +if __name__ == "__main__": + unittest.main()