Files
Home_assistant/backend/app/memory/service.py
T
2026-06-16 10:07:06 +03:00

330 lines
12 KiB
Python

import asyncio
import json
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.db.models import MemoryFact, SessionSummary, UserProfile
from app.memory.parse import normalize_text, parse_identity, texts_are_similar
DEFAULT_PROFILE: dict[str, Any] = {
"name": "",
"age": "",
"timezone": "",
"language": "ru",
"notes": "",
}
logger = logging.getLogger(__name__)
class MemoryService:
def __init__(self, db: Session, user_id: int):
self.db = db
self.user_id = user_id
@staticmethod
def _run_async(coro):
"""Run coroutine from sync code; safe inside FastAPI's running event loop."""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
with ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
@staticmethod
def _schedule_rag(coro) -> None:
def runner() -> None:
try:
asyncio.run(coro)
except Exception:
logger.exception("RAG background task failed")
threading.Thread(target=runner, daemon=True).start()
@staticmethod
def _rag_fact_payload(fact: MemoryFact) -> dict[str, Any]:
return {
"fact_id": int(fact.id),
"user_id": int(fact.user_id),
"content": fact.content,
"category": fact.category,
"importance": int(fact.importance),
"active": bool(fact.active),
}
def get_profile(self) -> dict[str, Any]:
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
if not row:
return dict(DEFAULT_PROFILE)
try:
data = json.loads(row.data_json or "{}")
except json.JSONDecodeError:
data = {}
merged = dict(DEFAULT_PROFILE)
merged.update(data)
return merged
def update_profile(self, updates: dict[str, Any]) -> dict[str, Any]:
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
if not row:
row = UserProfile(user_id=self.user_id, data_json="{}")
self.db.add(row)
self.db.flush()
current = self.get_profile()
for key, value in updates.items():
if value is None:
current.pop(key, None)
else:
current[key] = value
row.data_json = json.dumps(current, ensure_ascii=False)
row.updated_at = datetime.now(timezone.utc)
self.db.commit()
return {"ok": True, "profile": current}
def _find_similar_fact(self, text: str) -> MemoryFact | None:
for fact in self.db.scalars(
select(MemoryFact).where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True))
):
if texts_are_similar(fact.content, text):
return fact
return None
def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None:
parsed = parse_identity(text)
if not parsed:
return None
return self.update_profile(parsed)
def remember_fact(
self,
content: str,
*,
category: str = "fact",
source: str = "user",
session_id: int | None = None,
importance: int = 3,
) -> dict[str, Any]:
text = content.strip()
if not text:
raise ValueError("Пустой факт")
profile_sync = self._sync_identity_to_profile(text)
existing = self._find_similar_fact(text)
if existing:
if len(text) > len(existing.content):
existing.content = text[:2000]
existing.category = category or existing.category
existing.importance = max(existing.importance, min(5, max(1, importance)))
existing.updated_at = datetime.now(timezone.utc)
if session_id:
existing.session_id = session_id
self.db.commit()
from app.rag.ingest import index_memory_fact
self._schedule_rag(index_memory_fact(**self._rag_fact_payload(existing)))
result = {
"ok": True,
"action": "updated",
"memory_id": existing.id,
"content": existing.content,
"category": existing.category,
}
if profile_sync:
result["profile"] = profile_sync.get("profile")
return result
fact = MemoryFact(
user_id=self.user_id,
category=(category or "fact")[:64],
content=text[:2000],
source=source[:32],
session_id=session_id,
importance=min(5, max(1, importance)),
)
self.db.add(fact)
self.db.commit()
self.db.refresh(fact)
from app.rag.ingest import index_memory_fact
self._schedule_rag(index_memory_fact(**self._rag_fact_payload(fact)))
result = {
"ok": True,
"action": "created",
"memory_id": fact.id,
"content": fact.content,
"category": fact.category,
}
if profile_sync:
result["profile"] = profile_sync.get("profile")
return result
def recall_memories(
self,
*,
query: str | None = None,
category: str | None = None,
limit: int = 20,
active_only: bool = True,
) -> list[dict[str, Any]]:
stmt = select(MemoryFact).where(MemoryFact.user_id == self.user_id).order_by(
MemoryFact.importance.desc(),
MemoryFact.updated_at.desc(),
)
if active_only:
stmt = stmt.where(MemoryFact.active.is_(True))
if category:
stmt = stmt.where(MemoryFact.category == category)
facts = self.db.scalars(stmt.limit(100)).all()
if query:
qnorm = normalize_text(query)
facts = [
f
for f in facts
if qnorm in normalize_text(f.content)
or qnorm in normalize_text(f.category)
]
facts = facts[: min(limit, 50)]
return [
{
"id": f.id,
"category": f.category,
"content": f.content,
"importance": f.importance,
"source": f.source,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in facts
]
def forget_memory(self, memory_id: int) -> dict[str, Any]:
fact = self.db.get(MemoryFact, memory_id)
if not fact or fact.user_id != self.user_id:
raise ValueError(f"Память #{memory_id} не найдена")
fact.active = False
fact.updated_at = datetime.now(timezone.utc)
self.db.commit()
from app.rag.ingest import deactivate_memory_fact
self._schedule_rag(deactivate_memory_fact(memory_id))
return {"ok": True, "memory_id": memory_id, "forgotten": fact.content}
def get_active_facts(self, limit: int = 25) -> list[MemoryFact]:
return list(
self.db.scalars(
select(MemoryFact)
.where(MemoryFact.user_id == self.user_id, MemoryFact.active.is_(True))
.order_by(MemoryFact.importance.desc(), MemoryFact.updated_at.desc())
.limit(limit)
).all()
)
def get_session_summary(self, session_id: int) -> SessionSummary | None:
from app.db.models import ChatSession
session = self.db.get(ChatSession, session_id)
if not session or session.user_id != self.user_id:
return None
return self.db.scalar(
select(SessionSummary).where(SessionSummary.session_id == session_id)
)
def update_session_summary(
self,
session_id: int,
summary: str,
*,
message_count: int = 0,
) -> dict[str, Any]:
text = summary.strip()
if not text:
raise ValueError("Пустая сводка")
from app.db.models import ChatSession
session = self.db.get(ChatSession, session_id)
if not session or session.user_id != self.user_id:
raise ValueError("Session not found")
row = self.db.scalar(
select(SessionSummary).where(SessionSummary.session_id == session_id)
)
if not row:
row = SessionSummary(session_id=session_id)
self.db.add(row)
row.summary = text[:4000]
row.message_count = message_count
row.updated_at = datetime.now(timezone.utc)
self.db.commit()
from app.rag.ingest import index_session_summary
self._schedule_rag(index_session_summary(session_id, row.summary))
return {"ok": True, "session_id": session_id, "summary": row.summary}
def snapshot(self, session_id: int | None = None, query: str | None = None) -> dict[str, Any]:
from app.config import get_settings
from app.settings.service import SettingsService
settings = get_settings()
svc = SettingsService(self.db)
rag_on = bool(svc.get_effective("rag_enabled")) and settings.rag_enabled
facts_payload: list[dict[str, Any]]
total_facts = len(self.get_active_facts(limit=500))
if rag_on and (query or "").strip():
async def _load() -> list[dict[str, Any]]:
from app.rag.retriever import retrieve_memory_facts
top_k = int(svc.get_effective("rag_top_k"))
return await retrieve_memory_facts(query or "", user_id=self.user_id, top_k=top_k)
try:
rag_facts = self._run_async(_load())
except Exception:
rag_facts = []
if rag_facts:
facts_payload = rag_facts
else:
facts = self.get_active_facts(limit=settings.memory_facts_in_context)
facts_payload = [
{
"id": f.id,
"category": f.category,
"content": f.content,
"importance": f.importance,
"source": f.source,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in facts
]
else:
facts = self.get_active_facts(limit=settings.memory_facts_in_context)
facts_payload = [
{
"id": f.id,
"category": f.category,
"content": f.content,
"importance": f.importance,
"source": f.source,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in facts
]
summary_row = self.get_session_summary(session_id) if session_id else None
return {
"profile": self.get_profile(),
"facts": facts_payload,
"session_summary": summary_row.summary if summary_row else "",
"total_facts": total_facts,
}