import asyncio import json import threading from datetime import datetime, timezone from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import MemoryFact, SessionSummary, UserProfile from app.memory.parse import normalize_text, parse_identity, texts_are_similar DEFAULT_PROFILE: dict[str, Any] = { "name": "", "age": "", "timezone": "", "language": "ru", "notes": "", } class MemoryService: def __init__(self, db: Session, user_id: int): self.db = db self.user_id = user_id @staticmethod def _schedule_rag(coro) -> None: def runner() -> None: asyncio.run(coro) threading.Thread(target=runner, daemon=True).start() def get_profile(self) -> dict[str, Any]: row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1)) if not row: return dict(DEFAULT_PROFILE) try: data = json.loads(row.data_json or "{}") except json.JSONDecodeError: data = {} merged = dict(DEFAULT_PROFILE) merged.update(data) return merged def update_profile(self, updates: dict[str, Any]) -> dict[str, Any]: row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1)) if not row: row = UserProfile(user_id=self.user_id, data_json="{}") self.db.add(row) self.db.flush() current = self.get_profile() for key, value in updates.items(): if value is None: current.pop(key, None) else: current[key] = value row.data_json = json.dumps(current, ensure_ascii=False) row.updated_at = datetime.now(timezone.utc) self.db.commit() return {"ok": True, "profile": current} def _find_similar_fact(self, text: str) -> MemoryFact | None: for fact in self.db.scalars( select(MemoryFact).where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True)) ): if texts_are_similar(fact.content, text): return fact return None def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None: parsed = parse_identity(text) if not parsed: return None return self.update_profile(parsed) def remember_fact( self, content: str, *, category: str = "fact", source: str = "user", session_id: int | None = None, importance: int = 3, ) -> dict[str, Any]: text = content.strip() if not text: raise ValueError("Пустой факт") profile_sync = self._sync_identity_to_profile(text) existing = self._find_similar_fact(text) if existing: if len(text) > len(existing.content): existing.content = text[:2000] existing.category = category or existing.category existing.importance = max(existing.importance, min(5, max(1, importance))) existing.updated_at = datetime.now(timezone.utc) if session_id: existing.session_id = session_id self.db.commit() from app.rag.ingest import index_memory_fact self._schedule_rag(index_memory_fact(existing)) result = { "ok": True, "action": "updated", "memory_id": existing.id, "content": existing.content, "category": existing.category, } if profile_sync: result["profile"] = profile_sync.get("profile") return result fact = MemoryFact( user_id=self.user_id, category=(category or "fact")[:64], content=text[:2000], source=source[:32], session_id=session_id, importance=min(5, max(1, importance)), ) self.db.add(fact) self.db.commit() self.db.refresh(fact) from app.rag.ingest import index_memory_fact self._schedule_rag(index_memory_fact(fact)) result = { "ok": True, "action": "created", "memory_id": fact.id, "content": fact.content, "category": fact.category, } if profile_sync: result["profile"] = profile_sync.get("profile") return result def recall_memories( self, *, query: str | None = None, category: str | None = None, limit: int = 20, active_only: bool = True, ) -> list[dict[str, Any]]: stmt = select(MemoryFact).where(MemoryFact.user_id == self.user_id).order_by( MemoryFact.importance.desc(), MemoryFact.updated_at.desc(), ) if active_only: stmt = stmt.where(MemoryFact.active.is_(True)) if category: stmt = stmt.where(MemoryFact.category == category) facts = self.db.scalars(stmt.limit(100)).all() if query: qnorm = normalize_text(query) facts = [ f for f in facts if qnorm in normalize_text(f.content) or qnorm in normalize_text(f.category) ] facts = facts[: min(limit, 50)] return [ { "id": f.id, "category": f.category, "content": f.content, "importance": f.importance, "source": f.source, "updated_at": f.updated_at.isoformat() if f.updated_at else None, } for f in facts ] def forget_memory(self, memory_id: int) -> dict[str, Any]: fact = self.db.get(MemoryFact, memory_id) if not fact or fact.user_id != self.user_id: raise ValueError(f"Память #{memory_id} не найдена") fact.active = False fact.updated_at = datetime.now(timezone.utc) self.db.commit() from app.rag.ingest import deactivate_memory_fact self._schedule_rag(deactivate_memory_fact(memory_id)) return {"ok": True, "memory_id": memory_id, "forgotten": fact.content} def get_active_facts(self, limit: int = 25) -> list[MemoryFact]: return list( self.db.scalars( select(MemoryFact) .where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True)) .order_by(MemoryFact.importance.desc(), MemoryFact.updated_at.desc()) .limit(limit) ).all() ) def get_session_summary(self, session_id: int) -> SessionSummary | None: from app.db.models import ChatSession session = self.db.get(ChatSession, session_id) if not session or session.user_id != self.user_id: return None return self.db.scalar( select(SessionSummary).where(SessionSummary.session_id == session_id) ) def update_session_summary( self, session_id: int, summary: str, *, message_count: int = 0, ) -> dict[str, Any]: text = summary.strip() if not text: raise ValueError("Пустая сводка") from app.db.models import ChatSession session = self.db.get(ChatSession, session_id) if not session or session.user_id != self.user_id: raise ValueError("Session not found") row = self.db.scalar( select(SessionSummary).where(SessionSummary.session_id == session_id) ) if not row: row = SessionSummary(session_id=session_id) self.db.add(row) row.summary = text[:4000] row.message_count = message_count row.updated_at = datetime.now(timezone.utc) self.db.commit() from app.rag.ingest import index_session_summary self._schedule_rag(index_session_summary(session_id, row.summary)) return {"ok": True, "session_id": session_id, "summary": row.summary} def snapshot(self, session_id: int | None = None, query: str | None = None) -> dict[str, Any]: from app.config import get_settings from app.settings.service import SettingsService settings = get_settings() svc = SettingsService(self.db) rag_on = bool(svc.get_effective("rag_enabled")) and settings.rag_enabled facts_payload: list[dict[str, Any]] total_facts = len(self.get_active_facts(limit=500)) if rag_on and (query or "").strip(): async def _load() -> list[dict[str, Any]]: from app.rag.retriever import retrieve_memory_facts top_k = int(svc.get_effective("rag_top_k")) return await retrieve_memory_facts(query or "", user_id=self.user_id, top_k=top_k) try: rag_facts = asyncio.run(_load()) except Exception: rag_facts = [] if rag_facts: facts_payload = rag_facts else: facts = self.get_active_facts(limit=settings.memory_facts_in_context) facts_payload = [ { "id": f.id, "category": f.category, "content": f.content, "importance": f.importance, "source": f.source, "updated_at": f.updated_at.isoformat() if f.updated_at else None, } for f in facts ] else: facts = self.get_active_facts(limit=settings.memory_facts_in_context) facts_payload = [ { "id": f.id, "category": f.category, "content": f.content, "importance": f.importance, "source": f.source, "updated_at": f.updated_at.isoformat() if f.updated_at else None, } for f in facts ] summary_row = self.get_session_summary(session_id) if session_id else None return { "profile": self.get_profile(), "facts": facts_payload, "session_summary": summary_row.summary if summary_row else "", "total_facts": total_facts, }