288 lines
9.5 KiB
Python
288 lines
9.5 KiB
Python
import pytest
|
|
import asyncio
|
|
|
|
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
|
|
from traderai.memory import MemoryStore
|
|
|
|
|
|
class EmptyTools:
|
|
schemas = []
|
|
|
|
@property
|
|
def pending_actions(self):
|
|
return {}
|
|
|
|
|
|
class WakeTools(EmptyTools):
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
async def execute(self, name, arguments):
|
|
self.calls.append((name, arguments))
|
|
return {"count": 1, "notifications": [{"message": "Buyer replied"}]}
|
|
|
|
|
|
class WakeAgent(OllamaAgent):
|
|
def __init__(self, memory):
|
|
super().__init__("http://127.0.0.1:1", "missing-model", WakeTools(), memory=memory)
|
|
self.responses = [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"function": {
|
|
"name": "check_uex_notifications",
|
|
"arguments": {},
|
|
}
|
|
}
|
|
],
|
|
}
|
|
},
|
|
{"message": {"role": "assistant", "content": "I checked notifications: Buyer replied."}},
|
|
]
|
|
|
|
async def ensure_available(self):
|
|
return None
|
|
|
|
async def _ollama_chat(self, *args, **kwargs):
|
|
return self.responses.pop(0)
|
|
|
|
|
|
class TitleAgent(OllamaAgent):
|
|
def __init__(self, memory):
|
|
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
|
|
async def ensure_available(self):
|
|
return None
|
|
|
|
async def _generate_chat_title(self, first_message):
|
|
return "UEX Market Check"
|
|
|
|
async def _ollama_chat(self, *args, **kwargs):
|
|
return {"message": {"role": "assistant", "content": "Done"}}
|
|
|
|
|
|
class ImageCaptureAgent(OllamaAgent):
|
|
def __init__(self, memory):
|
|
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
self.last_messages = None
|
|
|
|
async def ensure_available(self):
|
|
return None
|
|
|
|
async def _chat_once(self, query="", messages=None, **kwargs):
|
|
self.last_messages = messages
|
|
return {"message": {"role": "assistant", "content": "Seen"}}
|
|
|
|
|
|
class SlowToolTools(EmptyTools):
|
|
schemas = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "slow_tool",
|
|
"description": "Slow fake tool.",
|
|
"parameters": {"type": "object", "properties": {}},
|
|
},
|
|
}
|
|
]
|
|
|
|
def __init__(self):
|
|
self.calls = 0
|
|
|
|
async def execute(self, name, arguments):
|
|
self.calls += 1
|
|
await asyncio.sleep(0.01)
|
|
return {"status": "ok", "value": "slow result"}
|
|
|
|
|
|
class SlowStreamingAgent(OllamaAgent):
|
|
def __init__(self, memory):
|
|
super().__init__("http://127.0.0.1:1", "missing-model", SlowToolTools(), memory=memory)
|
|
self.stream_calls = 0
|
|
|
|
async def health(self):
|
|
return {"online": True, "model": "test", "base_url": self.base_url}
|
|
|
|
async def _ollama_chat_stream(self, *args, **kwargs):
|
|
self.stream_calls += 1
|
|
if self.stream_calls == 1:
|
|
yield {
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
|
},
|
|
"done": True,
|
|
}
|
|
return
|
|
yield {"message": {"role": "assistant", "content": ""}, "done": True}
|
|
|
|
|
|
class FailingAfterToolAgent(SlowStreamingAgent):
|
|
async def _ollama_chat_stream(self, *args, **kwargs):
|
|
self.stream_calls += 1
|
|
if self.stream_calls == 1:
|
|
yield {
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
|
},
|
|
"done": True,
|
|
}
|
|
return
|
|
raise RuntimeError("ollama timed out")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_events_warns_when_ollama_offline():
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
|
|
events = []
|
|
async for event in agent.chat_events("hello"):
|
|
events.append(event)
|
|
|
|
assert events[0]["type"] == "warning"
|
|
assert "Ollama is offline" in events[0]["message"]
|
|
assert events[-1]["type"] == "done"
|
|
|
|
|
|
def test_runtime_context_uses_previous_interaction_not_current_message(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
memory.add_conversation("assistant", "older answer")
|
|
previous = memory.last_interaction()
|
|
assert previous is not None
|
|
|
|
memory.add_conversation("user", "current question")
|
|
current = memory.last_interaction()
|
|
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
context = agent._runtime_context("current question", previous_interaction=previous)
|
|
|
|
assert f"Previous interaction before this message: {previous['created_at']}" in context
|
|
assert f"Previous interaction before this message: {current['created_at']}" not in context
|
|
assert "Current local date/time:" in context
|
|
|
|
|
|
def test_runtime_context_includes_uex_user_identity(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
memory.set_profile(
|
|
"uex_user",
|
|
{
|
|
"username": "pilot_hudson",
|
|
"name": "Hudson",
|
|
"email": "hudson@example.test",
|
|
"timezone": "America/New_York",
|
|
"specializations": "trading,hauling",
|
|
},
|
|
)
|
|
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
context = agent._runtime_context("")
|
|
|
|
assert "You are speaking with UEX user pilot_hudson (Hudson)." in context
|
|
assert "timezone: America/New_York" in context
|
|
assert "specializations: trading,hauling" in context
|
|
assert "hudson@example.test" not in context
|
|
|
|
|
|
def test_stream_metrics_include_reading_and_writing_rates():
|
|
metrics = OllamaAgent._stream_metrics(
|
|
{
|
|
"prompt_eval_count": 20,
|
|
"prompt_eval_duration": 2_000_000_000,
|
|
"eval_count": 30,
|
|
"eval_duration": 3_000_000_000,
|
|
}
|
|
)
|
|
|
|
assert metrics["reading_tokens"] == 20
|
|
assert metrics["reading_tokens_per_second"] == 10
|
|
assert metrics["writing_tokens"] == 30
|
|
assert metrics["writing_tokens_per_second"] == 10
|
|
|
|
|
|
def test_system_prompt_prefers_current_marketplace_data():
|
|
assert "open/current" in SYSTEM_PROMPT
|
|
assert "Do not use historical sale data" in SYSTEM_PROMPT
|
|
assert "aUEC/UEC credits" in SYSTEM_PROMPT
|
|
assert "never real-world dollars" in SYSTEM_PROMPT
|
|
|
|
|
|
def test_ollama_options_include_num_ctx():
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
|
|
|
|
assert agent._ollama_options() == {"num_ctx": 64000}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wake_response_executes_tool_calls(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
agent = WakeAgent(memory)
|
|
|
|
response = await agent.generate_wake_response("Scheduled wake job fired. Check notifications.")
|
|
|
|
assert response == "I checked notifications: Buyer replied."
|
|
assert agent.tools.calls == [("check_uex_notifications", {})]
|
|
wake_rows = memory.recent_conversation(thread_id="wake")
|
|
assert wake_rows[-1]["content"] == "I checked notifications: Buyer replied."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_first_chat_message_generates_thread_title(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
thread = memory.create_thread()
|
|
agent = TitleAgent(memory)
|
|
|
|
result = await agent.chat("Check UEX market listings", thread_id=thread["id"])
|
|
|
|
assert result["message"] == "Done"
|
|
assert memory.get_thread(thread["id"])["title"] == "UEX Market Check"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_includes_pasted_images_and_memory_note(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
agent = ImageCaptureAgent(memory)
|
|
|
|
result = await agent.chat(
|
|
"",
|
|
images=[{"name": "listing.png", "content_type": "image/png", "image_data": "ZmFrZS1pbWFnZQ=="}],
|
|
)
|
|
|
|
assert result["message"] == "Seen"
|
|
user_message = next(message for message in reversed(agent.last_messages) if message.get("role") == "user")
|
|
assert user_message["images"] == ["ZmFrZS1pbWFnZQ=="]
|
|
assert user_message["content"] == "Please analyze the attached image."
|
|
assert "[Attached 1 pasted image]" in memory.recent_conversation()[-2]["content"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_events_returns_fallback_after_slow_tool_and_empty_final_response(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
agent = SlowStreamingAgent(memory)
|
|
|
|
events = [event async for event in agent.chat_events("run a slow tool")]
|
|
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
|
|
|
assert agent.tools.calls == 1
|
|
assert "I completed the tool call" in text
|
|
assert "slow result" in text
|
|
assert events[-1]["type"] == "done"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_events_returns_tool_result_when_model_fails_after_slow_tool(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
agent = FailingAfterToolAgent(memory)
|
|
|
|
events = [event async for event in agent.chat_events("run a slow tool")]
|
|
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
|
|
|
assert agent.tools.calls == 1
|
|
assert "local model stopped after the tool call" in text
|
|
assert "slow result" in text
|
|
assert events[-1]["type"] == "done"
|