diff --git a/justfile b/justfile index 765d7c3..bf493c0 100644 --- a/justfile +++ b/justfile @@ -159,7 +159,7 @@ test-local-4 mode="standard" map="standard" base_port="9101" snake="BestBattleSn -g "{{mode}}" --map "{{map}}" --seed "{{seed}}" $BROWSER_FLAG # ------------------------------------------------------------------------------ -# Fataset helpers +# Dataset helpers # ------------------------------------------------------------------------------ export-dataset input="data" output="data/dataset/good_moves.jsonl": @@ -170,3 +170,9 @@ curate-dataset input="good_moves-*.jsonl" output="data/dataset/best_moves.jsonl" analyze-dataset input="good_moves-*.jsonl" output="": if [ -n "{{output}}" ]; then python -m server.DatasetStats --input "{{input}}" --output "{{output}}"; else python -m server.DatasetStats --input "{{input}}"; fi + +train-ai input="data/dataset/best_moves.jsonl" rl_input="data/dataset/rl_bootstrap.jsonl" output="models/battlesnake_softmax_v2.json" eval_split="0.2" seed="42" epochs="14" lr="0.08": + if [ -f "{{rl_input}}" ]; then python -m server.TrainBattleSnakeAI --input "{{input}}" --input "{{rl_input}}" --output "{{output}}" --eval-split "{{eval_split}}" --seed "{{seed}}" --epochs "{{epochs}}" --lr "{{lr}}"; else python -m server.TrainBattleSnakeAI --input "{{input}}" --output "{{output}}" --eval-split "{{eval_split}}" --seed "{{seed}}" --epochs "{{epochs}}" --lr "{{lr}}"; fi + +run-trained model="models/battlesnake_softmax_v2.json" port="8000": + TRAINED_SNAKE_MODEL="{{model}}" SNAKE="TrainedBattleSnake" PORT="{{port}}" "{{justfile_directory()}}/main.py" diff --git a/server/TrainBattleSnakeAI.py b/server/TrainBattleSnakeAI.py new file mode 100644 index 0000000..469e13a --- /dev/null +++ b/server/TrainBattleSnakeAI.py @@ -0,0 +1,290 @@ +import argparse, random, glob, json, math +from collections import Counter +from pathlib import Path + +MOVES = ["up", "down", "left", "right"] + +def resolve_input_files(inputs:list[str]) -> list[Path]: + resolved:list[Path] = [] + seen:set[str] = set() + + for item in inputs: + path = Path(item) + if path.is_dir(): + for file_path in sorted(path.rglob("*.jsonl")): + key = str(file_path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(file_path) + continue + + if any(ch in item for ch in "*?[]"): + for match in sorted(glob.glob(item)): + file_path = Path(match) + if not file_path.is_file(): + continue + key = str(file_path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(file_path) + continue + + if path.is_file(): + key = str(path.resolve()) + if key in seen: + continue + seen.add(key) + resolved.append(path) + + return resolved + +def _neighbors(x:int, y:int) -> list[tuple[int, int, str]]: + return [ + (x, y + 1, "up"), + (x, y - 1, "down"), + (x - 1, y, "left"), + (x + 1, y, "right"), + ] + +def _safe_neighbor_count(point:tuple[int, int], blocked:set[tuple[int, int]], width:int, height:int) -> int: + count = 0 + for nx, ny, _ in _neighbors(point[0], point[1]): + if not (0 <= nx < width and 0 <= ny < height): + continue + if (nx, ny) in blocked: + continue + count += 1 + return count + +def _manhattan_to_nearest_food(point: tuple[int, int], food: set[tuple[int, int]]) -> int: + if not food: + return 25 + return min(abs(point[0] - fx) + abs(point[1] - fy) for fx, fy in food) + +def extract_feature_values(row:dict) -> dict[str, float]: + board = row.get("game_board", {}) + snakes = board.get("snakes", []) + if not snakes: + return {} + + me = snakes[0] + body = me.get("body", []) + if not body: + return {} + + width = int(board.get("width", 0)) + height = int(board.get("height", 0)) + head = body[0] + hx = int(head.get("x", 0)) + hy = int(head.get("y", 0)) + health = int(me.get("health", 100)) + length = int(me.get("length", len(body))) + + food_set = {(int(f.get("x", 0)), int(f.get("y", 0))) for f in board.get("food", [])} + hazard_set = {(int(h.get("x", 0)), int(h.get("y", 0))) for h in board.get("hazards", [])} + blocked = set() + for snake in snakes: + for seg in snake.get("body", []): + blocked.add((int(seg.get("x", 0)), int(seg.get("y", 0)))) + + features:dict[str, float] = { + "bias": 1.0, + "health_norm": max(0.0, min(1.0, health / 100.0)), + "length_norm": min(1.0, length / max(1.0, width * height)), + "turn_norm": min(1.0, int(row.get("turn", 0)) / 100.0), + "food_count_norm": min(1.0, len(food_set) / 10.0), + "hazard_count_norm": min(1.0, len(hazard_set) / 20.0), + "opponent_count_norm": min(1.0, max(0, len(snakes) - 1) / 7.0), + } + + safe_total = 0 + for nx, ny, move in _neighbors(hx, hy): + in_bounds = 1.0 if (0 <= nx < width and 0 <= ny < height) else 0.0 + blocked_next = 1.0 if (nx, ny) in blocked else 0.0 + food_next = 1.0 if (nx, ny) in food_set else 0.0 + hazard_next = 1.0 if (nx, ny) in hazard_set else 0.0 + + if in_bounds and not blocked_next: + safe_total += 1 + open_next = float(_safe_neighbor_count((nx, ny), blocked, width, height)) + dist_food = float(_manhattan_to_nearest_food((nx, ny), food_set)) + else: + open_next = 0.0 + dist_food = 25.0 + + prefix = f"m:{move}:" + features[prefix + "in_bounds"] = in_bounds + features[prefix + "blocked"] = blocked_next + features[prefix + "food"] = food_next + features[prefix + "hazard"] = hazard_next + features[prefix + "open_next"] = min(4.0, open_next) / 4.0 + features[prefix + "food_dist"] = min(25.0, dist_food) / 25.0 + + features["safe_total_norm"] = safe_total / 4.0 + return features + +class SoftmaxMoveModel: + def __init__(self): + self.weights = {move: {} for move in MOVES} + self.bias = {move: 0.0 for move in MOVES} + + def _score(self, move:str, features:dict[str, float]) -> float: + weight_map = self.weights[move] + value = self.bias[move] + for name, feat in features.items(): + value += weight_map.get(name, 0.0) * feat + return value + + def fit(self, rows:list[dict], epochs:int=14, lr:float=0.08, l2:float=1e-6) -> None: + examples = [] + for row in rows: + label = row.get("move") + if label not in MOVES: + continue + features = extract_feature_values(row) + if not features: + continue + examples.append((features, label)) + + if not examples: + return + + for _ in range(epochs): + random.shuffle(examples) + for features, label in examples: + scores = {move: self._score(move, features) for move in MOVES} + max_score = max(scores.values()) + exp_scores = { + move: math.exp(scores[move] - max_score) for move in MOVES + } + z = sum(exp_scores.values()) + probs = {move: exp_scores[move] / z for move in MOVES} + + for move in MOVES: + target = 1.0 if move == label else 0.0 + gradient = target - probs[move] + self.bias[move] += lr * gradient + + w = self.weights[move] + for name, feat in features.items(): + current = w.get(name, 0.0) + update = lr * ((gradient * feat) - (l2 * current)) + w[name] = current + update + + def predict_scores(self, row:dict) -> dict[str, float]: + features = extract_feature_values(row) + if not features: + return {move: 0.0 for move in MOVES} + return {move: self._score(move, features) for move in MOVES} + + def predict(self, row:dict) -> str: + scores = self.predict_scores(row) + return max(scores, key=lambda move: scores[move]) + + def evaluate(self, rows:list[dict]) -> dict: + total = 0 + correct = 0 + top2 = 0 + confusion = {move: Counter() for move in MOVES} + + for row in rows: + expected = row.get("move") + if expected not in MOVES: + continue + + scores = self.predict_scores(row) + ranked = sorted(scores.items(), key=lambda item: item[1], reverse=True) + predicted = ranked[0][0] + + total += 1 + if predicted == expected: + correct += 1 + if expected in { + ranked[0][0], + ranked[1][0] if len(ranked) > 1 else ranked[0][0], + }: + top2 += 1 + + confusion[expected][predicted] += 1 + + return { + "total": total, + "correct": correct, + "accuracy": round((correct / total) if total else 0.0, 4), + "top2_accuracy": round((top2 / total) if total else 0.0, 4), + "confusion": {label: dict(confusion[label]) for label in MOVES}, + } + + def to_dict(self) -> dict: + return { + "model_type": "softmax_moves_v2", + "moves": MOVES, + "weights": self.weights, + "bias": self.bias, + } + +def read_rows(paths:list[Path]) -> list[dict]: + rows: list[dict] = [] + for path in paths: + with path.open("r", encoding="utf-8") as handle: + for line in handle: + if not line.strip(): + continue + row = json.loads(line) + if row.get("move") not in MOVES: + continue + if not row.get("game_board"): + continue + rows.append(row) + return rows + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train Battlesnake move model") + parser.add_argument("--input", action="append", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--eval-split", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--epochs", type=int, default=14) + parser.add_argument("--lr", type=float, default=0.08) + args = parser.parse_args() + + paths = resolve_input_files(args.input) + if not paths: + raise SystemExit("No input files found") + + rows = read_rows(paths) + if len(rows) < 50: + raise SystemExit("Need at least 50 rows for training") + + random.seed(args.seed) + random.shuffle(rows) + eval_count = int(len(rows) * max(0.0, min(0.5, args.eval_split))) + eval_rows = rows[:eval_count] + train_rows = rows[eval_count:] + + model = SoftmaxMoveModel() + model.fit(train_rows, epochs=max(1, args.epochs), lr=max(1e-4, args.lr)) + metrics = model.evaluate(eval_rows) + + output = Path(args.output) + output.parent.mkdir(parents=True, exist_ok=True) + payload = { + "input_files": [str(p) for p in paths], + "train_rows": len(train_rows), + "eval_rows": len(eval_rows), + "eval_metrics": metrics, + "model": model.to_dict(), + } + output.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + print(json.dumps({ + "output": str(output), + "train_rows": len(train_rows), + "eval_rows": len(eval_rows), + "accuracy": metrics.get("accuracy"), + "top2_accuracy": metrics.get("top2_accuracy"), + }, + indent=2, + )) diff --git a/snakes/BestBattleSnake.py b/snakes/BestBattleSnake.py index 0fb45a1..bc71add 100644 --- a/snakes/BestBattleSnake.py +++ b/snakes/BestBattleSnake.py @@ -1,10 +1,12 @@ from collections.abc import Iterator from collections import deque from typing import Any, cast -import random, os from time import perf_counter +from pathlib import Path +import random, json, os from snakes.TemplateSnake import TemplateSnake +from server.GameBoard import GameBoard class BestBattleSnake(TemplateSnake): VERSION = "2.6.0" @@ -38,6 +40,11 @@ class BestBattleSnake(TemplateSnake): self.previous_hazards = set() self.duel_style = self._get_duel_style() self.timeout_buffer_ms = self._get_timeout_buffer_ms() + self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False) + self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000) + self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl")) + self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl")) + self.rl_needs_more_data = False def _get_duel_style(self) -> str: """Resolve duel tuning style from `BATTLE_SNAKE_DUEL_STYLE` or `DUEL_STYLE`.""" @@ -78,7 +85,62 @@ class BestBattleSnake(TemplateSnake): except ValueError: return 120 - def choose_move(self, game_data:dict) -> str: + def _env_bool(self, name:str, default:bool=False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.lower() in {'1', 'true', 'yes', 'on'} + + def _env_int(self, name:str, default:int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + def _count_jsonl_rows(self, path:Path) -> int: + if not path.exists() or not path.is_file(): + return 0 + count = 0 + with path.open("r", encoding="utf-8") as handle: + for line in handle: + if line.strip(): + count += 1 + return count + + def _refresh_rl_bootstrap_state(self): + if not self.rl_bootstrap_enabled: + self.rl_needs_more_data = False + return + base_rows = self._count_jsonl_rows(self.rl_base_dataset_path) + self.rl_needs_more_data = base_rows < self.rl_min_base_rows + + def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None): + if not self.rl_bootstrap_enabled or not self.rl_needs_more_data: + return + + try: + self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True) + row = { + "source": "best_battlesnake_bootstrap", + "game_id": getattr(game_data, "id", None), + "turn": game_data.get_turn(), + "move": move, + "safe_moves": list(safe_moves.keys()), + "reason": reason, + "game_board": game_data.get_game_board_as_dict(), + } + if scores: + row["scores"] = {k: round(v, 5) for k, v in scores.items()} + + with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(row, ensure_ascii=False) + "\n") + except Exception: + return + + def choose_move(self, game_data:GameBoard) -> str: """Pick the next move from a Battlesnake move request. Docs: https://docs.battlesnake.com/api/example-move @@ -97,6 +159,7 @@ class BestBattleSnake(TemplateSnake): self.last_move = None self.previous_hazards = set() self.last_game_id = game_id + self._refresh_rl_bootstrap_state() my_snake = cast(dict[str, Any], game_data.get_my_snake()) my_head = my_snake["head"] @@ -149,6 +212,7 @@ class BestBattleSnake(TemplateSnake): "reason": "no_safe_moves", } ) + self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves") self.previous_hazards = set(hazard_set) return fallback @@ -179,6 +243,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) + self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores) self.previous_hazards = set(hazard_set) return best_move @@ -202,6 +267,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) + self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores) self.previous_hazards = set(hazard_set) return best_move @@ -333,6 +399,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = quick_move self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"}) + self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget") self.previous_hazards = set(hazard_set) return quick_move @@ -367,6 +434,7 @@ class BestBattleSnake(TemplateSnake): self.recent_heads.append(current_head_point) self.last_move = best_move self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) + self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores) self.previous_hazards = set(hazard_set) return best_move @@ -807,7 +875,7 @@ class BestBattleSnake(TemplateSnake): return True return False - def _hazard_damage_per_turn(self, game_data:dict) -> int: + def _hazard_damage_per_turn(self, game_data:GameBoard) -> int: """Read royale hazard damage from ruleset settings. Docs: https://docs.battlesnake.com/maps/royale diff --git a/snakes/TrainedBattleSnake.py b/snakes/TrainedBattleSnake.py new file mode 100644 index 0000000..32679ed --- /dev/null +++ b/snakes/TrainedBattleSnake.py @@ -0,0 +1,91 @@ +from pathlib import Path +from typing import Any +import random, json, os + +from server.TrainBattleSnakeAI import MOVES, extract_feature_values +from snakes.TemplateSnake import TemplateSnake + +class TrainedBattleSnake(TemplateSnake): + VERSION = "0.1.0" + + def __init__(self): + super().__init__() + self.name = "TrainedBattleSnake" + self.version = self.VERSION + self._model_path:Path|None=None + self._model_data:dict[str, Any]|None=None + + def choose_move(self, game_data) -> str: + self.game_board = game_data + self.calculations = [] + + safe_positions = self.find_safe_positions(add_to_calculations=True) + if not safe_positions: + self.add_to_history({"turn": game_data.get_turn(), "reason": "no_safe_moves"}) + return "up" + + model = self._load_model() + if not model: + move = random.choice(list(safe_positions.keys())) + self.add_to_history({ + "turn": game_data.get_turn(), + "move": move, + "reason": "model_missing", + "safe_moves": list(safe_positions.keys()), + }) + return move + + row = { + "turn": game_data.get_turn(), + "game_board": game_data.get_game_board_as_dict(), + } + scores = self._predict_scores(model, row) + + best_safe_move = max(safe_positions.keys(), key=lambda move: scores.get(move, float("-inf"))) + self.add_to_history({ + "turn": game_data.get_turn(), + "move": best_safe_move, + "safe_moves": list(safe_positions.keys()), + "scores": {move: round(scores.get(move, 0.0), 5) for move in MOVES}, + }) + return best_safe_move + + def _load_model(self) -> dict[str, Any] | None: + env_path = os.getenv("TRAINED_SNAKE_MODEL", "models/battlesnake_softmax_v2.json") + path = Path(env_path) + + if self._model_path == path and self._model_data is not None: + return self._model_data + + if not path.exists() or not path.is_file(): + self._model_path = path + self._model_data = None + return None + + payload = json.loads(path.read_text(encoding="utf-8")) + model = payload.get("model") + if not isinstance(model, dict): + self._model_path = path + self._model_data = None + return None + + self._model_path = path + self._model_data = model + return model + + def _predict_scores(self, model:dict[str, Any], row:dict[str, Any]) -> dict[str, float]: + return self._predict_scores_softmax_v2(model, row) + + def _predict_scores_softmax_v2(self, model:dict[str, Any], row:dict[str, Any]) -> dict[str, float]: + features = extract_feature_values(row) + weights = model.get("weights", {}) + bias = model.get("bias", {}) + scores:dict[str, float] = {} + + for move in MOVES: + move_weights = weights.get(move, {}) + score = float(bias.get(move, 0.0)) + for name, value in features.items(): + score += float(move_weights.get(name, 0.0)) * float(value) + scores[move] = score + return scores diff --git a/snakes/__init__.py b/snakes/__init__.py index f340b72..b893c1d 100644 --- a/snakes/__init__.py +++ b/snakes/__init__.py @@ -7,6 +7,7 @@ SNAKE_REGISTRY = { "MasterSnake": "1.2.0", "BetterMasterSnake": "1.3.0", "BestBattleSnake": "2.6.0", + "TrainedBattleSnake": "0.1.0", } def build_snake(selected_snake: str):