import json from datetime import datetime, timezone from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.db.models import MemoryFact, SessionSummary, UserProfile from app.memory.parse import normalize_text, parse_identity, texts_are_similar DEFAULT_PROFILE: dict[str, Any] = { "name": "", "age": "", "timezone": "", "language": "ru", "notes": "", } class MemoryService: def __init__(self, db: Session): self.db = db def get_profile(self) -> dict[str, Any]: row = self.db.scalar(select(UserProfile).limit(1)) if not row: return dict(DEFAULT_PROFILE) try: data = json.loads(row.data_json or "{}") except json.JSONDecodeError: data = {} merged = dict(DEFAULT_PROFILE) merged.update(data) return merged def update_profile(self, updates: dict[str, Any]) -> dict[str, Any]: row = self.db.scalar(select(UserProfile).limit(1)) if not row: row = UserProfile(data_json="{}") self.db.add(row) self.db.flush() current = self.get_profile() for key, value in updates.items(): if value is None: current.pop(key, None) else: current[key] = value row.data_json = json.dumps(current, ensure_ascii=False) row.updated_at = datetime.now(timezone.utc) self.db.commit() return {"ok": True, "profile": current} def _find_similar_fact(self, text: str) -> MemoryFact | None: for fact in self.db.scalars( select(MemoryFact).where(MemoryFact.active.is_(True)) ): if texts_are_similar(fact.content, text): return fact return None def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None: parsed = parse_identity(text) if not parsed: return None return self.update_profile(parsed) def remember_fact( self, content: str, *, category: str = "fact", source: str = "user", session_id: int | None = None, importance: int = 3, ) -> dict[str, Any]: text = content.strip() if not text: raise ValueError("Пустой факт") profile_sync = self._sync_identity_to_profile(text) existing = self._find_similar_fact(text) if existing: if len(text) > len(existing.content): existing.content = text[:2000] existing.category = category or existing.category existing.importance = max(existing.importance, min(5, max(1, importance))) existing.updated_at = datetime.now(timezone.utc) if session_id: existing.session_id = session_id self.db.commit() result = { "ok": True, "action": "updated", "memory_id": existing.id, "content": existing.content, "category": existing.category, } if profile_sync: result["profile"] = profile_sync.get("profile") return result fact = MemoryFact( category=(category or "fact")[:64], content=text[:2000], source=source[:32], session_id=session_id, importance=min(5, max(1, importance)), ) self.db.add(fact) self.db.commit() self.db.refresh(fact) result = { "ok": True, "action": "created", "memory_id": fact.id, "content": fact.content, "category": fact.category, } if profile_sync: result["profile"] = profile_sync.get("profile") return result def recall_memories( self, *, query: str | None = None, category: str | None = None, limit: int = 20, active_only: bool = True, ) -> list[dict[str, Any]]: stmt = select(MemoryFact).order_by( MemoryFact.importance.desc(), MemoryFact.updated_at.desc(), ) if active_only: stmt = stmt.where(MemoryFact.active.is_(True)) if category: stmt = stmt.where(MemoryFact.category == category) facts = self.db.scalars(stmt.limit(100)).all() if query: qnorm = normalize_text(query) facts = [ f for f in facts if qnorm in normalize_text(f.content) or qnorm in normalize_text(f.category) ] facts = facts[: min(limit, 50)] return [ { "id": f.id, "category": f.category, "content": f.content, "importance": f.importance, "source": f.source, "updated_at": f.updated_at.isoformat() if f.updated_at else None, } for f in facts ] def forget_memory(self, memory_id: int) -> dict[str, Any]: fact = self.db.get(MemoryFact, memory_id) if not fact: raise ValueError(f"Память #{memory_id} не найдена") fact.active = False fact.updated_at = datetime.now(timezone.utc) self.db.commit() return {"ok": True, "memory_id": memory_id, "forgotten": fact.content} def get_active_facts(self, limit: int = 25) -> list[MemoryFact]: return list( self.db.scalars( select(MemoryFact) .where(MemoryFact.active.is_(True)) .order_by(MemoryFact.importance.desc(), MemoryFact.updated_at.desc()) .limit(limit) ).all() ) def get_session_summary(self, session_id: int) -> SessionSummary | None: return self.db.scalar( select(SessionSummary).where(SessionSummary.session_id == session_id) ) def update_session_summary( self, session_id: int, summary: str, *, message_count: int = 0, ) -> dict[str, Any]: text = summary.strip() if not text: raise ValueError("Пустая сводка") row = self.get_session_summary(session_id) if not row: row = SessionSummary(session_id=session_id) self.db.add(row) row.summary = text[:4000] row.message_count = message_count row.updated_at = datetime.now(timezone.utc) self.db.commit() return {"ok": True, "session_id": session_id, "summary": row.summary} def snapshot(self, session_id: int | None = None) -> dict[str, Any]: facts = self.get_active_facts() summary_row = self.get_session_summary(session_id) if session_id else None return { "profile": self.get_profile(), "facts": [ { "id": f.id, "category": f.category, "content": f.content, "importance": f.importance, } for f in facts ], "session_summary": summary_row.summary if summary_row else "", "total_facts": len(facts), }