Files

374 lines
13 KiB
Python

from __future__ import annotations
import argparse
import json
import subprocess
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass
class TestResult:
name: str
passed: bool
details: str
def main() -> int:
args = parse_args()
repo_root = Path.cwd()
model_path = Path(args.model_path).expanduser().resolve()
koboldcpp_path = Path(args.koboldcpp_path).expanduser().resolve()
if not model_path.exists():
print(f"FAIL: model does not exist: {model_path}")
return 2
if not koboldcpp_path.exists():
print(f"FAIL: koboldcpp executable does not exist: {koboldcpp_path}")
return 2
logs_dir = repo_root / "data" / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
safe_name = model_path.stem.replace(" ", "_")
stdout_path = logs_dir / f"model-smoke-{safe_name}-ctx{args.context_size}.out.log"
stderr_path = logs_dir / f"model-smoke-{safe_name}-ctx{args.context_size}.err.log"
process: subprocess.Popen[Any] | None = None
try:
if args.stop_existing:
stop_existing_koboldcpp()
process = start_koboldcpp(
koboldcpp_path=koboldcpp_path,
model_path=model_path,
port=args.port,
context_size=args.context_size,
extra_args=args.kobold_arg,
stdout_path=stdout_path,
stderr_path=stderr_path,
)
base_url = f"http://127.0.0.1:{args.port}"
model_id = wait_for_model(base_url, process, args.startup_timeout_seconds)
print(f"Model ready: {model_id}")
results = [
run_dialogue_test(base_url, model_id),
run_auto_tool_test(base_url, model_id),
run_forced_tool_test(base_url, model_id),
run_agent_action_prompt_test(base_url, model_id),
]
passed = all(result.passed for result in results)
print()
print("KoboldCpp model smoke test summary")
print(f"Model: {model_path}")
print(f"Context size: {args.context_size}")
print(f"KoboldCpp logs: {stdout_path} | {stderr_path}")
for result in results:
status = "PASS" if result.passed else "FAIL"
print(f"{status}: {result.name} - {result.details}")
print()
print("PASS" if passed else "FAIL")
return 0 if passed else 1
finally:
if process is not None and process.poll() is None and not args.keep_running:
process.terminate()
try:
process.wait(timeout=20)
except subprocess.TimeoutExpired:
process.kill()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Start KoboldCpp with a GGUF and run LocalDiplomacy tool-calling smoke tests."
)
parser.add_argument("model_path", help="Path to the GGUF model to test.")
parser.add_argument("--context-size", type=int, default=40960, help="KoboldCpp context size.")
parser.add_argument("--koboldcpp-path", default="./koboldcpp.exe", help="Path to koboldcpp.exe.")
parser.add_argument("--port", type=int, default=5001, help="KoboldCpp HTTP port.")
parser.add_argument("--startup-timeout-seconds", type=float, default=240.0)
parser.add_argument("--request-timeout-seconds", type=float, default=180.0)
parser.add_argument("--keep-running", action="store_true", help="Leave the started KoboldCpp process running.")
parser.add_argument("--stop-existing", action="store_true", help="Stop existing koboldcpp.exe processes first.")
parser.add_argument(
"--kobold-arg",
action="append",
default=[],
help="Extra argument passed to KoboldCpp. Repeat for multiple args.",
)
return parser.parse_args()
def start_koboldcpp(
*,
koboldcpp_path: Path,
model_path: Path,
port: int,
context_size: int,
extra_args: list[str],
stdout_path: Path,
stderr_path: Path,
) -> subprocess.Popen[Any]:
command = [
str(koboldcpp_path),
"--model",
str(model_path),
"--port",
str(port),
"--contextsize",
str(context_size),
"--jinja",
"--jinjatools",
*extra_args,
]
print("Starting KoboldCpp:")
print(" ".join(quote_arg(part) for part in command))
return subprocess.Popen(
command,
stdout=stdout_path.open("w", encoding="utf-8", errors="replace"),
stderr=stderr_path.open("w", encoding="utf-8", errors="replace"),
creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
)
def wait_for_model(base_url: str, process: subprocess.Popen[Any], timeout_seconds: float) -> str:
deadline = time.monotonic() + timeout_seconds
last_error = ""
while time.monotonic() < deadline:
if process.poll() is not None:
raise RuntimeError(f"KoboldCpp exited early with code {process.returncode}.")
try:
data = get_json(f"{base_url}/v1/models", timeout=10)
models = data.get("data") or []
if models:
return str(models[0].get("id") or "local-model")
except Exception as exc: # noqa: BLE001 - report the last startup error.
last_error = str(exc)
time.sleep(2)
raise TimeoutError(f"KoboldCpp did not become ready within {timeout_seconds}s. Last error: {last_error}")
def run_dialogue_test(base_url: str, model_id: str) -> TestResult:
data = chat(
base_url,
{
"model": model_id,
"messages": [
{
"role": "system",
"content": "/no_think You are Derthert, a Bannerlord lord. Answer in one sentence. Never expose reasoning.",
},
{"role": "user", "content": "/no_think Greetings. What news from the border?"},
],
"temperature": 0.3,
"max_tokens": 120,
},
)
content = first_message(data).get("content") or ""
bad_markers = ["<think>", "</think>", '"arguments"', "<tool_call"]
passed = bool(content.strip()) and not any(marker in content for marker in bad_markers)
return TestResult("dialogue_no_visible_reasoning", passed, truncate(content.strip()))
def run_auto_tool_test(base_url: str, model_id: str) -> TestResult:
data = chat(
base_url,
{
"model": model_id,
"messages": [
{
"role": "system",
"content": "/no_think Use tools through the API when a game action is needed. Never expose reasoning.",
},
{"role": "user", "content": "/no_think Derthert should propose peace with Battania."},
],
"tools": [action_tool_schema(["propose_peace", "declare_war"])],
"tool_choice": "auto",
"temperature": 0.0,
"max_tokens": 300,
},
)
return validate_tool_call("auto_tool_call", data)
def run_forced_tool_test(base_url: str, model_id: str) -> TestResult:
data = chat(
base_url,
{
"model": model_id,
"messages": [
{
"role": "system",
"content": "/no_think Use tools through the API when a game action is needed. Never expose reasoning.",
},
{
"role": "user",
"content": (
"/no_think Use propose_game_action now: action_type propose_peace, "
"actor_id lord_derthert, target_id kingdom_battania, reason end the border raids."
),
},
],
"tools": [action_tool_schema(["accept_peace", "propose_peace", "reject_peace"])],
"tool_choice": {"type": "function", "function": {"name": "propose_game_action"}},
"temperature": 0.0,
"max_tokens": 300,
},
)
return validate_tool_call("forced_tool_call", data)
def run_agent_action_prompt_test(base_url: str, model_id: str) -> TestResult:
context = {
"campaign_id": "mock-campaign",
"player_id": "player",
"npc_id": "lord_derthert",
"npc_name": "Derthert",
"npc_kingdom_id": "kingdom_vlandia",
"player_kingdom_id": "kingdom_vlandia",
"kingdom_state": {"wars": [{"enemy_kingdom_id": "kingdom_battania", "days": 42}], "war_fatigue": 0.67},
"nearby_parties": [{"id": "party_battania_raiders", "faction_id": "kingdom_battania"}],
"nearby_settlements": [{"id": "town_sargot", "owner_kingdom_id": "kingdom_vlandia"}],
"recent_events": [{"id": "event_border_raids", "summary": "Border raids have strained Vlandia and Battania."}],
}
data = chat(
base_url,
{
"model": model_id,
"messages": [
{
"role": "system",
"content": (
"/no_think You are LocalDiplomacy's Bannerlord action planner. "
"Call the provided tool through the API. Do not answer in prose. Use only IDs present in the request."
),
},
{
"role": "user",
"content": (
"/no_think "
f"Context: {json.dumps(context, ensure_ascii=False)}\n"
"Command: Use propose_game_action now: action_type propose_peace, actor_id lord_derthert, "
"target_id kingdom_battania, reason end the border raids."
),
},
],
"tools": [action_tool_schema(["accept_peace", "propose_peace", "reject_peace"])],
"tool_choice": {"type": "function", "function": {"name": "propose_game_action"}},
"temperature": 0.0,
"max_tokens": 400,
},
)
return validate_tool_call("agent_action_prompt_tool_call", data)
def action_tool_schema(action_types: list[str]) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": "propose_game_action",
"description": "Queue one Bannerlord game action proposal for mod-side validation.",
"parameters": {
"type": "object",
"properties": {
"action_type": {"type": "string", "enum": action_types},
"actor_id": {"type": "string"},
"target_id": {"type": "string"},
"args": {"type": "object"},
"confidence": {"type": "number", "minimum": 0, "maximum": 1},
"reason": {"type": "string"},
"requires_player_confirmation": {"type": "boolean"},
},
"required": ["action_type", "actor_id", "target_id", "reason"],
},
},
}
def validate_tool_call(name: str, data: dict[str, Any]) -> TestResult:
message = first_message(data)
tool_calls = message.get("tool_calls") or []
if not tool_calls:
return TestResult(name, False, f"no tool_calls; content={truncate(str(message.get('content') or ''))}")
function = (tool_calls[0] or {}).get("function") or {}
arguments_raw = function.get("arguments") or "{}"
try:
arguments = json.loads(arguments_raw)
except json.JSONDecodeError:
return TestResult(name, False, f"invalid arguments JSON: {truncate(arguments_raw)}")
passed = (
function.get("name") == "propose_game_action"
and arguments.get("action_type") == "propose_peace"
and bool(arguments.get("actor_id"))
and bool(arguments.get("target_id"))
)
return TestResult(name, passed, json.dumps(arguments, ensure_ascii=False))
def chat(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
return post_json(f"{base_url}/v1/chat/completions", payload, timeout=180)
def first_message(data: dict[str, Any]) -> dict[str, Any]:
choices = data.get("choices") or [{}]
choice = choices[0] or {}
return choice.get("message") or {}
def get_json(url: str, timeout: float) -> dict[str, Any]:
with urllib.request.urlopen(url, timeout=timeout) as response:
return json.loads(response.read().decode("utf-8"))
def post_json(url: str, payload: dict[str, Any], timeout: float) -> dict[str, Any]:
body = json.dumps(payload).encode("utf-8")
request = urllib.request.Request(
url,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
error_body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {exc.code}: {error_body}") from exc
def stop_existing_koboldcpp() -> None:
if sys.platform != "win32":
return
subprocess.run(
[
"powershell",
"-NoProfile",
"-Command",
"Get-Process koboldcpp -ErrorAction SilentlyContinue | Stop-Process -Force",
],
check=False,
)
def quote_arg(value: str) -> str:
if " " in value or "\t" in value:
return f'"{value}"'
return value
def truncate(value: str, limit: int = 220) -> str:
value = " ".join(value.split())
return value if len(value) <= limit else f"{value[: limit - 3]}..."
if __name__ == "__main__":
raise SystemExit(main())