Files
HRiggs 454bb57484
Build Release EXE / build-windows-exe (release) Successful in 1m2s
feat: deepseek
2026-06-08 23:41:46 -04:00

433 lines
15 KiB
Python

import pytest
import asyncio
import itertools
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
from traderai.memory import MemoryStore
class EmptyTools:
schemas = []
@property
def pending_actions(self):
return {}
class WakeTools(EmptyTools):
def __init__(self):
self.calls = []
async def execute(self, name, arguments):
self.calls.append((name, arguments))
return {"count": 1, "notifications": [{"message": "Buyer replied"}]}
class WakeAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", WakeTools(), memory=memory)
self.responses = [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "check_uex_notifications",
"arguments": {},
}
}
],
}
},
{"message": {"role": "assistant", "content": "I checked notifications: Buyer replied."}},
]
async def ensure_available(self):
return None
async def _ollama_chat(self, *args, **kwargs):
return self.responses.pop(0)
class TitleAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
async def ensure_available(self):
return None
async def _generate_chat_title(self, first_message):
return "UEX Market Check"
async def _ollama_chat(self, *args, **kwargs):
return {"message": {"role": "assistant", "content": "Done"}}
class ImageCaptureAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
self.last_messages = None
async def ensure_available(self):
return None
async def _chat_once(self, query="", messages=None, **kwargs):
self.last_messages = messages
return {"message": {"role": "assistant", "content": "Seen"}}
class SlowToolTools(EmptyTools):
schemas = [
{
"type": "function",
"function": {
"name": "slow_tool",
"description": "Slow fake tool.",
"parameters": {"type": "object", "properties": {}},
},
}
]
def __init__(self):
self.calls = 0
async def execute(self, name, arguments):
self.calls += 1
await asyncio.sleep(0.01)
return {"status": "ok", "value": "slow result"}
class SlowStreamingAgent(OllamaAgent):
def __init__(self, memory):
super().__init__("http://127.0.0.1:1", "missing-model", SlowToolTools(), memory=memory)
self.stream_calls = 0
async def health(self):
return {"online": True, "model": "test", "base_url": self.base_url}
async def _ollama_chat_stream(self, *args, **kwargs):
self.stream_calls += 1
if self.stream_calls == 1:
yield {
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
},
"done": True,
}
return
yield {"message": {"role": "assistant", "content": ""}, "done": True}
class FailingAfterToolAgent(SlowStreamingAgent):
async def _ollama_chat_stream(self, *args, **kwargs):
self.stream_calls += 1
if self.stream_calls == 1:
yield {
"message": {
"role": "assistant",
"content": "",
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
},
"done": True,
}
return
raise RuntimeError("ollama timed out")
@pytest.mark.asyncio
async def test_chat_events_warns_when_ollama_offline():
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
events = []
async for event in agent.chat_events("hello"):
events.append(event)
assert events[0]["type"] == "warning"
assert "Ollama is offline" in events[0]["message"]
assert events[-1]["type"] == "done"
def test_runtime_context_uses_previous_interaction_not_current_message(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
memory.add_conversation("assistant", "older answer")
previous = memory.last_interaction()
assert previous is not None
memory.add_conversation("user", "current question")
current = memory.last_interaction()
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
context = agent._runtime_context("current question", previous_interaction=previous)
assert f"Previous interaction before this message: {previous['created_at']}" in context
assert f"Previous interaction before this message: {current['created_at']}" not in context
assert "Current local date/time:" in context
def test_runtime_context_includes_uex_user_identity(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
memory.set_profile(
"uex_user",
{
"username": "pilot_hudson",
"name": "Hudson",
"email": "hudson@example.test",
"timezone": "America/New_York",
"specializations": "trading,hauling",
},
)
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
context = agent._runtime_context("")
assert "You are speaking with UEX user pilot_hudson (Hudson)." in context
assert "timezone: America/New_York" in context
assert "specializations: trading,hauling" in context
assert "hudson@example.test" not in context
def test_stream_metrics_include_reading_and_writing_rates():
metrics = OllamaAgent._stream_metrics(
{
"prompt_eval_count": 20,
"prompt_eval_duration": 2_000_000_000,
"eval_count": 30,
"eval_duration": 3_000_000_000,
}
)
assert metrics["reading_tokens"] == 20
assert metrics["reading_tokens_per_second"] == 10
assert metrics["writing_tokens"] == 30
assert metrics["writing_tokens_per_second"] == 10
def test_system_prompt_prefers_current_marketplace_data():
assert "open/current" in SYSTEM_PROMPT
assert "Do not use historical sale data" in SYSTEM_PROMPT
assert "aUEC/UEC credits" in SYSTEM_PROMPT
assert "never real-world dollars" in SYSTEM_PROMPT
def test_ollama_options_include_num_ctx():
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
assert agent._ollama_options() == {"num_ctx": 64000}
def test_deepseek_tool_rounds_are_not_capped_at_ten():
agent = OllamaAgent("https://api.deepseek.com", "deepseek-v4-flash", EmptyTools(), provider="deepseek", api_key="test")
rounds = list(itertools.islice(agent._tool_rounds(), 12))
assert len(rounds) == 12
def test_plan_draft_normalization_extracts_json_and_defaults():
seed = {"title": "Wikelo Polaris", "objective": "Find parts", "kind": "buying", "constraints": {}, "items": []}
raw = 'draft:\n{"title":"Wikelo Polaris Parts","objective":"Find and draft deals for the parts below","kind":"buying","cadence":"0 */3 * * *","constraints":{"message_tone":"casual","instructions":"Prioritize cheap listings first."},"items":[{"item_name":"RCMBNT-RGL-1","desired_quantity":2}]}'
draft = OllamaAgent._normalize_plan_draft(raw, seed)
assert draft["title"] == "Wikelo Polaris Parts"
assert draft["cadence"] == "0 */3 * * *"
assert draft["constraints"]["message_tone"] == "casual"
assert draft["items"][0]["item_name"] == "RCMBNT-RGL-1"
assert draft["items"][0]["desired_quantity"] == 2
def test_plan_draft_heuristic_fills_in_basic_instructions():
seed = {"title": "Watch open negotiations", "objective": "", "kind": "custom", "constraints": {}, "items": []}
draft = OllamaAgent._heuristic_plan_draft(seed)
assert draft["kind"] == "custom"
assert draft["cadence"] == "0 */4 * * *"
assert "summarize" in draft["constraints"]["instructions"].casefold()
assert draft["constraints"]["message_tone"] == "friendly and direct"
def test_codex_prompt_mentions_tools_and_images(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), memory=memory, provider="codex")
prompt = agent._codex_cli_prompt(
"check listing",
[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "Look at this",
"images": ["ZmFrZQ=="],
"image_content_types": ["image/png"],
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "search_marketplace_listings", "arguments": "{\"commodity\":\"gold\"}"},
}
],
},
{
"role": "tool",
"tool_name": "search_marketplace_listings",
"tool_call_id": "call_123",
"content": "{\"ok\":true}",
},
],
)
assert "Available tools" in prompt
assert "attached images: 1" in prompt
assert "search_marketplace_listings" in prompt
assert "tool search_marketplace_listings" in prompt
def test_deepseek_openai_messages_include_reasoning_content_for_tool_turns():
agent = OllamaAgent("https://api.deepseek.com", "deepseek-v4-flash", EmptyTools(), provider="deepseek", api_key="test")
messages = agent._openai_messages(
"check listing",
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "Check this listing"},
{
"role": "assistant",
"content": "",
"reasoning_content": "I should check the current listing first.",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "search_marketplace_listings", "arguments": "{\"query\":\"panel\"}"},
}
],
},
{"role": "tool", "tool_name": "search_marketplace_listings", "tool_call_id": "call_123", "content": "{\"ok\":true}"},
],
)
assistant_turn = next(message for message in messages if message["role"] == "assistant")
assert assistant_turn["reasoning_content"] == "I should check the current listing first."
def test_codex_structured_response_extracts_text_and_tool_calls():
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), provider="codex")
result = agent._codex_structured_response(
{
"kind": "tool_call",
"message": "",
"tool_name": "search_marketplace_listings",
"arguments_json": "{\"commodity\":\"gold\"}",
}
)
assert result["message"]["content"] == ""
assert result["message"]["tool_calls"] == [
{
"id": result["message"]["tool_calls"][0]["id"],
"type": "function",
"function": {
"name": "search_marketplace_listings",
"arguments": "{\"commodity\":\"gold\"}",
},
}
]
def test_parse_codex_exec_output_reads_final_json():
agent = OllamaAgent("codex", "gpt-5.3-codex", EmptyTools(), provider="codex")
result = agent._parse_codex_exec_output(
{
"returncode": 0,
"stdout": "",
"stderr": "",
"events": [
{"type": "thread.started", "thread_id": "abc"},
{"type": "item.completed", "item": {"type": "agent_message", "text": "{\"kind\":\"final\",\"message\":\"hello\",\"tool_name\":\"\",\"arguments_json\":\"{}\"}"}},
{"type": "turn.completed"},
],
}
)
assert result == {"kind": "final", "message": "hello", "tool_name": "", "arguments_json": "{}"}
@pytest.mark.asyncio
async def test_wake_response_executes_tool_calls(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = WakeAgent(memory)
response = await agent.generate_wake_response("Scheduled wake job fired. Check notifications.")
assert response == "I checked notifications: Buyer replied."
assert agent.tools.calls == [("check_uex_notifications", {})]
wake_rows = memory.recent_conversation(thread_id="wake")
assert wake_rows[-1]["content"] == "I checked notifications: Buyer replied."
@pytest.mark.asyncio
async def test_first_chat_message_generates_thread_title(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
thread = memory.create_thread()
agent = TitleAgent(memory)
result = await agent.chat("Check UEX market listings", thread_id=thread["id"])
assert result["message"] == "Done"
assert memory.get_thread(thread["id"])["title"] == "UEX Market Check"
@pytest.mark.asyncio
async def test_chat_includes_pasted_images_and_memory_note(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = ImageCaptureAgent(memory)
result = await agent.chat(
"",
images=[{"name": "listing.png", "content_type": "image/png", "image_data": "ZmFrZS1pbWFnZQ=="}],
)
assert result["message"] == "Seen"
user_message = next(message for message in reversed(agent.last_messages) if message.get("role") == "user")
assert user_message["images"] == ["ZmFrZS1pbWFnZQ=="]
assert user_message["content"] == "Please analyze the attached image."
assert "[Attached 1 pasted image]" in memory.recent_conversation()[-2]["content"]
@pytest.mark.asyncio
async def test_chat_events_returns_fallback_after_slow_tool_and_empty_final_response(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = SlowStreamingAgent(memory)
events = [event async for event in agent.chat_events("run a slow tool")]
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
assert agent.tools.calls == 1
assert "I completed the tool call" in text
assert "slow result" in text
assert events[-1]["type"] == "done"
@pytest.mark.asyncio
async def test_chat_events_returns_tool_result_when_model_fails_after_slow_tool(tmp_path):
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
agent = FailingAfterToolAgent(memory)
events = [event async for event in agent.chat_events("run a slow tool")]
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
assert agent.tools.calls == 1
assert "local model stopped after the tool call" in text
assert "slow result" in text
assert events[-1]["type"] == "done"