from __future__ import annotations from datetime import datetime, timedelta from typing import Any from uuid import uuid4 from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger from apscheduler.triggers.interval import IntervalTrigger from tzlocal import get_localzone from traderai.memory import MemoryStore, iso_now, parse_iso, time_since, utc_now UEX_NOTIFICATION_JOB_ID = "uex-notification-poll" class WakeScheduler: def __init__(self, memory: MemoryStore) -> None: self.memory = memory self.scheduler = AsyncIOScheduler(timezone=get_localzone()) self.agent = None self.uex = None self.plan_runner = None self.notification_poll_seconds = 60 def bind_agent(self, agent: Any) -> None: self.agent = agent def bind_plan_runner(self, plan_runner: Any) -> None: self.plan_runner = plan_runner def bind_uex_notifications(self, uex: Any, poll_seconds: int = 60) -> None: self.uex = uex self.notification_poll_seconds = max(15, poll_seconds) def start(self) -> None: if not self.scheduler.running: self.scheduler.start() self._schedule_notification_poll() for job in self.memory.list_jobs(): self._schedule_existing(job) if self.plan_runner is not None: for plan in self.plan_runner.store.list_plans(include_inactive=False): self.schedule_plan(plan) def shutdown(self) -> None: if self.scheduler.running: self.scheduler.shutdown(wait=False) def schedule_date(self, run_at: str, prompt: str, job_id: str | None = None) -> dict[str, Any]: parsed = datetime.fromisoformat(run_at) job_id = job_id or f"wake-{uuid4()}" trigger = DateTrigger(run_date=parsed) self.scheduler.add_job(self._run_job, trigger=trigger, id=job_id, args=[job_id, prompt], replace_existing=True) return self.memory.add_job(job_id, prompt, "date", run_at, parsed.isoformat()) def schedule_cron(self, cron: str, prompt: str, job_id: str | None = None) -> dict[str, Any]: job_id = job_id or f"wake-{uuid4()}" trigger = CronTrigger.from_crontab(cron) self.scheduler.add_job(self._run_job, trigger=trigger, id=job_id, args=[job_id, prompt], replace_existing=True) next_run = self.scheduler.get_job(job_id).next_run_time return self.memory.add_job(job_id, prompt, "cron", cron, next_run.isoformat() if next_run else None) def list_jobs(self) -> list[dict[str, Any]]: return self.memory.list_jobs() def schedule_plan(self, plan: dict[str, Any]) -> dict[str, Any]: if self.plan_runner is None or plan.get("status") != "active": return plan job_id = self._plan_job_id(plan["id"]) previous_next_run = plan.get("next_run_at") trigger = CronTrigger.from_crontab(plan.get("cadence") or "0 */6 * * *") self.scheduler.add_job(self._run_plan, trigger=trigger, id=job_id, args=[plan["id"]], replace_existing=True) job = self.scheduler.get_job(job_id) next_run = job.next_run_time if job else None self.plan_runner.store.update_schedule(plan["id"], next_run.isoformat() if next_run else None) if self._plan_is_overdue(previous_next_run): catchup_id = self._plan_catchup_job_id(plan["id"]) self.scheduler.add_job( self._run_plan, trigger=DateTrigger(run_date=datetime.now() + timedelta(seconds=5)), id=catchup_id, args=[plan["id"]], replace_existing=True, ) self.plan_runner.store.add_event( plan["id"], "catchup_scheduled", "Plan was overdue while the app was closed, so a one-time catch-up run was scheduled after startup.", {"previous_next_run_at": previous_next_run}, ) return self.plan_runner.store.get_plan(plan["id"]) or plan def unschedule_plan(self, plan_id: str) -> None: job_id = self._plan_job_id(plan_id) if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) catchup_id = self._plan_catchup_job_id(plan_id) if self.scheduler.get_job(catchup_id): self.scheduler.remove_job(catchup_id) if self.plan_runner is not None: self.plan_runner.store.update_schedule(plan_id, None) async def _run_plan(self, plan_id: str) -> None: if self.plan_runner is None: return result = await self.plan_runner.run_plan(plan_id) plan = result.get("plan") or self.plan_runner.store.get_plan(plan_id) if plan and plan.get("status") == "active": job = self.scheduler.get_job(self._plan_job_id(plan_id)) next_run = job.next_run_time if job else None self.plan_runner.store.update_schedule(plan_id, next_run.isoformat() if next_run else None) @staticmethod def _plan_job_id(plan_id: str) -> str: return f"continual-{plan_id}" @staticmethod def _plan_catchup_job_id(plan_id: str) -> str: return f"continual-catchup-{plan_id}" @staticmethod def _plan_is_overdue(next_run_at: str | None) -> bool: if not next_run_at: return False try: return parse_iso(next_run_at) <= utc_now() except ValueError: return False def _schedule_existing(self, job: dict[str, Any]) -> None: if job["trigger_type"] == "cron": trigger = CronTrigger.from_crontab(job["trigger_value"]) elif job["trigger_type"] == "date": trigger = DateTrigger(run_date=datetime.fromisoformat(job["trigger_value"])) else: return self.scheduler.add_job( self._run_job, trigger=trigger, id=job["id"], args=[job["id"], job["prompt"]], replace_existing=True, ) async def _run_job(self, job_id: str, prompt: str) -> None: last = self.memory.last_interaction() last_text = f"{last['created_at']} ({time_since(last['created_at'])})" if last else "never" wake_message = ( f"Scheduled wake job fired. Current time is {iso_now()}. " f"The last chat interaction was {last_text}. Job instruction: {prompt}" ) if self.agent is None: self.memory.add_outbox(wake_message) self._mark_job_finished(job_id) return try: text = await self.agent.generate_wake_response(wake_message) except Exception as exc: text = f"Wake job failed: {exc}. Job instruction: {prompt}" self.memory.add_outbox(text) self._mark_job_finished(job_id) def _mark_job_finished(self, job_id: str) -> None: job = self.scheduler.get_job(job_id) next_run = job.next_run_time if job else None self.memory.mark_job_run(job_id, next_run.isoformat() if next_run else None, enabled=bool(next_run)) def _schedule_notification_poll(self) -> None: if self.uex is None: return self.scheduler.add_job( self.poll_uex_notifications, trigger=IntervalTrigger(seconds=self.notification_poll_seconds), id=UEX_NOTIFICATION_JOB_ID, replace_existing=True, next_run_time=datetime.now(), ) async def poll_uex_notifications(self) -> list[dict[str, Any]]: if self.uex is None: return [] try: response = await self.uex.get_user_notifications() except Exception as exc: self.memory.add_outbox(f"UEX notification poll failed: {exc}") self.memory.set_profile("uex_last_notification_error", str(exc)) return [] notifications = response.get("notifications") or [] pending = [item for item in notifications if not item.get("date_read")] profile = self.memory.get_profile() seen = set(profile.get("uex_seen_notification_keys") or []) new_pending = [item for item in pending if self._notification_key(item) not in seen] if new_pending: for item in new_pending: self.memory.add_outbox(self._notification_text(item)) seen.update(self._notification_key(item) for item in new_pending) self.memory.set_profile("uex_seen_notification_keys", sorted(seen)) self.memory.set_profile("uex_last_notification_check", iso_now()) elif notifications: seen.update(self._notification_key(item) for item in pending) self.memory.set_profile("uex_seen_notification_keys", sorted(seen)) self.memory.set_profile("uex_last_notification_check", iso_now()) return new_pending @staticmethod def _notification_key(item: dict[str, Any]) -> str: for key in ("code", "id"): value = item.get(key) if value not in (None, ""): return f"{key}:{value}" return f"notification:{item.get('date_added')}:{item.get('message')}" @staticmethod def _notification_text(item: dict[str, Any]) -> str: message = item.get("message") or "You have a pending UEX notification." redir = item.get("redir") code = item.get("code") details = [] if code: details.append(f"code `{code}`") if redir: details.append(f"path `{redir}`") suffix = f" ({', '.join(details)})" if details else "" return f"UEX notification: {message}{suffix}"