Inital Commit
This commit is contained in:
@@ -0,0 +1,378 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def iso_now() -> str:
|
||||
return utc_now().isoformat()
|
||||
|
||||
|
||||
def iso_now_in_zone(zone: ZoneInfo) -> str:
|
||||
return utc_now().astimezone(zone).isoformat()
|
||||
|
||||
|
||||
def parse_iso(value: str) -> datetime:
|
||||
parsed = datetime.fromisoformat(value)
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
|
||||
|
||||
def time_since(value: str, now: datetime | None = None) -> str:
|
||||
then = parse_iso(value)
|
||||
current = now or utc_now()
|
||||
if current.tzinfo is None:
|
||||
current = current.replace(tzinfo=timezone.utc)
|
||||
seconds = max(0, int((current - then).total_seconds()))
|
||||
if seconds < 60:
|
||||
return f"{seconds} seconds ago"
|
||||
minutes = seconds // 60
|
||||
if minutes < 60:
|
||||
return _plural(minutes, "minute") + " ago"
|
||||
hours = minutes // 60
|
||||
if hours < 24:
|
||||
return _plural(hours, "hour") + " ago"
|
||||
days = hours // 24
|
||||
return _plural(days, "day") + " ago"
|
||||
|
||||
|
||||
def _plural(value: int, unit: str) -> str:
|
||||
suffix = "" if value == 1 else "s"
|
||||
return f"{value} {unit}{suffix}"
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = Path(path)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
connection = sqlite3.connect(self.path)
|
||||
connection.row_factory = sqlite3.Row
|
||||
return connection
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._connect() as db:
|
||||
db.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kind TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
importance INTEGER NOT NULL DEFAULT 3,
|
||||
metadata TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||
content,
|
||||
kind UNINDEXED,
|
||||
content='memories',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||||
INSERT INTO memories_fts(rowid, content, kind) VALUES (new.id, new.content, new.kind);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, content, kind)
|
||||
VALUES('delete', old.id, old.content, old.kind);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, content, kind)
|
||||
VALUES('delete', old.id, old.content, old.kind);
|
||||
INSERT INTO memories_fts(rowid, content, kind) VALUES (new.id, new.content, new.kind);
|
||||
END;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_profile (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scheduled_jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
prompt TEXT NOT NULL,
|
||||
trigger_type TEXT NOT NULL,
|
||||
trigger_value TEXT NOT NULL,
|
||||
next_run_at TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
last_run_at TEXT,
|
||||
enabled INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS outbox (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
delivered_at TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def add_conversation(self, role: str, content: str) -> None:
|
||||
with self._connect() as db:
|
||||
db.execute(
|
||||
"INSERT INTO conversations(role, content, created_at) VALUES (?, ?, ?)",
|
||||
(role, content, iso_now()),
|
||||
)
|
||||
|
||||
def last_interaction(self) -> dict[str, Any] | None:
|
||||
with self._connect() as db:
|
||||
row = db.execute(
|
||||
"SELECT role, content, created_at FROM conversations ORDER BY id DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def recent_conversation(self, limit: int = 8) -> list[dict[str, Any]]:
|
||||
with self._connect() as db:
|
||||
rows = db.execute(
|
||||
"SELECT role, content, created_at FROM conversations ORDER BY id DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(row) for row in reversed(rows)]
|
||||
|
||||
def remember(self, kind: str, content: str, importance: int = 3, metadata: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
now = iso_now()
|
||||
with self._connect() as db:
|
||||
cursor = db.execute(
|
||||
"""
|
||||
INSERT INTO memories(kind, content, importance, metadata, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(kind, content, importance, json.dumps(metadata or {}), now, now),
|
||||
)
|
||||
memory_id = cursor.lastrowid
|
||||
return {"id": memory_id, "kind": kind, "content": content, "importance": importance, "created_at": now}
|
||||
|
||||
def recall(self, query: str, limit: int = 6) -> list[dict[str, Any]]:
|
||||
if not query.strip():
|
||||
return self.top_memories(limit)
|
||||
|
||||
with self._connect() as db:
|
||||
try:
|
||||
rows = db.execute(
|
||||
"""
|
||||
SELECT m.id, m.kind, m.content, m.importance, m.metadata, m.created_at,
|
||||
bm25(memories_fts) AS rank
|
||||
FROM memories_fts
|
||||
JOIN memories m ON m.id = memories_fts.rowid
|
||||
WHERE memories_fts MATCH ?
|
||||
ORDER BY rank, m.importance DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(self._fts_query(query), limit),
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
rows = db.execute(
|
||||
"""
|
||||
SELECT id, kind, content, importance, metadata, created_at
|
||||
FROM memories
|
||||
WHERE content LIKE ?
|
||||
ORDER BY importance DESC, id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(f"%{query}%", limit),
|
||||
).fetchall()
|
||||
return [self._memory_row(row) for row in rows]
|
||||
|
||||
def top_memories(self, limit: int = 6) -> list[dict[str, Any]]:
|
||||
with self._connect() as db:
|
||||
rows = db.execute(
|
||||
"""
|
||||
SELECT id, kind, content, importance, metadata, created_at
|
||||
FROM memories
|
||||
ORDER BY importance DESC, updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [self._memory_row(row) for row in rows]
|
||||
|
||||
def inspect(self, limit: int = 50) -> dict[str, Any]:
|
||||
with self._connect() as db:
|
||||
memories = db.execute(
|
||||
"""
|
||||
SELECT id, kind, content, importance, metadata, created_at, updated_at
|
||||
FROM memories
|
||||
ORDER BY importance DESC, updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
conversations = db.execute(
|
||||
"""
|
||||
SELECT id, role, content, created_at
|
||||
FROM conversations
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
profile_rows = db.execute(
|
||||
"SELECT key, value, updated_at FROM user_profile ORDER BY key"
|
||||
).fetchall()
|
||||
jobs = db.execute(
|
||||
"SELECT * FROM scheduled_jobs ORDER BY enabled DESC, next_run_at"
|
||||
).fetchall()
|
||||
outbox = db.execute(
|
||||
"SELECT id, content, created_at, delivered_at FROM outbox ORDER BY id DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
|
||||
profile = []
|
||||
for row in profile_rows:
|
||||
item = dict(row)
|
||||
try:
|
||||
item["value"] = json.loads(item["value"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
profile.append(item)
|
||||
|
||||
return {
|
||||
"path": str(self.path),
|
||||
"memories": [self._memory_row(row) for row in memories],
|
||||
"conversations": [dict(row) for row in conversations],
|
||||
"profile": profile,
|
||||
"scheduled_jobs": [dict(row) for row in jobs],
|
||||
"outbox": [dict(row) for row in outbox],
|
||||
}
|
||||
|
||||
def clear(
|
||||
self,
|
||||
include_memories: bool = True,
|
||||
include_conversations: bool = True,
|
||||
include_profile: bool = False,
|
||||
include_jobs: bool = False,
|
||||
include_outbox: bool = True,
|
||||
) -> dict[str, int]:
|
||||
deleted: dict[str, int] = {}
|
||||
with self._connect() as db:
|
||||
if include_memories:
|
||||
deleted["memories"] = db.execute("DELETE FROM memories").rowcount
|
||||
db.execute("INSERT INTO memories_fts(memories_fts) VALUES('rebuild')")
|
||||
if include_conversations:
|
||||
deleted["conversations"] = db.execute("DELETE FROM conversations").rowcount
|
||||
if include_profile:
|
||||
deleted["profile"] = db.execute("DELETE FROM user_profile").rowcount
|
||||
if include_jobs:
|
||||
deleted["scheduled_jobs"] = db.execute("DELETE FROM scheduled_jobs").rowcount
|
||||
if include_outbox:
|
||||
deleted["outbox"] = db.execute("DELETE FROM outbox").rowcount
|
||||
return deleted
|
||||
|
||||
def set_profile(self, key: str, value: Any) -> None:
|
||||
with self._connect() as db:
|
||||
db.execute(
|
||||
"""
|
||||
INSERT INTO user_profile(key, value, updated_at) VALUES (?, ?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value=excluded.value, updated_at=excluded.updated_at
|
||||
""",
|
||||
(key, json.dumps(value), iso_now()),
|
||||
)
|
||||
|
||||
def get_profile(self) -> dict[str, Any]:
|
||||
with self._connect() as db:
|
||||
rows = db.execute("SELECT key, value FROM user_profile").fetchall()
|
||||
profile = {}
|
||||
for row in rows:
|
||||
try:
|
||||
profile[row["key"]] = json.loads(row["value"])
|
||||
except json.JSONDecodeError:
|
||||
profile[row["key"]] = row["value"]
|
||||
return profile
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
job_id: str,
|
||||
prompt: str,
|
||||
trigger_type: str,
|
||||
trigger_value: str,
|
||||
next_run_at: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
with self._connect() as db:
|
||||
db.execute(
|
||||
"""
|
||||
INSERT INTO scheduled_jobs(id, prompt, trigger_type, trigger_value, next_run_at, created_at, enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
prompt=excluded.prompt,
|
||||
trigger_type=excluded.trigger_type,
|
||||
trigger_value=excluded.trigger_value,
|
||||
next_run_at=excluded.next_run_at,
|
||||
enabled=1
|
||||
""",
|
||||
(job_id, prompt, trigger_type, trigger_value, next_run_at, iso_now()),
|
||||
)
|
||||
return {
|
||||
"id": job_id,
|
||||
"prompt": prompt,
|
||||
"trigger_type": trigger_type,
|
||||
"trigger_value": trigger_value,
|
||||
"next_run_at": next_run_at,
|
||||
}
|
||||
|
||||
def list_jobs(self) -> list[dict[str, Any]]:
|
||||
with self._connect() as db:
|
||||
rows = db.execute(
|
||||
"SELECT * FROM scheduled_jobs WHERE enabled = 1 ORDER BY next_run_at IS NULL, next_run_at"
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def mark_job_run(self, job_id: str, next_run_at: str | None = None) -> None:
|
||||
with self._connect() as db:
|
||||
db.execute(
|
||||
"UPDATE scheduled_jobs SET last_run_at = ?, next_run_at = ? WHERE id = ?",
|
||||
(iso_now(), next_run_at, job_id),
|
||||
)
|
||||
|
||||
def add_outbox(self, content: str) -> None:
|
||||
with self._connect() as db:
|
||||
db.execute("INSERT INTO outbox(content, created_at) VALUES (?, ?)", (content, iso_now()))
|
||||
|
||||
def undelivered_outbox(self) -> list[dict[str, Any]]:
|
||||
now = iso_now()
|
||||
with self._connect() as db:
|
||||
rows = db.execute(
|
||||
"SELECT id, content, created_at FROM outbox WHERE delivered_at IS NULL ORDER BY id"
|
||||
).fetchall()
|
||||
db.execute(
|
||||
"UPDATE outbox SET delivered_at = ? WHERE delivered_at IS NULL",
|
||||
(now,),
|
||||
)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _fts_query(query: str) -> str:
|
||||
tokens = [token.replace('"', "") for token in query.split() if token.strip()]
|
||||
return " OR ".join(f'"{token}"' for token in tokens) or '""'
|
||||
|
||||
@staticmethod
|
||||
def _memory_row(row: sqlite3.Row) -> dict[str, Any]:
|
||||
data = dict(row)
|
||||
if "metadata" in data:
|
||||
try:
|
||||
data["metadata"] = json.loads(data["metadata"])
|
||||
except json.JSONDecodeError:
|
||||
data["metadata"] = {}
|
||||
return data
|
||||
Reference in New Issue
Block a user