feat: chat sidebar and inbox, feat: saved chats, fix: wake jobs, fix: sandbox sends, ux: negotiation replies and draft box
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
|
||||
from traderai.memory import MemoryStore
|
||||
@@ -12,6 +13,117 @@ class EmptyTools:
|
||||
return {}
|
||||
|
||||
|
||||
class WakeTools(EmptyTools):
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def execute(self, name, arguments):
|
||||
self.calls.append((name, arguments))
|
||||
return {"count": 1, "notifications": [{"message": "Buyer replied"}]}
|
||||
|
||||
|
||||
class WakeAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", WakeTools(), memory=memory)
|
||||
self.responses = [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "check_uex_notifications",
|
||||
"arguments": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
{"message": {"role": "assistant", "content": "I checked notifications: Buyer replied."}},
|
||||
]
|
||||
|
||||
async def ensure_available(self):
|
||||
return None
|
||||
|
||||
async def _ollama_chat(self, *args, **kwargs):
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
class TitleAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
||||
|
||||
async def ensure_available(self):
|
||||
return None
|
||||
|
||||
async def _generate_chat_title(self, first_message):
|
||||
return "UEX Market Check"
|
||||
|
||||
async def _ollama_chat(self, *args, **kwargs):
|
||||
return {"message": {"role": "assistant", "content": "Done"}}
|
||||
|
||||
|
||||
class SlowToolTools(EmptyTools):
|
||||
schemas = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slow_tool",
|
||||
"description": "Slow fake tool.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def execute(self, name, arguments):
|
||||
self.calls += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return {"status": "ok", "value": "slow result"}
|
||||
|
||||
|
||||
class SlowStreamingAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", SlowToolTools(), memory=memory)
|
||||
self.stream_calls = 0
|
||||
|
||||
async def health(self):
|
||||
return {"online": True, "model": "test", "base_url": self.base_url}
|
||||
|
||||
async def _ollama_chat_stream(self, *args, **kwargs):
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
||||
},
|
||||
"done": True,
|
||||
}
|
||||
return
|
||||
yield {"message": {"role": "assistant", "content": ""}, "done": True}
|
||||
|
||||
|
||||
class FailingAfterToolAgent(SlowStreamingAgent):
|
||||
async def _ollama_chat_stream(self, *args, **kwargs):
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
||||
},
|
||||
"done": True,
|
||||
}
|
||||
return
|
||||
raise RuntimeError("ollama timed out")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_warns_when_ollama_offline():
|
||||
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
|
||||
@@ -90,3 +202,56 @@ def test_ollama_options_include_num_ctx():
|
||||
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
|
||||
|
||||
assert agent._ollama_options() == {"num_ctx": 64000}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake_response_executes_tool_calls(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = WakeAgent(memory)
|
||||
|
||||
response = await agent.generate_wake_response("Scheduled wake job fired. Check notifications.")
|
||||
|
||||
assert response == "I checked notifications: Buyer replied."
|
||||
assert agent.tools.calls == [("check_uex_notifications", {})]
|
||||
wake_rows = memory.recent_conversation(thread_id="wake")
|
||||
assert wake_rows[-1]["content"] == "I checked notifications: Buyer replied."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_chat_message_generates_thread_title(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
thread = memory.create_thread()
|
||||
agent = TitleAgent(memory)
|
||||
|
||||
result = await agent.chat("Check UEX market listings", thread_id=thread["id"])
|
||||
|
||||
assert result["message"] == "Done"
|
||||
assert memory.get_thread(thread["id"])["title"] == "UEX Market Check"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_returns_fallback_after_slow_tool_and_empty_final_response(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = SlowStreamingAgent(memory)
|
||||
|
||||
events = [event async for event in agent.chat_events("run a slow tool")]
|
||||
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
||||
|
||||
assert agent.tools.calls == 1
|
||||
assert "I completed the tool call" in text
|
||||
assert "slow result" in text
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_returns_tool_result_when_model_fails_after_slow_tool(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = FailingAfterToolAgent(memory)
|
||||
|
||||
events = [event async for event in agent.chat_events("run a slow tool")]
|
||||
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
||||
|
||||
assert agent.tools.calls == 1
|
||||
assert "local model stopped after the tool call" in text
|
||||
assert "slow result" in text
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
Reference in New Issue
Block a user