312 lines
11 KiB
Python
312 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db.models import MemoryFact, SessionSummary, UserProfile
|
|
from app.memory.parse import normalize_text, parse_identity, texts_are_similar
|
|
|
|
DEFAULT_PROFILE: dict[str, Any] = {
|
|
"name": "",
|
|
"age": "",
|
|
"timezone": "",
|
|
"language": "ru",
|
|
"notes": "",
|
|
}
|
|
|
|
|
|
class MemoryService:
|
|
def __init__(self, db: Session, user_id: int):
|
|
self.db = db
|
|
self.user_id = user_id
|
|
|
|
@staticmethod
|
|
def _run_async(coro):
|
|
"""Run coroutine from sync code; safe inside FastAPI's running event loop."""
|
|
try:
|
|
asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
return asyncio.run(coro)
|
|
with ThreadPoolExecutor(max_workers=1) as pool:
|
|
return pool.submit(asyncio.run, coro).result()
|
|
|
|
@staticmethod
|
|
def _schedule_rag(coro) -> None:
|
|
def runner() -> None:
|
|
asyncio.run(coro)
|
|
|
|
threading.Thread(target=runner, daemon=True).start()
|
|
|
|
def get_profile(self) -> dict[str, Any]:
|
|
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
|
|
if not row:
|
|
return dict(DEFAULT_PROFILE)
|
|
try:
|
|
data = json.loads(row.data_json or "{}")
|
|
except json.JSONDecodeError:
|
|
data = {}
|
|
merged = dict(DEFAULT_PROFILE)
|
|
merged.update(data)
|
|
return merged
|
|
|
|
def update_profile(self, updates: dict[str, Any]) -> dict[str, Any]:
|
|
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
|
|
if not row:
|
|
row = UserProfile(user_id=self.user_id, data_json="{}")
|
|
self.db.add(row)
|
|
self.db.flush()
|
|
|
|
current = self.get_profile()
|
|
for key, value in updates.items():
|
|
if value is None:
|
|
current.pop(key, None)
|
|
else:
|
|
current[key] = value
|
|
|
|
row.data_json = json.dumps(current, ensure_ascii=False)
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
return {"ok": True, "profile": current}
|
|
|
|
def _find_similar_fact(self, text: str) -> MemoryFact | None:
|
|
for fact in self.db.scalars(
|
|
select(MemoryFact).where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True))
|
|
):
|
|
if texts_are_similar(fact.content, text):
|
|
return fact
|
|
return None
|
|
|
|
def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None:
|
|
parsed = parse_identity(text)
|
|
if not parsed:
|
|
return None
|
|
return self.update_profile(parsed)
|
|
|
|
def remember_fact(
|
|
self,
|
|
content: str,
|
|
*,
|
|
category: str = "fact",
|
|
source: str = "user",
|
|
session_id: int | None = None,
|
|
importance: int = 3,
|
|
) -> dict[str, Any]:
|
|
text = content.strip()
|
|
if not text:
|
|
raise ValueError("Пустой факт")
|
|
|
|
profile_sync = self._sync_identity_to_profile(text)
|
|
|
|
existing = self._find_similar_fact(text)
|
|
if existing:
|
|
if len(text) > len(existing.content):
|
|
existing.content = text[:2000]
|
|
existing.category = category or existing.category
|
|
existing.importance = max(existing.importance, min(5, max(1, importance)))
|
|
existing.updated_at = datetime.now(timezone.utc)
|
|
if session_id:
|
|
existing.session_id = session_id
|
|
self.db.commit()
|
|
from app.rag.ingest import index_memory_fact
|
|
|
|
self._schedule_rag(index_memory_fact(existing))
|
|
result = {
|
|
"ok": True,
|
|
"action": "updated",
|
|
"memory_id": existing.id,
|
|
"content": existing.content,
|
|
"category": existing.category,
|
|
}
|
|
if profile_sync:
|
|
result["profile"] = profile_sync.get("profile")
|
|
return result
|
|
|
|
fact = MemoryFact(
|
|
user_id=self.user_id,
|
|
category=(category or "fact")[:64],
|
|
content=text[:2000],
|
|
source=source[:32],
|
|
session_id=session_id,
|
|
importance=min(5, max(1, importance)),
|
|
)
|
|
self.db.add(fact)
|
|
self.db.commit()
|
|
self.db.refresh(fact)
|
|
from app.rag.ingest import index_memory_fact
|
|
|
|
self._schedule_rag(index_memory_fact(fact))
|
|
result = {
|
|
"ok": True,
|
|
"action": "created",
|
|
"memory_id": fact.id,
|
|
"content": fact.content,
|
|
"category": fact.category,
|
|
}
|
|
if profile_sync:
|
|
result["profile"] = profile_sync.get("profile")
|
|
return result
|
|
|
|
def recall_memories(
|
|
self,
|
|
*,
|
|
query: str | None = None,
|
|
category: str | None = None,
|
|
limit: int = 20,
|
|
active_only: bool = True,
|
|
) -> list[dict[str, Any]]:
|
|
stmt = select(MemoryFact).where(MemoryFact.user_id == self.user_id).order_by(
|
|
MemoryFact.importance.desc(),
|
|
MemoryFact.updated_at.desc(),
|
|
)
|
|
if active_only:
|
|
stmt = stmt.where(MemoryFact.active.is_(True))
|
|
if category:
|
|
stmt = stmt.where(MemoryFact.category == category)
|
|
facts = self.db.scalars(stmt.limit(100)).all()
|
|
if query:
|
|
qnorm = normalize_text(query)
|
|
facts = [
|
|
f
|
|
for f in facts
|
|
if qnorm in normalize_text(f.content)
|
|
or qnorm in normalize_text(f.category)
|
|
]
|
|
facts = facts[: min(limit, 50)]
|
|
return [
|
|
{
|
|
"id": f.id,
|
|
"category": f.category,
|
|
"content": f.content,
|
|
"importance": f.importance,
|
|
"source": f.source,
|
|
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
|
}
|
|
for f in facts
|
|
]
|
|
|
|
def forget_memory(self, memory_id: int) -> dict[str, Any]:
|
|
fact = self.db.get(MemoryFact, memory_id)
|
|
if not fact or fact.user_id != self.user_id:
|
|
raise ValueError(f"Память #{memory_id} не найдена")
|
|
fact.active = False
|
|
fact.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
from app.rag.ingest import deactivate_memory_fact
|
|
|
|
self._schedule_rag(deactivate_memory_fact(memory_id))
|
|
return {"ok": True, "memory_id": memory_id, "forgotten": fact.content}
|
|
|
|
def get_active_facts(self, limit: int = 25) -> list[MemoryFact]:
|
|
return list(
|
|
self.db.scalars(
|
|
select(MemoryFact)
|
|
.where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True))
|
|
.order_by(MemoryFact.importance.desc(), MemoryFact.updated_at.desc())
|
|
.limit(limit)
|
|
).all()
|
|
)
|
|
|
|
def get_session_summary(self, session_id: int) -> SessionSummary | None:
|
|
from app.db.models import ChatSession
|
|
|
|
session = self.db.get(ChatSession, session_id)
|
|
if not session or session.user_id != self.user_id:
|
|
return None
|
|
return self.db.scalar(
|
|
select(SessionSummary).where(SessionSummary.session_id == session_id)
|
|
)
|
|
|
|
def update_session_summary(
|
|
self,
|
|
session_id: int,
|
|
summary: str,
|
|
*,
|
|
message_count: int = 0,
|
|
) -> dict[str, Any]:
|
|
text = summary.strip()
|
|
if not text:
|
|
raise ValueError("Пустая сводка")
|
|
|
|
from app.db.models import ChatSession
|
|
|
|
session = self.db.get(ChatSession, session_id)
|
|
if not session or session.user_id != self.user_id:
|
|
raise ValueError("Session not found")
|
|
|
|
row = self.db.scalar(
|
|
select(SessionSummary).where(SessionSummary.session_id == session_id)
|
|
)
|
|
if not row:
|
|
row = SessionSummary(session_id=session_id)
|
|
self.db.add(row)
|
|
|
|
row.summary = text[:4000]
|
|
row.message_count = message_count
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
from app.rag.ingest import index_session_summary
|
|
|
|
self._schedule_rag(index_session_summary(session_id, row.summary))
|
|
return {"ok": True, "session_id": session_id, "summary": row.summary}
|
|
|
|
def snapshot(self, session_id: int | None = None, query: str | None = None) -> dict[str, Any]:
|
|
from app.config import get_settings
|
|
from app.settings.service import SettingsService
|
|
|
|
settings = get_settings()
|
|
svc = SettingsService(self.db)
|
|
rag_on = bool(svc.get_effective("rag_enabled")) and settings.rag_enabled
|
|
facts_payload: list[dict[str, Any]]
|
|
total_facts = len(self.get_active_facts(limit=500))
|
|
if rag_on and (query or "").strip():
|
|
async def _load() -> list[dict[str, Any]]:
|
|
from app.rag.retriever import retrieve_memory_facts
|
|
|
|
top_k = int(svc.get_effective("rag_top_k"))
|
|
return await retrieve_memory_facts(query or "", user_id=self.user_id, top_k=top_k)
|
|
|
|
try:
|
|
rag_facts = self._run_async(_load())
|
|
except Exception:
|
|
rag_facts = []
|
|
if rag_facts:
|
|
facts_payload = rag_facts
|
|
else:
|
|
facts = self.get_active_facts(limit=settings.memory_facts_in_context)
|
|
facts_payload = [
|
|
{
|
|
"id": f.id,
|
|
"category": f.category,
|
|
"content": f.content,
|
|
"importance": f.importance,
|
|
"source": f.source,
|
|
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
|
}
|
|
for f in facts
|
|
]
|
|
else:
|
|
facts = self.get_active_facts(limit=settings.memory_facts_in_context)
|
|
facts_payload = [
|
|
{
|
|
"id": f.id,
|
|
"category": f.category,
|
|
"content": f.content,
|
|
"importance": f.importance,
|
|
"source": f.source,
|
|
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
|
}
|
|
for f in facts
|
|
]
|
|
summary_row = self.get_session_summary(session_id) if session_id else None
|
|
return {
|
|
"profile": self.get_profile(),
|
|
"facts": facts_payload,
|
|
"session_summary": summary_row.summary if summary_row else "",
|
|
"total_facts": total_facts,
|
|
}
|