fixed memmory
This commit is contained in:
@@ -5,6 +5,8 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.base import get_db
|
||||
from app.db.models import ChatSession
|
||||
from app.memory.extract import extract_after_turn
|
||||
from app.memory.service import MemoryService
|
||||
|
||||
router = APIRouter()
|
||||
@@ -26,6 +28,13 @@ class SessionSummaryUpdate(BaseModel):
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class ExtractRequest(BaseModel):
|
||||
session_id: int
|
||||
user_text: str = Field(min_length=1)
|
||||
assistant_text: str = ""
|
||||
force: bool = False
|
||||
|
||||
|
||||
@router.get("/memory")
|
||||
def get_memory_snapshot(
|
||||
session_id: int | None = None,
|
||||
@@ -85,6 +94,23 @@ def forget_fact(memory_id: int, db: Session = Depends(get_db)) -> dict[str, Any]
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/memory/extract")
|
||||
async def extract_memories(
|
||||
payload: ExtractRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
session = db.get(ChatSession, payload.session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return await extract_after_turn(
|
||||
db,
|
||||
payload.session_id,
|
||||
payload.user_text,
|
||||
payload.assistant_text,
|
||||
force=payload.force,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/memory/sessions/{session_id}/summary")
|
||||
def update_session_summary(
|
||||
session_id: int,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.character.service import CharacterService
|
||||
from app.chat.notices import (
|
||||
POMODORO_TOOL_NAMES,
|
||||
@@ -16,6 +17,7 @@ from app.memory.context import (
|
||||
format_memory_context,
|
||||
get_memory_snapshot,
|
||||
)
|
||||
from app.memory.extract import extract_after_turn
|
||||
from app.projects.context import format_projects_context, get_projects_snapshot
|
||||
from app.db.models import ChatSession, Message
|
||||
from app.llm.client import LLMClient
|
||||
@@ -184,7 +186,20 @@ class ChatService:
|
||||
if final_content:
|
||||
self._save_message(session_id, "assistant", final_content)
|
||||
|
||||
yield self._sse("done", {})
|
||||
memory_meta: dict[str, Any] = {}
|
||||
if get_settings().memory_auto_extract:
|
||||
extraction = await extract_after_turn(
|
||||
self.db,
|
||||
session_id,
|
||||
user_text,
|
||||
final_content,
|
||||
)
|
||||
memory_meta = {
|
||||
"memory_extracted": extraction.get("count", 0),
|
||||
"memory_saved": extraction.get("saved", []),
|
||||
}
|
||||
|
||||
yield self._sse("done", memory_meta)
|
||||
return
|
||||
|
||||
yield self._sse("error", {"message": "Too many tool call rounds"})
|
||||
|
||||
@@ -21,6 +21,7 @@ class Settings(BaseSettings):
|
||||
database_url: str = "sqlite:///./data/assistant.db"
|
||||
cors_origins: str = "http://localhost:5173,http://localhost:8080,http://localhost:3000"
|
||||
system_prompt_path: str = "./prompts/assistant.md"
|
||||
memory_auto_extract: bool = True
|
||||
|
||||
# Taiga/Gitea on host (not in Docker) — use host.docker.internal from container
|
||||
taiga_base_url: str = "http://host.docker.internal:9000"
|
||||
|
||||
@@ -70,11 +70,13 @@ class LLMClient:
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.llm.client import LLMClient
|
||||
from app.memory.service import MemoryService
|
||||
from app.projects.structuring import strip_markdown_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SKIP_USER_PATTERN = re.compile(
|
||||
r"^(ок|ok|да|нет|спасибо|thanks|\.{1,3}|👍|\+1)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
EXTRACTION_PROMPT = """
|
||||
Ты извлекаешь долгосрочные факты о пользователе из фрагмента диалога.
|
||||
Ответь ТОЛЬКО JSON без markdown.
|
||||
|
||||
Схема:
|
||||
{
|
||||
"facts": [
|
||||
{"content": "текст факта", "category": "preference|person|habit|project|fact", "importance": 1}
|
||||
],
|
||||
"profile": {"name": "", "age": "", "timezone": "", "notes": ""}
|
||||
}
|
||||
|
||||
Правила:
|
||||
- Сохраняй устойчивое: имя, возраст, предпочтения, привычки, проекты, семья, работа.
|
||||
- НЕ сохраняй: статус помидоро, погоду, разовые команды, ролевую игру, выдумки ассистента.
|
||||
- profile — только поля с новыми значениями (пустые строки не включай).
|
||||
- facts — короткие утверждения от первого лица пользователя («люблю кофе», «меня зовут …»).
|
||||
- Если нечего сохранять — {"facts": [], "profile": {}}.
|
||||
- Не дублируй уже известное (см. текущий профиль и факты ниже).
|
||||
- importance: 5 критично (имя), 4 важно, 3 обычно, 2 мелочь.
|
||||
""".strip()
|
||||
|
||||
|
||||
def _should_skip_extraction(user_text: str) -> bool:
|
||||
text = user_text.strip()
|
||||
if len(text) < 4:
|
||||
return True
|
||||
if SKIP_USER_PATTERN.match(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _call_extractor(
|
||||
user_text: str,
|
||||
assistant_text: str,
|
||||
snapshot: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
profile = snapshot.get("profile") or {}
|
||||
facts = snapshot.get("facts") or []
|
||||
known = [
|
||||
f"Профиль: {json.dumps(profile, ensure_ascii=False)}",
|
||||
"Факты:",
|
||||
*[f"- {f.get('content')}" for f in facts[:30]],
|
||||
]
|
||||
|
||||
llm = LLMClient()
|
||||
result = await llm.complete(
|
||||
[
|
||||
{"role": "system", "content": EXTRACTION_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"\n".join(known)
|
||||
+ "\n\n---\nДиалог:\nПользователь: "
|
||||
+ user_text
|
||||
+ "\nАссистент: "
|
||||
+ (assistant_text[:1500] if assistant_text else "(нет ответа)")
|
||||
),
|
||||
},
|
||||
],
|
||||
temperature=0.2,
|
||||
)
|
||||
raw = strip_markdown_json(result.get("content") or "")
|
||||
if not raw:
|
||||
return {"facts": [], "profile": {}}
|
||||
parsed = json.loads(raw)
|
||||
if not isinstance(parsed, dict):
|
||||
return {"facts": [], "profile": {}}
|
||||
return parsed
|
||||
|
||||
|
||||
async def extract_after_turn(
|
||||
db: Session,
|
||||
session_id: int,
|
||||
user_text: str,
|
||||
assistant_text: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if not force and _should_skip_extraction(user_text):
|
||||
return {"ok": True, "skipped": "short_message", "saved": []}
|
||||
|
||||
memory = MemoryService(db)
|
||||
snapshot = memory.snapshot(session_id)
|
||||
|
||||
try:
|
||||
parsed = await _call_extractor(user_text, assistant_text, snapshot)
|
||||
except (json.JSONDecodeError, Exception) as exc:
|
||||
logger.warning("Memory extraction failed: %s", exc)
|
||||
return {"ok": False, "error": str(exc), "saved": []}
|
||||
|
||||
saved: list[dict[str, Any]] = []
|
||||
|
||||
profile_updates = parsed.get("profile") or {}
|
||||
if isinstance(profile_updates, dict):
|
||||
filtered = {
|
||||
k: str(v).strip()
|
||||
for k, v in profile_updates.items()
|
||||
if v and str(v).strip()
|
||||
}
|
||||
if filtered:
|
||||
memory.update_profile(filtered)
|
||||
saved.append({"type": "profile", "updates": filtered})
|
||||
|
||||
facts = parsed.get("facts") or []
|
||||
if isinstance(facts, list):
|
||||
for item in facts:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
content = (item.get("content") or "").strip()
|
||||
if not content or len(content) < 3:
|
||||
continue
|
||||
try:
|
||||
result = memory.remember_fact(
|
||||
content,
|
||||
category=str(item.get("category") or "fact")[:64],
|
||||
importance=int(item.get("importance") or 3),
|
||||
session_id=session_id,
|
||||
source="auto",
|
||||
)
|
||||
saved.append({"type": "fact", **result})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return {"ok": True, "saved": saved, "count": len(saved)}
|
||||
@@ -218,6 +218,8 @@ class MemoryService:
|
||||
"category": f.category,
|
||||
"content": f.content,
|
||||
"importance": f.importance,
|
||||
"source": f.source,
|
||||
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
|
||||
}
|
||||
for f in facts
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user