fix КФП

This commit is contained in:
2026-06-16 10:07:06 +03:00
parent b1506f8695
commit 70910b82d2
3 changed files with 47 additions and 14 deletions
+20 -2
View File
@@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import logging
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -20,6 +21,9 @@ DEFAULT_PROFILE: dict[str, Any] = {
} }
logger = logging.getLogger(__name__)
class MemoryService: class MemoryService:
def __init__(self, db: Session, user_id: int): def __init__(self, db: Session, user_id: int):
self.db = db self.db = db
@@ -38,10 +42,24 @@ class MemoryService:
@staticmethod @staticmethod
def _schedule_rag(coro) -> None: def _schedule_rag(coro) -> None:
def runner() -> None: def runner() -> None:
try:
asyncio.run(coro) asyncio.run(coro)
except Exception:
logger.exception("RAG background task failed")
threading.Thread(target=runner, daemon=True).start() threading.Thread(target=runner, daemon=True).start()
@staticmethod
def _rag_fact_payload(fact: MemoryFact) -> dict[str, Any]:
return {
"fact_id": int(fact.id),
"user_id": int(fact.user_id),
"content": fact.content,
"category": fact.category,
"importance": int(fact.importance),
"active": bool(fact.active),
}
def get_profile(self) -> dict[str, Any]: def get_profile(self) -> dict[str, Any]:
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1)) row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
if not row: if not row:
@@ -114,7 +132,7 @@ class MemoryService:
self.db.commit() self.db.commit()
from app.rag.ingest import index_memory_fact from app.rag.ingest import index_memory_fact
self._schedule_rag(index_memory_fact(existing)) self._schedule_rag(index_memory_fact(**self._rag_fact_payload(existing)))
result = { result = {
"ok": True, "ok": True,
"action": "updated", "action": "updated",
@@ -139,7 +157,7 @@ class MemoryService:
self.db.refresh(fact) self.db.refresh(fact)
from app.rag.ingest import index_memory_fact from app.rag.ingest import index_memory_fact
self._schedule_rag(index_memory_fact(fact)) self._schedule_rag(index_memory_fact(**self._rag_fact_payload(fact)))
result = { result = {
"ok": True, "ok": True,
"action": "created", "action": "created",
+18 -10
View File
@@ -10,7 +10,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import get_settings from app.config import get_settings
from app.db.models import ChatSession, Document, DocumentChunk, MemoryFact from app.db.models import ChatSession, Document, DocumentChunk
from app.rag import embeddings from app.rag import embeddings
from app.rag.chunker import chunk_text from app.rag.chunker import chunk_text
from app.rag.store import ( from app.rag.store import (
@@ -22,25 +22,33 @@ from app.rag.store import (
) )
async def index_memory_fact(fact: MemoryFact) -> None: async def index_memory_fact(
*,
fact_id: int,
user_id: int,
content: str,
category: str,
importance: int,
active: bool = True,
) -> None:
settings = get_settings() settings = get_settings()
if not settings.rag_enabled or not fact.active: if not settings.rag_enabled or not active:
return return
vectors = await embeddings.embed_texts([fact.content]) vectors = await embeddings.embed_texts([content])
if not vectors: if not vectors:
return return
upsert_points( upsert_points(
COLLECTION_FACTS, COLLECTION_FACTS,
[ [
qm.PointStruct( qm.PointStruct(
id=int(fact.id), id=int(fact_id),
vector=vectors[0], vector=vectors[0],
payload={ payload={
"user_id": fact.user_id, "user_id": user_id,
"fact_id": fact.id, "fact_id": fact_id,
"category": fact.category, "category": category,
"content": fact.content, "content": content,
"importance": fact.importance, "importance": importance,
}, },
) )
], ],
+8 -1
View File
@@ -23,7 +23,14 @@ async def main() -> None:
try: try:
facts = db.scalars(select(MemoryFact).where(MemoryFact.active.is_(True))).all() facts = db.scalars(select(MemoryFact).where(MemoryFact.active.is_(True))).all()
for fact in facts: for fact in facts:
await index_memory_fact(fact) await index_memory_fact(
fact_id=int(fact.id),
user_id=int(fact.user_id),
content=fact.content,
category=fact.category,
importance=int(fact.importance),
active=bool(fact.active),
)
summaries = db.scalars(select(SessionSummary)).all() summaries = db.scalars(select(SessionSummary)).all()
for row in summaries: for row in summaries:
if row.summary: if row.summary: