93 lines
3.0 KiB
Python
93 lines
3.0 KiB
Python
import pytest
|
|
|
|
from traderai.agent import OllamaAgent, SYSTEM_PROMPT
|
|
from traderai.memory import MemoryStore
|
|
|
|
|
|
class EmptyTools:
|
|
schemas = []
|
|
|
|
@property
|
|
def pending_actions(self):
|
|
return {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_events_warns_when_ollama_offline():
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools())
|
|
events = []
|
|
async for event in agent.chat_events("hello"):
|
|
events.append(event)
|
|
|
|
assert events[0]["type"] == "warning"
|
|
assert "Ollama is offline" in events[0]["message"]
|
|
assert events[-1]["type"] == "done"
|
|
|
|
|
|
def test_runtime_context_uses_previous_interaction_not_current_message(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
memory.add_conversation("assistant", "older answer")
|
|
previous = memory.last_interaction()
|
|
assert previous is not None
|
|
|
|
memory.add_conversation("user", "current question")
|
|
current = memory.last_interaction()
|
|
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
context = agent._runtime_context("current question", previous_interaction=previous)
|
|
|
|
assert f"Previous interaction before this message: {previous['created_at']}" in context
|
|
assert f"Previous interaction before this message: {current['created_at']}" not in context
|
|
assert "Current local date/time:" in context
|
|
|
|
|
|
def test_runtime_context_includes_uex_user_identity(tmp_path):
|
|
memory = MemoryStore(str(tmp_path / "memory.sqlite3"))
|
|
memory.set_profile(
|
|
"uex_user",
|
|
{
|
|
"username": "pilot_hudson",
|
|
"name": "Hudson",
|
|
"email": "hudson@example.test",
|
|
"timezone": "America/New_York",
|
|
"specializations": "trading,hauling",
|
|
},
|
|
)
|
|
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), memory=memory)
|
|
context = agent._runtime_context("")
|
|
|
|
assert "You are speaking with UEX user pilot_hudson (Hudson)." in context
|
|
assert "timezone: America/New_York" in context
|
|
assert "specializations: trading,hauling" in context
|
|
assert "hudson@example.test" not in context
|
|
|
|
|
|
def test_stream_metrics_include_reading_and_writing_rates():
|
|
metrics = OllamaAgent._stream_metrics(
|
|
{
|
|
"prompt_eval_count": 20,
|
|
"prompt_eval_duration": 2_000_000_000,
|
|
"eval_count": 30,
|
|
"eval_duration": 3_000_000_000,
|
|
}
|
|
)
|
|
|
|
assert metrics["reading_tokens"] == 20
|
|
assert metrics["reading_tokens_per_second"] == 10
|
|
assert metrics["writing_tokens"] == 30
|
|
assert metrics["writing_tokens_per_second"] == 10
|
|
|
|
|
|
def test_system_prompt_prefers_current_marketplace_data():
|
|
assert "open/current" in SYSTEM_PROMPT
|
|
assert "Do not use historical sale data" in SYSTEM_PROMPT
|
|
assert "aUEC/UEC credits" in SYSTEM_PROMPT
|
|
assert "never real-world dollars" in SYSTEM_PROMPT
|
|
|
|
|
|
def test_ollama_options_include_num_ctx():
|
|
agent = OllamaAgent("http://127.0.0.1:1", "missing-model", EmptyTools(), num_ctx=64000)
|
|
|
|
assert agent._ollama_options() == {"num_ctx": 64000}
|