Files
Home_assistant/backend/app/memory/service.py
T
2026-06-10 08:32:20 +03:00

227 lines
7.2 KiB
Python

import json
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.db.models import MemoryFact, SessionSummary, UserProfile
from app.memory.parse import normalize_text, parse_identity, texts_are_similar
DEFAULT_PROFILE: dict[str, Any] = {
"name": "",
"age": "",
"timezone": "",
"language": "ru",
"notes": "",
}
class MemoryService:
def __init__(self, db: Session):
self.db = db
def get_profile(self) -> dict[str, Any]:
row = self.db.scalar(select(UserProfile).limit(1))
if not row:
return dict(DEFAULT_PROFILE)
try:
data = json.loads(row.data_json or "{}")
except json.JSONDecodeError:
data = {}
merged = dict(DEFAULT_PROFILE)
merged.update(data)
return merged
def update_profile(self, updates: dict[str, Any]) -> dict[str, Any]:
row = self.db.scalar(select(UserProfile).limit(1))
if not row:
row = UserProfile(data_json="{}")
self.db.add(row)
self.db.flush()
current = self.get_profile()
for key, value in updates.items():
if value is None:
current.pop(key, None)
else:
current[key] = value
row.data_json = json.dumps(current, ensure_ascii=False)
row.updated_at = datetime.now(timezone.utc)
self.db.commit()
return {"ok": True, "profile": current}
def _find_similar_fact(self, text: str) -> MemoryFact | None:
for fact in self.db.scalars(
select(MemoryFact).where(MemoryFact.active.is_(True))
):
if texts_are_similar(fact.content, text):
return fact
return None
def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None:
parsed = parse_identity(text)
if not parsed:
return None
return self.update_profile(parsed)
def remember_fact(
self,
content: str,
*,
category: str = "fact",
source: str = "user",
session_id: int | None = None,
importance: int = 3,
) -> dict[str, Any]:
text = content.strip()
if not text:
raise ValueError("Пустой факт")
profile_sync = self._sync_identity_to_profile(text)
existing = self._find_similar_fact(text)
if existing:
if len(text) > len(existing.content):
existing.content = text[:2000]
existing.category = category or existing.category
existing.importance = max(existing.importance, min(5, max(1, importance)))
existing.updated_at = datetime.now(timezone.utc)
if session_id:
existing.session_id = session_id
self.db.commit()
result = {
"ok": True,
"action": "updated",
"memory_id": existing.id,
"content": existing.content,
"category": existing.category,
}
if profile_sync:
result["profile"] = profile_sync.get("profile")
return result
fact = MemoryFact(
category=(category or "fact")[:64],
content=text[:2000],
source=source[:32],
session_id=session_id,
importance=min(5, max(1, importance)),
)
self.db.add(fact)
self.db.commit()
self.db.refresh(fact)
result = {
"ok": True,
"action": "created",
"memory_id": fact.id,
"content": fact.content,
"category": fact.category,
}
if profile_sync:
result["profile"] = profile_sync.get("profile")
return result
def recall_memories(
self,
*,
query: str | None = None,
category: str | None = None,
limit: int = 20,
active_only: bool = True,
) -> list[dict[str, Any]]:
stmt = select(MemoryFact).order_by(
MemoryFact.importance.desc(),
MemoryFact.updated_at.desc(),
)
if active_only:
stmt = stmt.where(MemoryFact.active.is_(True))
if category:
stmt = stmt.where(MemoryFact.category == category)
facts = self.db.scalars(stmt.limit(100)).all()
if query:
qnorm = normalize_text(query)
facts = [
f
for f in facts
if qnorm in normalize_text(f.content)
or qnorm in normalize_text(f.category)
]
facts = facts[: min(limit, 50)]
return [
{
"id": f.id,
"category": f.category,
"content": f.content,
"importance": f.importance,
"source": f.source,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in facts
]
def forget_memory(self, memory_id: int) -> dict[str, Any]:
fact = self.db.get(MemoryFact, memory_id)
if not fact:
raise ValueError(f"Память #{memory_id} не найдена")
fact.active = False
fact.updated_at = datetime.now(timezone.utc)
self.db.commit()
return {"ok": True, "memory_id": memory_id, "forgotten": fact.content}
def get_active_facts(self, limit: int = 25) -> list[MemoryFact]:
return list(
self.db.scalars(
select(MemoryFact)
.where(MemoryFact.active.is_(True))
.order_by(MemoryFact.importance.desc(), MemoryFact.updated_at.desc())
.limit(limit)
).all()
)
def get_session_summary(self, session_id: int) -> SessionSummary | None:
return self.db.scalar(
select(SessionSummary).where(SessionSummary.session_id == session_id)
)
def update_session_summary(
self,
session_id: int,
summary: str,
*,
message_count: int = 0,
) -> dict[str, Any]:
text = summary.strip()
if not text:
raise ValueError("Пустая сводка")
row = self.get_session_summary(session_id)
if not row:
row = SessionSummary(session_id=session_id)
self.db.add(row)
row.summary = text[:4000]
row.message_count = message_count
row.updated_at = datetime.now(timezone.utc)
self.db.commit()
return {"ok": True, "session_id": session_id, "summary": row.summary}
def snapshot(self, session_id: int | None = None) -> dict[str, Any]:
facts = self.get_active_facts()
summary_row = self.get_session_summary(session_id) if session_id else None
return {
"profile": self.get_profile(),
"facts": [
{
"id": f.id,
"category": f.category,
"content": f.content,
"importance": f.importance,
}
for f in facts
],
"session_summary": summary_row.summary if summary_row else "",
"total_facts": len(facts),
}