156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
import json
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi import HTTPException
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from traderai.agent import OllamaAgent, OllamaUnavailable
|
|
from traderai.config import get_settings
|
|
from traderai.memory import MemoryStore
|
|
from traderai.scheduler import WakeScheduler
|
|
from traderai.tools import ToolRegistry
|
|
from traderai.uex_client import UEXClient
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
|
|
|
|
class ClearMemoryRequest(BaseModel):
|
|
include_memories: bool = True
|
|
include_conversations: bool = True
|
|
include_profile: bool = False
|
|
include_jobs: bool = False
|
|
include_outbox: bool = True
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
settings = get_settings()
|
|
memory = MemoryStore(settings.traderai_memory_path)
|
|
scheduler = WakeScheduler(memory)
|
|
uex = UEXClient(settings.uex_base_url, settings.uex_secret_key, settings.uex_bearer_token)
|
|
tools = ToolRegistry(uex, settings.require_write_approval, memory=memory, scheduler=scheduler)
|
|
agent = OllamaAgent(
|
|
settings.ollama_base_url,
|
|
settings.ollama_model,
|
|
tools,
|
|
memory=memory,
|
|
user_name=settings.traderai_user_name,
|
|
num_ctx=settings.ollama_num_ctx,
|
|
)
|
|
scheduler.bind_agent(agent)
|
|
scheduler.bind_uex_notifications(uex, settings.uex_notification_poll_seconds)
|
|
|
|
app = FastAPI(title="TraderAI")
|
|
static_dir = Path(__file__).resolve().parent.parent / "web"
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
@app.on_event("startup")
|
|
async def startup() -> None:
|
|
await refresh_user_profile()
|
|
scheduler.start()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown() -> None:
|
|
scheduler.shutdown()
|
|
|
|
async def refresh_user_profile() -> None:
|
|
if settings.traderai_user_name:
|
|
memory.set_profile("configured_name", settings.traderai_user_name)
|
|
agent.user_name = agent.user_name or settings.traderai_user_name
|
|
|
|
try:
|
|
response = await uex.get_user(authenticated=True)
|
|
except Exception as exc:
|
|
memory.set_profile("uex_user_error", str(exc))
|
|
if settings.traderai_user_name:
|
|
try:
|
|
response = await uex.get_user(username=settings.traderai_user_name)
|
|
except Exception:
|
|
return
|
|
else:
|
|
return
|
|
|
|
data = response.get("user")
|
|
if data:
|
|
memory.set_profile("uex_user", data)
|
|
username = data.get("username") or data.get("user_username") or data.get("name")
|
|
if username:
|
|
agent.user_name = username
|
|
|
|
@app.get("/")
|
|
async def index() -> FileResponse:
|
|
return FileResponse(static_dir / "index.html")
|
|
|
|
@app.get("/api/health")
|
|
async def health() -> dict:
|
|
return {
|
|
"ollama": await agent.health(),
|
|
"user": memory.get_profile(),
|
|
"jobs": scheduler.list_jobs(),
|
|
}
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(request: ChatRequest) -> dict:
|
|
try:
|
|
return await agent.chat(request.message)
|
|
except OllamaUnavailable as exc:
|
|
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
|
|
|
@app.post("/api/chat/stream")
|
|
async def chat_stream(request: ChatRequest) -> StreamingResponse:
|
|
async def events():
|
|
async for event in agent.chat_events(request.message):
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
|
|
return StreamingResponse(events(), media_type="text/event-stream")
|
|
|
|
@app.get("/api/pending-actions")
|
|
async def pending_actions() -> dict:
|
|
return {"pending_actions": agent._pending_payloads()}
|
|
|
|
@app.get("/api/notifications")
|
|
async def notifications() -> dict:
|
|
return {"notifications": memory.undelivered_outbox()}
|
|
|
|
@app.get("/api/wake-jobs")
|
|
async def wake_jobs() -> dict:
|
|
return {"scheduled_jobs": scheduler.list_jobs()}
|
|
|
|
@app.get("/api/memory")
|
|
async def inspect_memory(limit: int = 50) -> dict:
|
|
return memory.inspect(max(1, min(limit, 200)))
|
|
|
|
@app.post("/api/memory/clear")
|
|
async def clear_memory(request: ClearMemoryRequest) -> dict:
|
|
if request.include_jobs:
|
|
scheduler.shutdown()
|
|
deleted = memory.clear(
|
|
include_memories=request.include_memories,
|
|
include_conversations=request.include_conversations,
|
|
include_profile=request.include_profile,
|
|
include_jobs=request.include_jobs,
|
|
include_outbox=request.include_outbox,
|
|
)
|
|
if request.include_jobs:
|
|
scheduler.start()
|
|
return {"deleted": deleted, "memory": memory.inspect(50)}
|
|
|
|
@app.post("/api/approve/{action_id}")
|
|
async def approve(action_id: str) -> dict:
|
|
return await tools.approve(action_id)
|
|
|
|
@app.post("/api/decline/{action_id}")
|
|
async def decline(action_id: str) -> dict:
|
|
return await tools.decline(action_id)
|
|
|
|
return app
|
|
|
|
|
|
app = create_app()
|