fixed memmory
This commit is contained in:
@@ -4,8 +4,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.memory.service import MemoryService
|
||||
|
||||
from app.memory.parse import is_identity_question
|
||||
|
||||
MAX_FACTS_IN_CONTEXT = 25
|
||||
PROFILE_KEYS = ("name", "timezone", "language", "notes")
|
||||
PROFILE_KEYS = ("name", "age", "timezone", "language", "notes")
|
||||
|
||||
|
||||
def get_memory_snapshot(db: Session, session_id: int | None = None) -> dict[str, Any]:
|
||||
@@ -48,11 +50,34 @@ def format_memory_context(snapshot: dict[str, Any]) -> str:
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"Правила памяти: "
|
||||
"«запомни» → remember_fact. "
|
||||
"«что ты помнишь» → recall_memories или ответ из снимка выше. "
|
||||
"«запомни» → remember_fact (имя/возраст также пишутся в профиль). "
|
||||
"«кто я» / «сколько мне лет» → ответь из профиля и фактов выше, БЕЗ выдумок. "
|
||||
"Роль персонажа (сын, мать и т.п.) — стиль общения, НЕ биография пользователя. "
|
||||
"Если профиль и факты пусты — честно скажи «не помню» и предложи запомнить. "
|
||||
"«забудь #N» → forget_memory. "
|
||||
"Профиль (имя, timezone) → update_profile. "
|
||||
"Длинный чат — update_session_summary с краткой сводкой темы. "
|
||||
"Не выдумывай факты — только то, что в профиле/фактах или сказал пользователь."
|
||||
"Длинный чат — update_session_summary."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_identity_hint(snapshot: dict[str, Any], user_text: str) -> str:
|
||||
if not is_identity_question(user_text):
|
||||
return ""
|
||||
|
||||
profile = snapshot.get("profile") or {}
|
||||
facts = snapshot.get("facts") or []
|
||||
lines = [
|
||||
"[Вопрос об идентичности пользователя]",
|
||||
"Ответь ТОЛЬКО из данных ниже. Не придумывай роли из сценария персонажа.",
|
||||
]
|
||||
name = (profile.get("name") or "").strip()
|
||||
age = (profile.get("age") or "").strip()
|
||||
if name:
|
||||
lines.append(f"Имя: {name}")
|
||||
if age:
|
||||
lines.append(f"Возраст: {age} лет")
|
||||
for fact in facts:
|
||||
lines.append(f"Факт: {fact.get('content')}")
|
||||
if not name and not age and not facts:
|
||||
lines.append("Данных нет — скажи, что не помнишь.")
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import re
|
||||
|
||||
IDENTITY_QUESTION = re.compile(
|
||||
r"(кто\s+я|как\s+меня\s+зовут|сколько\s+мне\s+лет|"
|
||||
r"что\s+ты\s+(помнишь|знаешь)\s+(обо\s+мне|про\s+меня)|"
|
||||
r"напомни\s+(кто\s+я|про\s+меня))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
NAME_PATTERN = re.compile(
|
||||
r"(?:меня\s+зовут|имя[:\s]+|зовут)\s+([A-Za-zА-Яа-яЁё][A-Za-zА-Яа-яЁё\-]*)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
AGE_PATTERN = re.compile(r"(?:мне\s+(\d{1,3})\s+лет|возраст[:\s]+(\d{1,3}))", re.IGNORECASE)
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
return " ".join(text.casefold().split())
|
||||
|
||||
|
||||
def is_identity_question(text: str) -> bool:
|
||||
return bool(IDENTITY_QUESTION.search(text))
|
||||
|
||||
|
||||
def parse_identity(text: str) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
name_match = NAME_PATTERN.search(text)
|
||||
if name_match:
|
||||
result["name"] = name_match.group(1)
|
||||
age_match = AGE_PATTERN.search(text)
|
||||
if age_match:
|
||||
result["age"] = age_match.group(1) or age_match.group(2)
|
||||
return result
|
||||
|
||||
|
||||
def texts_are_similar(a: str, b: str) -> bool:
|
||||
na, nb = normalize_text(a), normalize_text(b)
|
||||
if na == nb:
|
||||
return True
|
||||
return na in nb or nb in na
|
||||
@@ -2,13 +2,15 @@ import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import MemoryFact, SessionSummary, UserProfile
|
||||
from app.memory.parse import normalize_text, parse_identity, texts_are_similar
|
||||
|
||||
DEFAULT_PROFILE: dict[str, Any] = {
|
||||
"name": "",
|
||||
"age": "",
|
||||
"timezone": "",
|
||||
"language": "ru",
|
||||
"notes": "",
|
||||
@@ -50,6 +52,20 @@ class MemoryService:
|
||||
self.db.commit()
|
||||
return {"ok": True, "profile": current}
|
||||
|
||||
def _find_similar_fact(self, text: str) -> MemoryFact | None:
|
||||
for fact in self.db.scalars(
|
||||
select(MemoryFact).where(MemoryFact.active.is_(True))
|
||||
):
|
||||
if texts_are_similar(fact.content, text):
|
||||
return fact
|
||||
return None
|
||||
|
||||
def _sync_identity_to_profile(self, text: str) -> dict[str, Any] | None:
|
||||
parsed = parse_identity(text)
|
||||
if not parsed:
|
||||
return None
|
||||
return self.update_profile(parsed)
|
||||
|
||||
def remember_fact(
|
||||
self,
|
||||
content: str,
|
||||
@@ -63,26 +79,28 @@ class MemoryService:
|
||||
if not text:
|
||||
raise ValueError("Пустой факт")
|
||||
|
||||
existing = self.db.scalar(
|
||||
select(MemoryFact).where(
|
||||
MemoryFact.active.is_(True),
|
||||
func.lower(MemoryFact.content) == text.lower(),
|
||||
)
|
||||
)
|
||||
profile_sync = self._sync_identity_to_profile(text)
|
||||
|
||||
existing = self._find_similar_fact(text)
|
||||
if existing:
|
||||
if len(text) > len(existing.content):
|
||||
existing.content = text[:2000]
|
||||
existing.category = category or existing.category
|
||||
existing.importance = max(existing.importance, min(5, max(1, importance)))
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
if session_id:
|
||||
existing.session_id = session_id
|
||||
self.db.commit()
|
||||
return {
|
||||
result = {
|
||||
"ok": True,
|
||||
"action": "updated",
|
||||
"memory_id": existing.id,
|
||||
"content": existing.content,
|
||||
"category": existing.category,
|
||||
}
|
||||
if profile_sync:
|
||||
result["profile"] = profile_sync.get("profile")
|
||||
return result
|
||||
|
||||
fact = MemoryFact(
|
||||
category=(category or "fact")[:64],
|
||||
@@ -94,13 +112,16 @@ class MemoryService:
|
||||
self.db.add(fact)
|
||||
self.db.commit()
|
||||
self.db.refresh(fact)
|
||||
return {
|
||||
result = {
|
||||
"ok": True,
|
||||
"action": "created",
|
||||
"memory_id": fact.id,
|
||||
"content": fact.content,
|
||||
"category": fact.category,
|
||||
}
|
||||
if profile_sync:
|
||||
result["profile"] = profile_sync.get("profile")
|
||||
return result
|
||||
|
||||
def recall_memories(
|
||||
self,
|
||||
@@ -118,15 +139,16 @@ class MemoryService:
|
||||
stmt = stmt.where(MemoryFact.active.is_(True))
|
||||
if category:
|
||||
stmt = stmt.where(MemoryFact.category == category)
|
||||
facts = self.db.scalars(stmt.limit(100)).all()
|
||||
if query:
|
||||
pattern = f"%{query.strip()}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
MemoryFact.content.ilike(pattern),
|
||||
MemoryFact.category.ilike(pattern),
|
||||
)
|
||||
)
|
||||
facts = self.db.scalars(stmt.limit(min(limit, 50))).all()
|
||||
qnorm = normalize_text(query)
|
||||
facts = [
|
||||
f
|
||||
for f in facts
|
||||
if qnorm in normalize_text(f.content)
|
||||
or qnorm in normalize_text(f.category)
|
||||
]
|
||||
facts = facts[: min(limit, 50)]
|
||||
return [
|
||||
{
|
||||
"id": f.id,
|
||||
|
||||
Reference in New Issue
Block a user