Inital Commit
This commit is contained in:
@@ -0,0 +1,344 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from traderai.memory import MemoryStore, iso_now, iso_now_in_zone, time_since
|
||||
from traderai.tools import ToolRegistry
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """You are TraderAI, a local assistant for UEX marketplace work.
|
||||
Use tools when the user asks about listings, negotiations, messages, offers, or posting ads.
|
||||
For marketplace writes, draft the exact pending action and tell the user what will be sent; never claim it was sent until approval succeeds.
|
||||
Keep prices, listing ids, slugs, users, and UEX status codes precise. If data is missing, say what you need next."""
|
||||
|
||||
|
||||
class OllamaAgent:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
model: str,
|
||||
tools: ToolRegistry,
|
||||
memory: MemoryStore | None = None,
|
||||
user_name: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.memory = memory
|
||||
self.user_name = user_name
|
||||
self.messages: list[dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
async def health(self) -> dict[str, Any]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3) as client:
|
||||
response = await client.get(f"{self.base_url}/api/tags")
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except (httpx.HTTPError, ValueError) as exc:
|
||||
return {
|
||||
"online": False,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"message": f"Ollama is offline or unreachable at {self.base_url}. Start Ollama and make sure the model is pulled.",
|
||||
"detail": str(exc),
|
||||
}
|
||||
|
||||
models = [model.get("name") or model.get("model") for model in body.get("models", [])]
|
||||
return {
|
||||
"online": True,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"model_available": self.model in models,
|
||||
"models": models,
|
||||
"message": "Ollama is online.",
|
||||
}
|
||||
|
||||
async def ensure_available(self) -> None:
|
||||
health = await self.health()
|
||||
if not health["online"]:
|
||||
raise OllamaUnavailable(health["message"])
|
||||
|
||||
async def chat(self, content: str) -> dict[str, Any]:
|
||||
await self.ensure_available()
|
||||
previous_interaction = self.memory.last_interaction() if self.memory else None
|
||||
if self.memory:
|
||||
self.memory.add_conversation("user", content)
|
||||
self.messages.append({"role": "user", "content": content})
|
||||
for _ in range(5):
|
||||
response = await self._ollama_chat(content, previous_interaction=previous_interaction)
|
||||
message = response.get("message") or {}
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if not tool_calls:
|
||||
self.messages.append({"role": "assistant", "content": message.get("content", "")})
|
||||
if self.memory:
|
||||
self.memory.add_conversation("assistant", message.get("content", ""))
|
||||
return {"message": message.get("content", ""), "pending_actions": self._pending_payloads()}
|
||||
|
||||
self.messages.append(message)
|
||||
for call in tool_calls:
|
||||
name, arguments = self._extract_call(call)
|
||||
result = await self.tools.execute(name, arguments)
|
||||
self.messages.append({"role": "tool", "tool_name": name, "content": json.dumps(result)})
|
||||
|
||||
fallback = "I hit the tool-call limit while working on that. Try narrowing the request or approve any pending action first."
|
||||
self.messages.append({"role": "assistant", "content": fallback})
|
||||
if self.memory:
|
||||
self.memory.add_conversation("assistant", fallback)
|
||||
return {"message": fallback, "pending_actions": self._pending_payloads()}
|
||||
|
||||
async def chat_events(self, content: str) -> AsyncIterator[dict[str, Any]]:
|
||||
health = await self.health()
|
||||
if not health["online"]:
|
||||
yield {"type": "warning", "message": health["message"]}
|
||||
yield {"type": "done", "pending_actions": self._pending_payloads()}
|
||||
return
|
||||
|
||||
previous_interaction = self.memory.last_interaction() if self.memory else None
|
||||
if self.memory:
|
||||
self.memory.add_conversation("user", content)
|
||||
self.messages.append({"role": "user", "content": content})
|
||||
yield {"type": "status", "message": "Thinking"}
|
||||
|
||||
for _ in range(5):
|
||||
assistant_message: dict[str, Any] = {"role": "assistant", "content": ""}
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
async for event in self._ollama_chat_stream(content, previous_interaction=previous_interaction):
|
||||
message = event.get("message") or {}
|
||||
chunk = message.get("content") or ""
|
||||
if chunk:
|
||||
assistant_message["content"] += chunk
|
||||
yield {"type": "token", "content": chunk}
|
||||
if message.get("tool_calls"):
|
||||
tool_calls.extend(message["tool_calls"])
|
||||
|
||||
if not tool_calls:
|
||||
self.messages.append(assistant_message)
|
||||
if self.memory:
|
||||
self.memory.add_conversation("assistant", assistant_message.get("content", ""))
|
||||
yield {"type": "done", "pending_actions": self._pending_payloads()}
|
||||
return
|
||||
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
self.messages.append(assistant_message)
|
||||
for call in tool_calls:
|
||||
name, arguments = self._extract_call(call)
|
||||
yield {"type": "status", "message": self._tool_status(name)}
|
||||
result = await self.tools.execute(name, arguments)
|
||||
self.messages.append({"role": "tool", "tool_name": name, "content": json.dumps(result)})
|
||||
|
||||
yield {"type": "status", "message": "Writing response"}
|
||||
|
||||
fallback = "I hit the tool-call limit while working on that. Try narrowing the request or approve any pending action first."
|
||||
self.messages.append({"role": "assistant", "content": fallback})
|
||||
if self.memory:
|
||||
self.memory.add_conversation("assistant", fallback)
|
||||
yield {"type": "token", "content": fallback}
|
||||
yield {"type": "done", "pending_actions": self._pending_payloads()}
|
||||
|
||||
async def generate_wake_response(self, wake_message: str) -> str:
|
||||
await self.ensure_available()
|
||||
self.messages.append({"role": "user", "content": wake_message})
|
||||
response = await self._ollama_chat(wake_message)
|
||||
message = response.get("message") or {}
|
||||
content = message.get("content", "")
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
if self.memory:
|
||||
self.memory.add_conversation("system", wake_message)
|
||||
self.memory.add_conversation("assistant", content)
|
||||
return content or wake_message
|
||||
|
||||
async def _ollama_chat(self, query: str = "", previous_interaction: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": self._messages_with_context(query, previous_interaction=previous_interaction),
|
||||
"tools": self.tools.schemas,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _ollama_chat_stream(
|
||||
self,
|
||||
query: str = "",
|
||||
previous_interaction: dict[str, Any] | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/api/chat",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": self._messages_with_context(query, previous_interaction=previous_interaction),
|
||||
"tools": self.tools.schemas,
|
||||
"stream": True,
|
||||
},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line:
|
||||
yield json.loads(line)
|
||||
|
||||
def _messages_with_context(
|
||||
self,
|
||||
query: str,
|
||||
previous_interaction: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
context = self._runtime_context(query, previous_interaction=previous_interaction)
|
||||
if not context:
|
||||
return self.messages
|
||||
return [self.messages[0], {"role": "system", "content": context}, *self.messages[1:]]
|
||||
|
||||
def _runtime_context(self, query: str, previous_interaction: dict[str, Any] | None = None) -> str:
|
||||
local_zone = get_localzone()
|
||||
parts = [
|
||||
f"Current local date/time: {iso_now()} UTC; {iso_now_in_zone(local_zone)} {local_zone}.",
|
||||
]
|
||||
if self.user_name:
|
||||
parts.append(f"Known user name/handle: {self.user_name}.")
|
||||
|
||||
if self.memory is None:
|
||||
return "\n".join(parts)
|
||||
|
||||
profile = self.memory.get_profile()
|
||||
if profile:
|
||||
identity = self._profile_identity(profile)
|
||||
if identity:
|
||||
parts.append(identity)
|
||||
parts.append(f"Known user profile JSON: {json.dumps(self._profile_for_prompt(profile), ensure_ascii=True)}.")
|
||||
|
||||
last = previous_interaction if previous_interaction is not None else self.memory.last_interaction()
|
||||
if last:
|
||||
parts.append(
|
||||
f"Previous interaction before this message: {last['created_at']} "
|
||||
f"({time_since(last['created_at'])}, role {last['role']})."
|
||||
)
|
||||
else:
|
||||
parts.append("Previous interaction before this message: none recorded.")
|
||||
|
||||
memories = self.memory.recall(query, limit=6)
|
||||
if memories:
|
||||
memory_text = "\n".join(
|
||||
f"- [{item['kind']}, importance {item['importance']}] {item['content']}"
|
||||
for item in memories
|
||||
)
|
||||
parts.append(f"Relevant long-term memories:\n{memory_text}")
|
||||
|
||||
recent = self.memory.recent_conversation(limit=6)
|
||||
if recent:
|
||||
recent_text = "\n".join(
|
||||
f"- {item['created_at']} {item['role']}: {item['content'][:500]}"
|
||||
for item in recent
|
||||
)
|
||||
parts.append(f"Recent conversation excerpts:\n{recent_text}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _pending_payloads(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"id": action.id,
|
||||
"label": action.label,
|
||||
"endpoint": action.endpoint,
|
||||
"payload": action.payload,
|
||||
}
|
||||
for action in self.tools.pending_actions.values()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _tool_status(name: str) -> str:
|
||||
labels = {
|
||||
"search_marketplace_listings": "Searching UEX listings",
|
||||
"get_marketplace_listing": "Fetching listing details",
|
||||
"list_marketplace_negotiations": "Checking negotiations",
|
||||
"get_negotiation_messages": "Reading negotiation messages",
|
||||
"draft_negotiation_message": "Drafting message for approval",
|
||||
"draft_marketplace_listing": "Drafting listing for approval",
|
||||
}
|
||||
return labels.get(name, f"Running {name}")
|
||||
|
||||
@staticmethod
|
||||
def _profile_identity(profile: dict[str, Any]) -> str:
|
||||
user = profile.get("uex_user")
|
||||
if not isinstance(user, dict):
|
||||
configured = profile.get("configured_name")
|
||||
return f"You are speaking with {configured}." if configured else ""
|
||||
|
||||
username = user.get("username") or user.get("user_username")
|
||||
name = user.get("name")
|
||||
fields = []
|
||||
if username and name and username != name:
|
||||
fields.append(f"You are speaking with UEX user {username} ({name}).")
|
||||
elif username or name:
|
||||
fields.append(f"You are speaking with UEX user {username or name}.")
|
||||
|
||||
details = []
|
||||
for key, label in [
|
||||
("timezone", "timezone"),
|
||||
("language", "preferred language"),
|
||||
("specializations", "specializations"),
|
||||
("languages", "languages"),
|
||||
("archetypes", "archetypes"),
|
||||
]:
|
||||
value = user.get(key)
|
||||
if value:
|
||||
details.append(f"{label}: {value}")
|
||||
if details:
|
||||
fields.append("UEX profile details: " + "; ".join(details) + ".")
|
||||
return " ".join(fields)
|
||||
|
||||
@staticmethod
|
||||
def _profile_for_prompt(profile: dict[str, Any]) -> dict[str, Any]:
|
||||
user = profile.get("uex_user")
|
||||
if not isinstance(user, dict):
|
||||
return profile
|
||||
|
||||
useful_user_fields = [
|
||||
"id",
|
||||
"name",
|
||||
"username",
|
||||
"avatar",
|
||||
"bio",
|
||||
"website_url",
|
||||
"timezone",
|
||||
"language",
|
||||
"day_availability",
|
||||
"time_availability",
|
||||
"specializations",
|
||||
"languages",
|
||||
"archetypes",
|
||||
"is_datarunner",
|
||||
"is_staff",
|
||||
"is_away_game",
|
||||
"date_rsi_verified",
|
||||
"date_twitch_verified",
|
||||
]
|
||||
prompt_profile = dict(profile)
|
||||
prompt_profile["uex_user"] = {
|
||||
key: user[key]
|
||||
for key in useful_user_fields
|
||||
if key in user and user[key] not in (None, "")
|
||||
}
|
||||
return prompt_profile
|
||||
|
||||
@staticmethod
|
||||
def _extract_call(call: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
function = call.get("function") or {}
|
||||
name = function.get("name") or call.get("name")
|
||||
arguments = function.get("arguments") or call.get("arguments") or {}
|
||||
if isinstance(arguments, str):
|
||||
arguments = json.loads(arguments or "{}")
|
||||
return name, arguments
|
||||
|
||||
|
||||
class OllamaUnavailable(RuntimeError):
|
||||
pass
|
||||
Reference in New Issue
Block a user