374 lines
13 KiB
Python
374 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
name: str
|
|
passed: bool
|
|
details: str
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
repo_root = Path.cwd()
|
|
model_path = Path(args.model_path).expanduser().resolve()
|
|
koboldcpp_path = Path(args.koboldcpp_path).expanduser().resolve()
|
|
|
|
if not model_path.exists():
|
|
print(f"FAIL: model does not exist: {model_path}")
|
|
return 2
|
|
if not koboldcpp_path.exists():
|
|
print(f"FAIL: koboldcpp executable does not exist: {koboldcpp_path}")
|
|
return 2
|
|
|
|
logs_dir = repo_root / "data" / "logs"
|
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
safe_name = model_path.stem.replace(" ", "_")
|
|
stdout_path = logs_dir / f"model-smoke-{safe_name}-ctx{args.context_size}.out.log"
|
|
stderr_path = logs_dir / f"model-smoke-{safe_name}-ctx{args.context_size}.err.log"
|
|
|
|
process: subprocess.Popen[Any] | None = None
|
|
try:
|
|
if args.stop_existing:
|
|
stop_existing_koboldcpp()
|
|
|
|
process = start_koboldcpp(
|
|
koboldcpp_path=koboldcpp_path,
|
|
model_path=model_path,
|
|
port=args.port,
|
|
context_size=args.context_size,
|
|
extra_args=args.kobold_arg,
|
|
stdout_path=stdout_path,
|
|
stderr_path=stderr_path,
|
|
)
|
|
base_url = f"http://127.0.0.1:{args.port}"
|
|
model_id = wait_for_model(base_url, process, args.startup_timeout_seconds)
|
|
print(f"Model ready: {model_id}")
|
|
|
|
results = [
|
|
run_dialogue_test(base_url, model_id),
|
|
run_auto_tool_test(base_url, model_id),
|
|
run_forced_tool_test(base_url, model_id),
|
|
run_agent_action_prompt_test(base_url, model_id),
|
|
]
|
|
|
|
passed = all(result.passed for result in results)
|
|
print()
|
|
print("KoboldCpp model smoke test summary")
|
|
print(f"Model: {model_path}")
|
|
print(f"Context size: {args.context_size}")
|
|
print(f"KoboldCpp logs: {stdout_path} | {stderr_path}")
|
|
for result in results:
|
|
status = "PASS" if result.passed else "FAIL"
|
|
print(f"{status}: {result.name} - {result.details}")
|
|
print()
|
|
print("PASS" if passed else "FAIL")
|
|
return 0 if passed else 1
|
|
finally:
|
|
if process is not None and process.poll() is None and not args.keep_running:
|
|
process.terminate()
|
|
try:
|
|
process.wait(timeout=20)
|
|
except subprocess.TimeoutExpired:
|
|
process.kill()
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Start KoboldCpp with a GGUF and run LocalDiplomacy tool-calling smoke tests."
|
|
)
|
|
parser.add_argument("model_path", help="Path to the GGUF model to test.")
|
|
parser.add_argument("--context-size", type=int, default=40960, help="KoboldCpp context size.")
|
|
parser.add_argument("--koboldcpp-path", default="./koboldcpp.exe", help="Path to koboldcpp.exe.")
|
|
parser.add_argument("--port", type=int, default=5001, help="KoboldCpp HTTP port.")
|
|
parser.add_argument("--startup-timeout-seconds", type=float, default=240.0)
|
|
parser.add_argument("--request-timeout-seconds", type=float, default=180.0)
|
|
parser.add_argument("--keep-running", action="store_true", help="Leave the started KoboldCpp process running.")
|
|
parser.add_argument("--stop-existing", action="store_true", help="Stop existing koboldcpp.exe processes first.")
|
|
parser.add_argument(
|
|
"--kobold-arg",
|
|
action="append",
|
|
default=[],
|
|
help="Extra argument passed to KoboldCpp. Repeat for multiple args.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def start_koboldcpp(
|
|
*,
|
|
koboldcpp_path: Path,
|
|
model_path: Path,
|
|
port: int,
|
|
context_size: int,
|
|
extra_args: list[str],
|
|
stdout_path: Path,
|
|
stderr_path: Path,
|
|
) -> subprocess.Popen[Any]:
|
|
command = [
|
|
str(koboldcpp_path),
|
|
"--model",
|
|
str(model_path),
|
|
"--port",
|
|
str(port),
|
|
"--contextsize",
|
|
str(context_size),
|
|
"--jinja",
|
|
"--jinjatools",
|
|
*extra_args,
|
|
]
|
|
print("Starting KoboldCpp:")
|
|
print(" ".join(quote_arg(part) for part in command))
|
|
return subprocess.Popen(
|
|
command,
|
|
stdout=stdout_path.open("w", encoding="utf-8", errors="replace"),
|
|
stderr=stderr_path.open("w", encoding="utf-8", errors="replace"),
|
|
creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
|
|
)
|
|
|
|
|
|
def wait_for_model(base_url: str, process: subprocess.Popen[Any], timeout_seconds: float) -> str:
|
|
deadline = time.monotonic() + timeout_seconds
|
|
last_error = ""
|
|
while time.monotonic() < deadline:
|
|
if process.poll() is not None:
|
|
raise RuntimeError(f"KoboldCpp exited early with code {process.returncode}.")
|
|
try:
|
|
data = get_json(f"{base_url}/v1/models", timeout=10)
|
|
models = data.get("data") or []
|
|
if models:
|
|
return str(models[0].get("id") or "local-model")
|
|
except Exception as exc: # noqa: BLE001 - report the last startup error.
|
|
last_error = str(exc)
|
|
time.sleep(2)
|
|
raise TimeoutError(f"KoboldCpp did not become ready within {timeout_seconds}s. Last error: {last_error}")
|
|
|
|
|
|
def run_dialogue_test(base_url: str, model_id: str) -> TestResult:
|
|
data = chat(
|
|
base_url,
|
|
{
|
|
"model": model_id,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "/no_think You are Derthert, a Bannerlord lord. Answer in one sentence. Never expose reasoning.",
|
|
},
|
|
{"role": "user", "content": "/no_think Greetings. What news from the border?"},
|
|
],
|
|
"temperature": 0.3,
|
|
"max_tokens": 120,
|
|
},
|
|
)
|
|
content = first_message(data).get("content") or ""
|
|
bad_markers = ["<think>", "</think>", '"arguments"', "<tool_call"]
|
|
passed = bool(content.strip()) and not any(marker in content for marker in bad_markers)
|
|
return TestResult("dialogue_no_visible_reasoning", passed, truncate(content.strip()))
|
|
|
|
|
|
def run_auto_tool_test(base_url: str, model_id: str) -> TestResult:
|
|
data = chat(
|
|
base_url,
|
|
{
|
|
"model": model_id,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "/no_think Use tools through the API when a game action is needed. Never expose reasoning.",
|
|
},
|
|
{"role": "user", "content": "/no_think Derthert should propose peace with Battania."},
|
|
],
|
|
"tools": [action_tool_schema(["propose_peace", "declare_war"])],
|
|
"tool_choice": "auto",
|
|
"temperature": 0.0,
|
|
"max_tokens": 300,
|
|
},
|
|
)
|
|
return validate_tool_call("auto_tool_call", data)
|
|
|
|
|
|
def run_forced_tool_test(base_url: str, model_id: str) -> TestResult:
|
|
data = chat(
|
|
base_url,
|
|
{
|
|
"model": model_id,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "/no_think Use tools through the API when a game action is needed. Never expose reasoning.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"/no_think Use propose_game_action now: action_type propose_peace, "
|
|
"actor_id lord_derthert, target_id kingdom_battania, reason end the border raids."
|
|
),
|
|
},
|
|
],
|
|
"tools": [action_tool_schema(["accept_peace", "propose_peace", "reject_peace"])],
|
|
"tool_choice": {"type": "function", "function": {"name": "propose_game_action"}},
|
|
"temperature": 0.0,
|
|
"max_tokens": 300,
|
|
},
|
|
)
|
|
return validate_tool_call("forced_tool_call", data)
|
|
|
|
|
|
def run_agent_action_prompt_test(base_url: str, model_id: str) -> TestResult:
|
|
context = {
|
|
"campaign_id": "mock-campaign",
|
|
"player_id": "player",
|
|
"npc_id": "lord_derthert",
|
|
"npc_name": "Derthert",
|
|
"npc_kingdom_id": "kingdom_vlandia",
|
|
"player_kingdom_id": "kingdom_vlandia",
|
|
"kingdom_state": {"wars": [{"enemy_kingdom_id": "kingdom_battania", "days": 42}], "war_fatigue": 0.67},
|
|
"nearby_parties": [{"id": "party_battania_raiders", "faction_id": "kingdom_battania"}],
|
|
"nearby_settlements": [{"id": "town_sargot", "owner_kingdom_id": "kingdom_vlandia"}],
|
|
"recent_events": [{"id": "event_border_raids", "summary": "Border raids have strained Vlandia and Battania."}],
|
|
}
|
|
data = chat(
|
|
base_url,
|
|
{
|
|
"model": model_id,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"/no_think You are LocalDiplomacy's Bannerlord action planner. "
|
|
"Call the provided tool through the API. Do not answer in prose. Use only IDs present in the request."
|
|
),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"/no_think "
|
|
f"Context: {json.dumps(context, ensure_ascii=False)}\n"
|
|
"Command: Use propose_game_action now: action_type propose_peace, actor_id lord_derthert, "
|
|
"target_id kingdom_battania, reason end the border raids."
|
|
),
|
|
},
|
|
],
|
|
"tools": [action_tool_schema(["accept_peace", "propose_peace", "reject_peace"])],
|
|
"tool_choice": {"type": "function", "function": {"name": "propose_game_action"}},
|
|
"temperature": 0.0,
|
|
"max_tokens": 400,
|
|
},
|
|
)
|
|
return validate_tool_call("agent_action_prompt_tool_call", data)
|
|
|
|
|
|
def action_tool_schema(action_types: list[str]) -> dict[str, Any]:
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "propose_game_action",
|
|
"description": "Queue one Bannerlord game action proposal for mod-side validation.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"action_type": {"type": "string", "enum": action_types},
|
|
"actor_id": {"type": "string"},
|
|
"target_id": {"type": "string"},
|
|
"args": {"type": "object"},
|
|
"confidence": {"type": "number", "minimum": 0, "maximum": 1},
|
|
"reason": {"type": "string"},
|
|
"requires_player_confirmation": {"type": "boolean"},
|
|
},
|
|
"required": ["action_type", "actor_id", "target_id", "reason"],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def validate_tool_call(name: str, data: dict[str, Any]) -> TestResult:
|
|
message = first_message(data)
|
|
tool_calls = message.get("tool_calls") or []
|
|
if not tool_calls:
|
|
return TestResult(name, False, f"no tool_calls; content={truncate(str(message.get('content') or ''))}")
|
|
|
|
function = (tool_calls[0] or {}).get("function") or {}
|
|
arguments_raw = function.get("arguments") or "{}"
|
|
try:
|
|
arguments = json.loads(arguments_raw)
|
|
except json.JSONDecodeError:
|
|
return TestResult(name, False, f"invalid arguments JSON: {truncate(arguments_raw)}")
|
|
|
|
passed = (
|
|
function.get("name") == "propose_game_action"
|
|
and arguments.get("action_type") == "propose_peace"
|
|
and bool(arguments.get("actor_id"))
|
|
and bool(arguments.get("target_id"))
|
|
)
|
|
return TestResult(name, passed, json.dumps(arguments, ensure_ascii=False))
|
|
|
|
|
|
def chat(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return post_json(f"{base_url}/v1/chat/completions", payload, timeout=180)
|
|
|
|
|
|
def first_message(data: dict[str, Any]) -> dict[str, Any]:
|
|
choices = data.get("choices") or [{}]
|
|
choice = choices[0] or {}
|
|
return choice.get("message") or {}
|
|
|
|
|
|
def get_json(url: str, timeout: float) -> dict[str, Any]:
|
|
with urllib.request.urlopen(url, timeout=timeout) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
|
|
|
|
def post_json(url: str, payload: dict[str, Any], timeout: float) -> dict[str, Any]:
|
|
body = json.dumps(payload).encode("utf-8")
|
|
request = urllib.request.Request(
|
|
url,
|
|
data=body,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=timeout) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except urllib.error.HTTPError as exc:
|
|
error_body = exc.read().decode("utf-8", errors="replace")
|
|
raise RuntimeError(f"HTTP {exc.code}: {error_body}") from exc
|
|
|
|
|
|
def stop_existing_koboldcpp() -> None:
|
|
if sys.platform != "win32":
|
|
return
|
|
subprocess.run(
|
|
[
|
|
"powershell",
|
|
"-NoProfile",
|
|
"-Command",
|
|
"Get-Process koboldcpp -ErrorAction SilentlyContinue | Stop-Process -Force",
|
|
],
|
|
check=False,
|
|
)
|
|
|
|
|
|
def quote_arg(value: str) -> str:
|
|
if " " in value or "\t" in value:
|
|
return f'"{value}"'
|
|
return value
|
|
|
|
|
|
def truncate(value: str, limit: int = 220) -> str:
|
|
value = " ".join(value.split())
|
|
return value if len(value) <= limit else f"{value[: limit - 3]}..."
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|