Files
TraderAI/tests/test_agent.py
T
HRiggs 00cf6f8747
Build Release EXE / build-windows-exe (release) Successful in 58s
feat: infrance
2026-06-08 20:28:06 -04:00

372 lines
12 KiB
Python

import pytest
import asyncio
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
from traderai.memory import MemoryStore
class EmptyTools:
schemas = []
@property
def pending_actions(self):
return {}
class WakeTools(EmptyTools):
def __init__(self):
self.calls = []
async def execute(self, name, arguments):
self.calls.append((name, arguments))
return {"count": 1, "notifications": [{"message": "Buyer replied"}]}
class WakeAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", WakeTools(), memory=memory)
self.responses = [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "check_uex_notifications",
"arguments": {},
}
}
],
}
},
{"message": {"role": "assistant", "content": "I checked notifications: Buyer replied."}},
]
async def ensure_available(self):
return None
async def _ollama_chat(self, *args, **kwargs):
return self.responses.pop(0)
class TitleAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
async def ensure_available(self):
return None
async def _generate_chat_title(self, first_message):
return "UEX Market Check"
async def _ollama_chat(self, *args, **kwargs):
return {"message": {"role": "assistant", "content": "Done"}}
class ImageCaptureAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
self.last_messages = None
async def ensure_available(self):
return None
async def _chat_once(self, query="", messages=None, **kwargs):
self.last_messages = messages
return {"message": {"role": "assistant", "content": "Seen"}}
class SlowToolTools(EmptyTools):
schemas = [
{
"type": "function",
"function": {
"name": "slow_tool",
"description": "Slow fake tool.",
"parameters": {"type": "object", "properties": {}},
},
}
]
def __init__(self):
self.calls = 0
async def execute(self, name, arguments):
self.calls += 1
await asyncio.sleep(0.01)
return {"status": "ok", "value": "slow result"}
class SlowStreamingAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", SlowToolTools(), memory=memory)
self.stream_calls = 0
async def health(self):
return {"online": True, "model": "test", "base_url": self.base_url}
async def _ollama_chat_stream(self, *args, **kwargs):
self.stream_calls += 1
if self.stream_calls == 1:
yield {
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
},
"done": True,
}
return
yield {"message": {"role": "assistant", "content": ""}, "done": True}
class FailingAfterToolAgent(SlowStreamingAgent):
async def _ollama_chat_stream(self, *args, **kwargs):
self.stream_calls += 1
if self.stream_calls == 1:
yield {
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
},
"done": True,
}
return
raise RuntimeError("ollama timed out")
@pytest.mark.asyncio
async def test_chat_events_warns_when_ollama_offline():
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
events = []
async for event in agent.chat_events("hello"):
events.append(event)
assert events[0]["type"] == "warning"
assert "Ollama is offline" in events[0]["message"]
assert events[-1]["type"] == "done"
def test_runtime_context_uses_previous_interaction_not_current_message(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
memory.add_conversation("assistant", "older answer")
previous = memory.last_interaction()
assert previous is not None
memory.add_conversation("user", "current question")
current = memory.last_interaction()
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
context = agent._runtime_context("current question", previous_interaction=previous)
assert f"Previous interaction before this message: {previous['created_at']}" in context
assert f"Previous interaction before this message: {current['created_at']}" not in context
assert "Current local date/time:" in context
def test_runtime_context_includes_uex_user_identity(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
memory.set_profile(
"uex_user",
{
"username": "pilot_hudson",
"name": "Hudson",
"email": "hudson@example.test",
"timezone": "America/New_York",
"specializations": "trading,hauling",
},
)
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
context = agent._runtime_context("")
assert "You are speaking with UEX user pilot_hudson (Hudson)." in context
assert "timezone: America/New_York" in context
assert "specializations: trading,hauling" in context
assert "hudson@example.test" not in context
def test_stream_metrics_include_reading_and_writing_rates():
metrics = OllamaAgent._stream_metrics(
{
"prompt_eval_count": 20,
"prompt_eval_duration": 2_000_000_000,
"eval_count": 30,
"eval_duration": 3_000_000_000,
}
)
assert metrics["reading_tokens"] == 20
assert metrics["reading_tokens_per_second"] == 10
assert metrics["writing_tokens"] == 30
assert metrics["writing_tokens_per_second"] == 10
def test_system_prompt_prefers_current_marketplace_data():
assert "open/current" in SYSTEM_PROMPT
assert "Do not use historical sale data" in SYSTEM_PROMPT
assert "aUEC/UEC credits" in SYSTEM_PROMPT
assert "never real-world dollars" in SYSTEM_PROMPT
def test_ollama_options_include_num_ctx():
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
assert agent._ollama_options() == {"num_ctx": 64000}
def test_codex_prompt_mentions_tools_and_images(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), memory=memory, provider="codex")
prompt = agent._codex_cli_prompt(
"check listing",
[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "Look at this",
"images": ["ZmFrZQ=="],
"image_content_types": ["image/png"],
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "search_marketplace_listings", "arguments": "{\"commodity\":\"gold\"}"},
}
],
},
{
"role": "tool",
"tool_name": "search_marketplace_listings",
"tool_call_id": "call_123",
"content": "{\"ok\":true}",
},
],
)
assert "Available tools" in prompt
assert "attached images: 1" in prompt
assert "search_marketplace_listings" in prompt
assert "tool search_marketplace_listings" in prompt
def test_codex_structured_response_extracts_text_and_tool_calls():
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), provider="codex")
result = agent._codex_structured_response(
{
"kind": "tool_call",
"message": "",
"tool_name": "search_marketplace_listings",
"arguments_json": "{\"commodity\":\"gold\"}",
}
)
assert result["message"]["content"] == ""
assert result["message"]["tool_calls"] == [
{
"id": result["message"]["tool_calls"][0]["id"],
"type": "function",
"function": {
"name": "search_marketplace_listings",
"arguments": "{\"commodity\":\"gold\"}",
},
}
]
def test_parse_codex_exec_output_reads_final_json():
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), provider="codex")
result = agent._parse_codex_exec_output(
{
"returncode": 0,
"stdout": "",
"stderr": "",
"events": [
{"type": "thread.started", "thread_id": "abc"},
{"type": "item.completed", "item": {"type": "agent_message", "text": "{\"kind\":\"final\",\"message\":\"hello\",\"tool_name\":\"\",\"arguments_json\":\"{}\"}"}},
{"type": "turn.completed"},
],
}
)
assert result == {"kind": "final", "message": "hello", "tool_name": "", "arguments_json": "{}"}
@pytest.mark.asyncio
async def test_wake_response_executes_tool_calls(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = WakeAgent(memory)
response = await agent.generate_wake_response("Scheduled wake job fired. Check notifications.")
assert response == "I checked notifications: Buyer replied."
assert agent.tools.calls == [("check_uex_notifications", {})]
wake_rows = memory.recent_conversation(thread_id="wake")
assert wake_rows[-1]["content"] == "I checked notifications: Buyer replied."
@pytest.mark.asyncio
async def test_first_chat_message_generates_thread_title(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
thread = memory.create_thread()
agent = TitleAgent(memory)
result = await agent.chat("Check UEX market listings", thread_id=thread["id"])
assert result["message"] == "Done"
assert memory.get_thread(thread["id"])["title"] == "UEX Market Check"
@pytest.mark.asyncio
async def test_chat_includes_pasted_images_and_memory_note(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = ImageCaptureAgent(memory)
result = await agent.chat(
"",
images=[{"name": "listing.png", "content_type": "image/png", "image_data": "ZmFrZS1pbWFnZQ=="}],
)
assert result["message"] == "Seen"
user_message = next(message for message in reversed(agent.last_messages) if message.get("role") == "user")
assert user_message["images"] == ["ZmFrZS1pbWFnZQ=="]
assert user_message["content"] == "Please analyze the attached image."
assert "[Attached 1 pasted image]" in memory.recent_conversation()[-2]["content"]
@pytest.mark.asyncio
async def test_chat_events_returns_fallback_after_slow_tool_and_empty_final_response(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = SlowStreamingAgent(memory)
events = [event async for event in agent.chat_events("run a slow tool")]
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
assert agent.tools.calls == 1
assert "I completed the tool call" in text
assert "slow result" in text
assert events[-1]["type"] == "done"
@pytest.mark.asyncio
async def test_chat_events_returns_tool_result_when_model_fails_after_slow_tool(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = FailingAfterToolAgent(memory)
events = [event async for event in agent.chat_events("run a slow tool")]
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
assert agent.tools.calls == 1
assert "local model stopped after the tool call" in text
assert "slow result" in text
assert events[-1]["type"] == "done"