add Training for AI and AI Model and allow to collect rl data from BestBattleSnake
Build and Push Docker Container / build-and-push (push) Successful in 1m36s

This commit is contained in:
2026-04-03 23:19:09 +02:00
parent d3b0488e0f
commit 9e826afa5f
5 changed files with 460 additions and 4 deletions
+7 -1
View File
@@ -159,7 +159,7 @@ test-local-4 mode="standard" map="standard" base_port="9101" snake="BestBattleSn
-g "{{mode}}" --map "{{map}}" --seed "{{seed}}" $BROWSER_FLAG -g "{{mode}}" --map "{{map}}" --seed "{{seed}}" $BROWSER_FLAG
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Fataset helpers # Dataset helpers
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
export-dataset input="data" output="data/dataset/good_moves.jsonl": export-dataset input="data" output="data/dataset/good_moves.jsonl":
@@ -170,3 +170,9 @@ curate-dataset input="good_moves-*.jsonl" output="data/dataset/best_moves.jsonl"
analyze-dataset input="good_moves-*.jsonl" output="": analyze-dataset input="good_moves-*.jsonl" output="":
if [ -n "{{output}}" ]; then python -m server.DatasetStats --input "{{input}}" --output "{{output}}"; else python -m server.DatasetStats --input "{{input}}"; fi if [ -n "{{output}}" ]; then python -m server.DatasetStats --input "{{input}}" --output "{{output}}"; else python -m server.DatasetStats --input "{{input}}"; fi
train-ai input="data/dataset/best_moves.jsonl" rl_input="data/dataset/rl_bootstrap.jsonl" output="models/battlesnake_softmax_v2.json" eval_split="0.2" seed="42" epochs="14" lr="0.08":
if [ -f "{{rl_input}}" ]; then python -m server.TrainBattleSnakeAI --input "{{input}}" --input "{{rl_input}}" --output "{{output}}" --eval-split "{{eval_split}}" --seed "{{seed}}" --epochs "{{epochs}}" --lr "{{lr}}"; else python -m server.TrainBattleSnakeAI --input "{{input}}" --output "{{output}}" --eval-split "{{eval_split}}" --seed "{{seed}}" --epochs "{{epochs}}" --lr "{{lr}}"; fi
run-trained model="models/battlesnake_softmax_v2.json" port="8000":
TRAINED_SNAKE_MODEL="{{model}}" SNAKE="TrainedBattleSnake" PORT="{{port}}" "{{justfile_directory()}}/main.py"
+290
View File
@@ -0,0 +1,290 @@
import argparse, random, glob, json, math
from collections import Counter
from pathlib import Path
MOVES = ["up", "down", "left", "right"]
def resolve_input_files(inputs:list[str]) -> list[Path]:
resolved:list[Path] = []
seen:set[str] = set()
for item in inputs:
path = Path(item)
if path.is_dir():
for file_path in sorted(path.rglob("*.jsonl")):
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if any(ch in item for ch in "*?[]"):
for match in sorted(glob.glob(item)):
file_path = Path(match)
if not file_path.is_file():
continue
key = str(file_path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(file_path)
continue
if path.is_file():
key = str(path.resolve())
if key in seen:
continue
seen.add(key)
resolved.append(path)
return resolved
def _neighbors(x:int, y:int) -> list[tuple[int, int, str]]:
return [
(x, y + 1, "up"),
(x, y - 1, "down"),
(x - 1, y, "left"),
(x + 1, y, "right"),
]
def _safe_neighbor_count(point:tuple[int, int], blocked:set[tuple[int, int]], width:int, height:int) -> int:
count = 0
for nx, ny, _ in _neighbors(point[0], point[1]):
if not (0 <= nx < width and 0 <= ny < height):
continue
if (nx, ny) in blocked:
continue
count += 1
return count
def _manhattan_to_nearest_food(point: tuple[int, int], food: set[tuple[int, int]]) -> int:
if not food:
return 25
return min(abs(point[0] - fx) + abs(point[1] - fy) for fx, fy in food)
def extract_feature_values(row:dict) -> dict[str, float]:
board = row.get("game_board", {})
snakes = board.get("snakes", [])
if not snakes:
return {}
me = snakes[0]
body = me.get("body", [])
if not body:
return {}
width = int(board.get("width", 0))
height = int(board.get("height", 0))
head = body[0]
hx = int(head.get("x", 0))
hy = int(head.get("y", 0))
health = int(me.get("health", 100))
length = int(me.get("length", len(body)))
food_set = {(int(f.get("x", 0)), int(f.get("y", 0))) for f in board.get("food", [])}
hazard_set = {(int(h.get("x", 0)), int(h.get("y", 0))) for h in board.get("hazards", [])}
blocked = set()
for snake in snakes:
for seg in snake.get("body", []):
blocked.add((int(seg.get("x", 0)), int(seg.get("y", 0))))
features:dict[str, float] = {
"bias": 1.0,
"health_norm": max(0.0, min(1.0, health / 100.0)),
"length_norm": min(1.0, length / max(1.0, width * height)),
"turn_norm": min(1.0, int(row.get("turn", 0)) / 100.0),
"food_count_norm": min(1.0, len(food_set) / 10.0),
"hazard_count_norm": min(1.0, len(hazard_set) / 20.0),
"opponent_count_norm": min(1.0, max(0, len(snakes) - 1) / 7.0),
}
safe_total = 0
for nx, ny, move in _neighbors(hx, hy):
in_bounds = 1.0 if (0 <= nx < width and 0 <= ny < height) else 0.0
blocked_next = 1.0 if (nx, ny) in blocked else 0.0
food_next = 1.0 if (nx, ny) in food_set else 0.0
hazard_next = 1.0 if (nx, ny) in hazard_set else 0.0
if in_bounds and not blocked_next:
safe_total += 1
open_next = float(_safe_neighbor_count((nx, ny), blocked, width, height))
dist_food = float(_manhattan_to_nearest_food((nx, ny), food_set))
else:
open_next = 0.0
dist_food = 25.0
prefix = f"m:{move}:"
features[prefix + "in_bounds"] = in_bounds
features[prefix + "blocked"] = blocked_next
features[prefix + "food"] = food_next
features[prefix + "hazard"] = hazard_next
features[prefix + "open_next"] = min(4.0, open_next) / 4.0
features[prefix + "food_dist"] = min(25.0, dist_food) / 25.0
features["safe_total_norm"] = safe_total / 4.0
return features
class SoftmaxMoveModel:
def __init__(self):
self.weights = {move: {} for move in MOVES}
self.bias = {move: 0.0 for move in MOVES}
def _score(self, move:str, features:dict[str, float]) -> float:
weight_map = self.weights[move]
value = self.bias[move]
for name, feat in features.items():
value += weight_map.get(name, 0.0) * feat
return value
def fit(self, rows:list[dict], epochs:int=14, lr:float=0.08, l2:float=1e-6) -> None:
examples = []
for row in rows:
label = row.get("move")
if label not in MOVES:
continue
features = extract_feature_values(row)
if not features:
continue
examples.append((features, label))
if not examples:
return
for _ in range(epochs):
random.shuffle(examples)
for features, label in examples:
scores = {move: self._score(move, features) for move in MOVES}
max_score = max(scores.values())
exp_scores = {
move: math.exp(scores[move] - max_score) for move in MOVES
}
z = sum(exp_scores.values())
probs = {move: exp_scores[move] / z for move in MOVES}
for move in MOVES:
target = 1.0 if move == label else 0.0
gradient = target - probs[move]
self.bias[move] += lr * gradient
w = self.weights[move]
for name, feat in features.items():
current = w.get(name, 0.0)
update = lr * ((gradient * feat) - (l2 * current))
w[name] = current + update
def predict_scores(self, row:dict) -> dict[str, float]:
features = extract_feature_values(row)
if not features:
return {move: 0.0 for move in MOVES}
return {move: self._score(move, features) for move in MOVES}
def predict(self, row:dict) -> str:
scores = self.predict_scores(row)
return max(scores, key=lambda move: scores[move])
def evaluate(self, rows:list[dict]) -> dict:
total = 0
correct = 0
top2 = 0
confusion = {move: Counter() for move in MOVES}
for row in rows:
expected = row.get("move")
if expected not in MOVES:
continue
scores = self.predict_scores(row)
ranked = sorted(scores.items(), key=lambda item: item[1], reverse=True)
predicted = ranked[0][0]
total += 1
if predicted == expected:
correct += 1
if expected in {
ranked[0][0],
ranked[1][0] if len(ranked) > 1 else ranked[0][0],
}:
top2 += 1
confusion[expected][predicted] += 1
return {
"total": total,
"correct": correct,
"accuracy": round((correct / total) if total else 0.0, 4),
"top2_accuracy": round((top2 / total) if total else 0.0, 4),
"confusion": {label: dict(confusion[label]) for label in MOVES},
}
def to_dict(self) -> dict:
return {
"model_type": "softmax_moves_v2",
"moves": MOVES,
"weights": self.weights,
"bias": self.bias,
}
def read_rows(paths:list[Path]) -> list[dict]:
rows: list[dict] = []
for path in paths:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
row = json.loads(line)
if row.get("move") not in MOVES:
continue
if not row.get("game_board"):
continue
rows.append(row)
return rows
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train Battlesnake move model")
parser.add_argument("--input", action="append", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--eval-split", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--epochs", type=int, default=14)
parser.add_argument("--lr", type=float, default=0.08)
args = parser.parse_args()
paths = resolve_input_files(args.input)
if not paths:
raise SystemExit("No input files found")
rows = read_rows(paths)
if len(rows) < 50:
raise SystemExit("Need at least 50 rows for training")
random.seed(args.seed)
random.shuffle(rows)
eval_count = int(len(rows) * max(0.0, min(0.5, args.eval_split)))
eval_rows = rows[:eval_count]
train_rows = rows[eval_count:]
model = SoftmaxMoveModel()
model.fit(train_rows, epochs=max(1, args.epochs), lr=max(1e-4, args.lr))
metrics = model.evaluate(eval_rows)
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
payload = {
"input_files": [str(p) for p in paths],
"train_rows": len(train_rows),
"eval_rows": len(eval_rows),
"eval_metrics": metrics,
"model": model.to_dict(),
}
output.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(json.dumps({
"output": str(output),
"train_rows": len(train_rows),
"eval_rows": len(eval_rows),
"accuracy": metrics.get("accuracy"),
"top2_accuracy": metrics.get("top2_accuracy"),
},
indent=2,
))
+71 -3
View File
@@ -1,10 +1,12 @@
from collections.abc import Iterator from collections.abc import Iterator
from collections import deque from collections import deque
from typing import Any, cast from typing import Any, cast
import random, os
from time import perf_counter from time import perf_counter
from pathlib import Path
import random, json, os
from snakes.TemplateSnake import TemplateSnake from snakes.TemplateSnake import TemplateSnake
from server.GameBoard import GameBoard
class BestBattleSnake(TemplateSnake): class BestBattleSnake(TemplateSnake):
VERSION = "2.6.0" VERSION = "2.6.0"
@@ -38,6 +40,11 @@ class BestBattleSnake(TemplateSnake):
self.previous_hazards = set() self.previous_hazards = set()
self.duel_style = self._get_duel_style() self.duel_style = self._get_duel_style()
self.timeout_buffer_ms = self._get_timeout_buffer_ms() self.timeout_buffer_ms = self._get_timeout_buffer_ms()
self.rl_bootstrap_enabled = self._env_bool("RL_BOOTSTRAP_ENABLED", default=False)
self.rl_min_base_rows = self._env_int("RL_MIN_BASE_ROWS", default=5000)
self.rl_base_dataset_path = Path(os.getenv("RL_BASE_DATASET", "data/dataset/best_moves.jsonl"))
self.rl_bootstrap_path = Path(os.getenv("RL_BOOTSTRAP_OUTPUT", "data/dataset/rl_bootstrap.jsonl"))
self.rl_needs_more_data = False
def _get_duel_style(self) -> str: def _get_duel_style(self) -> str:
"""Resolve duel tuning style from `BATTLE_SNAKE_DUEL_STYLE` or `DUEL_STYLE`.""" """Resolve duel tuning style from `BATTLE_SNAKE_DUEL_STYLE` or `DUEL_STYLE`."""
@@ -78,7 +85,62 @@ class BestBattleSnake(TemplateSnake):
except ValueError: except ValueError:
return 120 return 120
def choose_move(self, game_data:dict) -> str: def _env_bool(self, name:str, default:bool=False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {'1', 'true', 'yes', 'on'}
def _env_int(self, name:str, default:int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
return default
def _count_jsonl_rows(self, path:Path) -> int:
if not path.exists() or not path.is_file():
return 0
count = 0
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
count += 1
return count
def _refresh_rl_bootstrap_state(self):
if not self.rl_bootstrap_enabled:
self.rl_needs_more_data = False
return
base_rows = self._count_jsonl_rows(self.rl_base_dataset_path)
self.rl_needs_more_data = base_rows < self.rl_min_base_rows
def _record_rl_bootstrap_sample(self, game_data:GameBoard, move:str, safe_moves:MoveMap, reason:str, scores:dict[str, float]|None=None):
if not self.rl_bootstrap_enabled or not self.rl_needs_more_data:
return
try:
self.rl_bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
row = {
"source": "best_battlesnake_bootstrap",
"game_id": getattr(game_data, "id", None),
"turn": game_data.get_turn(),
"move": move,
"safe_moves": list(safe_moves.keys()),
"reason": reason,
"game_board": game_data.get_game_board_as_dict(),
}
if scores:
row["scores"] = {k: round(v, 5) for k, v in scores.items()}
with self.rl_bootstrap_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
except Exception:
return
def choose_move(self, game_data:GameBoard) -> str:
"""Pick the next move from a Battlesnake move request. """Pick the next move from a Battlesnake move request.
Docs: https://docs.battlesnake.com/api/example-move Docs: https://docs.battlesnake.com/api/example-move
@@ -97,6 +159,7 @@ class BestBattleSnake(TemplateSnake):
self.last_move = None self.last_move = None
self.previous_hazards = set() self.previous_hazards = set()
self.last_game_id = game_id self.last_game_id = game_id
self._refresh_rl_bootstrap_state()
my_snake = cast(dict[str, Any], game_data.get_my_snake()) my_snake = cast(dict[str, Any], game_data.get_my_snake())
my_head = my_snake["head"] my_head = my_snake["head"]
@@ -149,6 +212,7 @@ class BestBattleSnake(TemplateSnake):
"reason": "no_safe_moves", "reason": "no_safe_moves",
} }
) )
self._record_rl_bootstrap_sample(game_data, fallback, safe_moves, "no_safe_moves")
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return fallback return fallback
@@ -179,6 +243,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "constrictor", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
@@ -202,6 +267,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "duel", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
@@ -333,6 +399,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = quick_move self.last_move = quick_move
self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"}) self.add_to_history({"turn": turn, "move": quick_move, "reason": "timeout_budget"})
self._record_rl_bootstrap_sample(game_data, quick_move, safe_moves, "timeout_budget")
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return quick_move return quick_move
@@ -367,6 +434,7 @@ class BestBattleSnake(TemplateSnake):
self.recent_heads.append(current_head_point) self.recent_heads.append(current_head_point)
self.last_move = best_move self.last_move = best_move
self.add_to_history({"turn": turn, "move": best_move, "scores": scores}) self.add_to_history({"turn": turn, "move": best_move, "scores": scores})
self._record_rl_bootstrap_sample(game_data, best_move, safe_moves, "multi", scores)
self.previous_hazards = set(hazard_set) self.previous_hazards = set(hazard_set)
return best_move return best_move
@@ -807,7 +875,7 @@ class BestBattleSnake(TemplateSnake):
return True return True
return False return False
def _hazard_damage_per_turn(self, game_data:dict) -> int: def _hazard_damage_per_turn(self, game_data:GameBoard) -> int:
"""Read royale hazard damage from ruleset settings. """Read royale hazard damage from ruleset settings.
Docs: https://docs.battlesnake.com/maps/royale Docs: https://docs.battlesnake.com/maps/royale
+91
View File
@@ -0,0 +1,91 @@
from pathlib import Path
from typing import Any
import random, json, os
from server.TrainBattleSnakeAI import MOVES, extract_feature_values
from snakes.TemplateSnake import TemplateSnake
class TrainedBattleSnake(TemplateSnake):
VERSION = "0.1.0"
def __init__(self):
super().__init__()
self.name = "TrainedBattleSnake"
self.version = self.VERSION
self._model_path:Path|None=None
self._model_data:dict[str, Any]|None=None
def choose_move(self, game_data) -> str:
self.game_board = game_data
self.calculations = []
safe_positions = self.find_safe_positions(add_to_calculations=True)
if not safe_positions:
self.add_to_history({"turn": game_data.get_turn(), "reason": "no_safe_moves"})
return "up"
model = self._load_model()
if not model:
move = random.choice(list(safe_positions.keys()))
self.add_to_history({
"turn": game_data.get_turn(),
"move": move,
"reason": "model_missing",
"safe_moves": list(safe_positions.keys()),
})
return move
row = {
"turn": game_data.get_turn(),
"game_board": game_data.get_game_board_as_dict(),
}
scores = self._predict_scores(model, row)
best_safe_move = max(safe_positions.keys(), key=lambda move: scores.get(move, float("-inf")))
self.add_to_history({
"turn": game_data.get_turn(),
"move": best_safe_move,
"safe_moves": list(safe_positions.keys()),
"scores": {move: round(scores.get(move, 0.0), 5) for move in MOVES},
})
return best_safe_move
def _load_model(self) -> dict[str, Any] | None:
env_path = os.getenv("TRAINED_SNAKE_MODEL", "models/battlesnake_softmax_v2.json")
path = Path(env_path)
if self._model_path == path and self._model_data is not None:
return self._model_data
if not path.exists() or not path.is_file():
self._model_path = path
self._model_data = None
return None
payload = json.loads(path.read_text(encoding="utf-8"))
model = payload.get("model")
if not isinstance(model, dict):
self._model_path = path
self._model_data = None
return None
self._model_path = path
self._model_data = model
return model
def _predict_scores(self, model:dict[str, Any], row:dict[str, Any]) -> dict[str, float]:
return self._predict_scores_softmax_v2(model, row)
def _predict_scores_softmax_v2(self, model:dict[str, Any], row:dict[str, Any]) -> dict[str, float]:
features = extract_feature_values(row)
weights = model.get("weights", {})
bias = model.get("bias", {})
scores:dict[str, float] = {}
for move in MOVES:
move_weights = weights.get(move, {})
score = float(bias.get(move, 0.0))
for name, value in features.items():
score += float(move_weights.get(name, 0.0)) * float(value)
scores[move] = score
return scores
+1
View File
@@ -7,6 +7,7 @@ SNAKE_REGISTRY = {
"MasterSnake": "1.2.0", "MasterSnake": "1.2.0",
"BetterMasterSnake": "1.3.0", "BetterMasterSnake": "1.3.0",
"BestBattleSnake": "2.6.0", "BestBattleSnake": "2.6.0",
"TrainedBattleSnake": "0.1.0",
} }
def build_snake(selected_snake: str): def build_snake(selected_snake: str):