Files
TraderAI/traderai/server.py
T

152 lines
4.9 KiB
Python

from __future__ import annotations
from pathlib import Path
import json
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from traderai.agent import OllamaAgent, OllamaUnavailable
from traderai.config import get_settings
from traderai.memory import MemoryStore
from traderai.scheduler import WakeScheduler
from traderai.tools import ToolRegistry
from traderai.uex_client import UEXClient
class ChatRequest(BaseModel):
message: str
class ClearMemoryRequest(BaseModel):
include_memories: bool = True
include_conversations: bool = True
include_profile: bool = False
include_jobs: bool = False
include_outbox: bool = True
def create_app() -> FastAPI:
settings = get_settings()
memory = MemoryStore(settings.traderai_memory_path)
scheduler = WakeScheduler(memory)
uex = UEXClient(settings.uex_base_url, settings.uex_secret_key, settings.uex_bearer_token)
tools = ToolRegistry(uex, settings.require_write_approval, memory=memory, scheduler=scheduler)
agent = OllamaAgent(
settings.ollama_base_url,
settings.ollama_model,
tools,
memory=memory,
user_name=settings.traderai_user_name,
num_ctx=settings.ollama_num_ctx,
)
scheduler.bind_agent(agent)
scheduler.bind_uex_notifications(uex, settings.uex_notification_poll_seconds)
app = FastAPI(title="TraderAI")
static_dir = Path(__file__).resolve().parent.parent / "web"
app.mount("/static", StaticFiles(directory=static_dir), name="static")
@app.on_event("startup")
async def startup() -> None:
await refresh_user_profile()
scheduler.start()
@app.on_event("shutdown")
async def shutdown() -> None:
scheduler.shutdown()
async def refresh_user_profile() -> None:
if settings.traderai_user_name:
memory.set_profile("configured_name", settings.traderai_user_name)
agent.user_name = agent.user_name or settings.traderai_user_name
try:
response = await uex.get_user(authenticated=True)
except Exception as exc:
memory.set_profile("uex_user_error", str(exc))
if settings.traderai_user_name:
try:
response = await uex.get_user(username=settings.traderai_user_name)
except Exception:
return
else:
return
data = response.get("user")
if data:
memory.set_profile("uex_user", data)
username = data.get("username") or data.get("user_username") or data.get("name")
if username:
agent.user_name = username
@app.get("/")
async def index() -> FileResponse:
return FileResponse(static_dir / "index.html")
@app.get("/api/health")
async def health() -> dict:
return {
"ollama": await agent.health(),
"user": memory.get_profile(),
"jobs": scheduler.list_jobs(),
}
@app.post("/api/chat")
async def chat(request: ChatRequest) -> dict:
try:
return await agent.chat(request.message)
except OllamaUnavailable as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest) -> StreamingResponse:
async def events():
async for event in agent.chat_events(request.message):
yield f"data: {json.dumps(event)}\n\n"
return StreamingResponse(events(), media_type="text/event-stream")
@app.get("/api/pending-actions")
async def pending_actions() -> dict:
return {"pending_actions": agent._pending_payloads()}
@app.get("/api/notifications")
async def notifications() -> dict:
return {"notifications": memory.undelivered_outbox()}
@app.get("/api/wake-jobs")
async def wake_jobs() -> dict:
return {"scheduled_jobs": scheduler.list_jobs()}
@app.get("/api/memory")
async def inspect_memory(limit: int = 50) -> dict:
return memory.inspect(max(1, min(limit, 200)))
@app.post("/api/memory/clear")
async def clear_memory(request: ClearMemoryRequest) -> dict:
if request.include_jobs:
scheduler.shutdown()
deleted = memory.clear(
include_memories=request.include_memories,
include_conversations=request.include_conversations,
include_profile=request.include_profile,
include_jobs=request.include_jobs,
include_outbox=request.include_outbox,
)
if request.include_jobs:
scheduler.start()
return {"deleted": deleted, "memory": memory.inspect(50)}
@app.post("/api/approve/{action_id}")
async def approve(action_id: str) -> dict:
return await tools.approve(action_id)
return app
app = create_app()