feat: chat sidebar and inbox, feat: saved chats, fix: wake jobs, fix: sandbox sends, ux: negotiation replies and draft box
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
|
||||
from traderai.memory import MemoryStore
|
||||
@@ -12,6 +13,117 @@ class EmptyTools:
|
||||
return {}
|
||||
|
||||
|
||||
class WakeTools(EmptyTools):
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def execute(self, name, arguments):
|
||||
self.calls.append((name, arguments))
|
||||
return {"count": 1, "notifications": [{"message": "Buyer replied"}]}
|
||||
|
||||
|
||||
class WakeAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", WakeTools(), memory=memory)
|
||||
self.responses = [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "check_uex_notifications",
|
||||
"arguments": {},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
{"message": {"role": "assistant", "content": "I checked notifications: Buyer replied."}},
|
||||
]
|
||||
|
||||
async def ensure_available(self):
|
||||
return None
|
||||
|
||||
async def _ollama_chat(self, *args, **kwargs):
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
class TitleAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
||||
|
||||
async def ensure_available(self):
|
||||
return None
|
||||
|
||||
async def _generate_chat_title(self, first_message):
|
||||
return "UEX Market Check"
|
||||
|
||||
async def _ollama_chat(self, *args, **kwargs):
|
||||
return {"message": {"role": "assistant", "content": "Done"}}
|
||||
|
||||
|
||||
class SlowToolTools(EmptyTools):
|
||||
schemas = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "slow_tool",
|
||||
"description": "Slow fake tool.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def execute(self, name, arguments):
|
||||
self.calls += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return {"status": "ok", "value": "slow result"}
|
||||
|
||||
|
||||
class SlowStreamingAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", SlowToolTools(), memory=memory)
|
||||
self.stream_calls = 0
|
||||
|
||||
async def health(self):
|
||||
return {"online": True, "model": "test", "base_url": self.base_url}
|
||||
|
||||
async def _ollama_chat_stream(self, *args, **kwargs):
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
||||
},
|
||||
"done": True,
|
||||
}
|
||||
return
|
||||
yield {"message": {"role": "assistant", "content": ""}, "done": True}
|
||||
|
||||
|
||||
class FailingAfterToolAgent(SlowStreamingAgent):
|
||||
async def _ollama_chat_stream(self, *args, **kwargs):
|
||||
self.stream_calls += 1
|
||||
if self.stream_calls == 1:
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"function": {"name": "slow_tool", "arguments": {}}}],
|
||||
},
|
||||
"done": True,
|
||||
}
|
||||
return
|
||||
raise RuntimeError("ollama timed out")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_warns_when_ollama_offline():
|
||||
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
|
||||
@@ -90,3 +202,56 @@ def test_ollama_options_include_num_ctx():
|
||||
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
|
||||
|
||||
assert agent._ollama_options() == {"num_ctx": 64000}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake_response_executes_tool_calls(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = WakeAgent(memory)
|
||||
|
||||
response = await agent.generate_wake_response("Scheduled wake job fired. Check notifications.")
|
||||
|
||||
assert response == "I checked notifications: Buyer replied."
|
||||
assert agent.tools.calls == [("check_uex_notifications", {})]
|
||||
wake_rows = memory.recent_conversation(thread_id="wake")
|
||||
assert wake_rows[-1]["content"] == "I checked notifications: Buyer replied."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_chat_message_generates_thread_title(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
thread = memory.create_thread()
|
||||
agent = TitleAgent(memory)
|
||||
|
||||
result = await agent.chat("Check UEX market listings", thread_id=thread["id"])
|
||||
|
||||
assert result["message"] == "Done"
|
||||
assert memory.get_thread(thread["id"])["title"] == "UEX Market Check"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_returns_fallback_after_slow_tool_and_empty_final_response(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = SlowStreamingAgent(memory)
|
||||
|
||||
events = [event async for event in agent.chat_events("run a slow tool")]
|
||||
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
||||
|
||||
assert agent.tools.calls == 1
|
||||
assert "I completed the tool call" in text
|
||||
assert "slow result" in text
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_events_returns_tool_result_when_model_fails_after_slow_tool(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
agent = FailingAfterToolAgent(memory)
|
||||
|
||||
events = [event async for event in agent.chat_events("run a slow tool")]
|
||||
text = "".join(event.get("content", "") for event in events if event["type"] == "token")
|
||||
|
||||
assert agent.tools.calls == 1
|
||||
assert "local model stopped after the tool call" in text
|
||||
assert "slow result" in text
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
@@ -25,3 +25,33 @@ def test_memory_store_clear_selected_sections(tmp_path):
|
||||
assert snapshot["memories"] == []
|
||||
assert snapshot["conversations"] == []
|
||||
assert snapshot["profile"][0]["key"] == "configured_name"
|
||||
|
||||
|
||||
def test_memory_store_separates_chat_threads_but_keeps_shared_memories(tmp_path):
|
||||
store = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
first = store.create_thread("First")
|
||||
second = store.create_thread("Second")
|
||||
store.add_conversation("user", "first thread message", first["id"])
|
||||
store.add_conversation("user", "second thread message", second["id"])
|
||||
store.remember("preference", "Shared trading preference", importance=5)
|
||||
|
||||
first_rows = store.recent_conversation(thread_id=first["id"])
|
||||
second_rows = store.recent_conversation(thread_id=second["id"])
|
||||
|
||||
assert [row["content"] for row in first_rows] == ["first thread message"]
|
||||
assert [row["content"] for row in second_rows] == ["second thread message"]
|
||||
assert store.recall("trading preference")[0]["content"] == "Shared trading preference"
|
||||
|
||||
|
||||
def test_memory_store_renames_threads_and_deletes_outbox_items(tmp_path):
|
||||
store = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
thread = store.create_thread("New chat")
|
||||
store.add_outbox("Wake job result")
|
||||
inbox_id = store.list_outbox()[0]["id"]
|
||||
|
||||
renamed = store.rename_thread(thread["id"], " Market Check ")
|
||||
deleted = store.delete_outbox(inbox_id)
|
||||
|
||||
assert renamed["title"] == "Market Check"
|
||||
assert deleted is True
|
||||
assert store.list_outbox() == []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from traderai.agent import OllamaAgent
|
||||
from traderai.memory import MemoryStore
|
||||
from traderai.scheduler import WakeScheduler
|
||||
|
||||
@@ -31,6 +32,102 @@ class FakeUEXNotifications:
|
||||
}
|
||||
|
||||
|
||||
class FailingUEXNotifications:
|
||||
async def get_user_notifications(self):
|
||||
raise RuntimeError("bad token")
|
||||
|
||||
|
||||
class FakeWakeAgent:
|
||||
async def generate_wake_response(self, wake_message):
|
||||
return f"Wake output: {wake_message}"
|
||||
|
||||
|
||||
class ListingWakeTools:
|
||||
schemas = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_marketplace_listings",
|
||||
"description": "Search active UEX marketplace listings.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
self.pending_actions = {}
|
||||
|
||||
async def execute(self, name, arguments):
|
||||
self.calls.append((name, arguments))
|
||||
return {
|
||||
"count": 2,
|
||||
"listings": [
|
||||
{
|
||||
"id": 100,
|
||||
"title": "Wikelo Favor",
|
||||
"operation": "sell",
|
||||
"price": 500_000_000,
|
||||
"currency": "UEC",
|
||||
"in_stock": 9,
|
||||
"advertiser": "pilot_a",
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"title": "Wikelo Favor stack",
|
||||
"operation": "sell",
|
||||
"price": 1_000_000_000,
|
||||
"currency": "UEC",
|
||||
"in_stock": 5,
|
||||
"advertiser": "pilot_b",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ListingWakeAgent(OllamaAgent):
|
||||
def __init__(self, memory):
|
||||
self.listing_tools = ListingWakeTools()
|
||||
super().__init__("http://127.0.0.1:1", "missing-model", self.listing_tools, memory=memory)
|
||||
self.responses = [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "search_marketplace_listings",
|
||||
"arguments": {
|
||||
"query": "Wikelo Favor",
|
||||
"operation": "sell",
|
||||
"limit": 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"Listing check complete: found 2 active Wikelo Favor sell listings. "
|
||||
"Cheapest listing is 500,000,000 UEC with 9 in stock; the next listing is "
|
||||
"1,000,000,000 UEC. Suggested next action: price near 500,000,000 UEC "
|
||||
"if you want to move yours quickly."
|
||||
),
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
async def ensure_available(self):
|
||||
return None
|
||||
|
||||
async def _ollama_chat(self, *args, **kwargs):
|
||||
return self.responses.pop(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_uex_notifications_adds_unread_once(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
@@ -45,3 +142,52 @@ async def test_poll_uex_notifications_adds_unread_once(tmp_path):
|
||||
assert second == []
|
||||
assert len(outbox) == 1
|
||||
assert "A buyer replied to your listing." in outbox[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_uex_notifications_reports_failures_to_outbox(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
scheduler = WakeScheduler(memory)
|
||||
scheduler.bind_uex_notifications(FailingUEXNotifications())
|
||||
|
||||
result = await scheduler.poll_uex_notifications()
|
||||
|
||||
assert result == []
|
||||
assert "bad token" in memory.inspect()["outbox"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake_job_writes_agent_output_to_outbox_and_disables_one_shot(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
scheduler = WakeScheduler(memory)
|
||||
scheduler.bind_agent(FakeWakeAgent())
|
||||
memory.add_job("wake-test", "check notifications", "date", "2099-01-01T00:00:00+00:00")
|
||||
|
||||
await scheduler._run_job("wake-test", "check notifications")
|
||||
snapshot = memory.inspect()
|
||||
|
||||
assert "Wake output:" in snapshot["outbox"][0]["content"]
|
||||
assert snapshot["scheduled_jobs"][0]["enabled"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake_job_checks_listings_and_writes_analysis_to_outbox(tmp_path):
|
||||
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
||||
scheduler = WakeScheduler(memory)
|
||||
agent = ListingWakeAgent(memory)
|
||||
scheduler.bind_agent(agent)
|
||||
memory.add_job("wake-listings", "check Wikelo Favor listings and analyze the market", "date", "2099-01-01T00:00:00+00:00")
|
||||
|
||||
await scheduler._run_job("wake-listings", "check Wikelo Favor listings and analyze the market")
|
||||
snapshot = memory.inspect()
|
||||
content = snapshot["outbox"][0]["content"]
|
||||
|
||||
assert agent.listing_tools.calls == [
|
||||
(
|
||||
"search_marketplace_listings",
|
||||
{"query": "Wikelo Favor", "operation": "sell", "limit": 5},
|
||||
)
|
||||
]
|
||||
assert "Listing check complete" in content
|
||||
assert "500,000,000 UEC" in content
|
||||
assert "Suggested next action" in content
|
||||
|
||||
@@ -7,6 +7,9 @@ from traderai.uex_client import UEXClient
|
||||
|
||||
|
||||
class FakeUEX:
|
||||
def __init__(self):
|
||||
self.posts = []
|
||||
|
||||
async def get(self, path, params=None, authenticated=False):
|
||||
if path == "commodities_prices_history":
|
||||
return {
|
||||
@@ -113,6 +116,10 @@ class FakeUEX:
|
||||
async def delete(self, path, params=None, authenticated=True):
|
||||
return {"status": "ok", "deleted": {"path": path, "params": params, "authenticated": authenticated}}
|
||||
|
||||
async def post(self, path, payload, authenticated=True):
|
||||
self.posts.append({"path": path, "payload": payload, "authenticated": authenticated})
|
||||
return {"status": "ok", "posted": self.posts[-1]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_marketplace_listings_filters_locally():
|
||||
@@ -145,6 +152,19 @@ async def test_decline_pending_action_removes_without_sending():
|
||||
assert action_id not in registry.pending_actions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_negotiation_message_forces_production_send():
|
||||
fake = FakeUEX()
|
||||
registry = ToolRegistry(fake)
|
||||
result = await registry.draft_negotiation_message(hash="abc", message="Ready to close", is_production=0)
|
||||
action_id = result["pending_action"]["id"]
|
||||
|
||||
approved = await registry.approve(action_id)
|
||||
|
||||
assert approved["posted"]["path"] == "marketplace_negotiations_messages"
|
||||
assert approved["posted"]["payload"]["is_production"] == 1
|
||||
|
||||
|
||||
def test_uex_client_uses_bearer_and_secret_headers():
|
||||
client = UEXClient("https://api.uexcorp.space/2.0", secret_key="secret", bearer_token="bearer")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user