fix КФП
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -20,6 +21,9 @@ DEFAULT_PROFILE: dict[str, Any] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryService:
|
class MemoryService:
|
||||||
def __init__(self, db: Session, user_id: int):
|
def __init__(self, db: Session, user_id: int):
|
||||||
self.db = db
|
self.db = db
|
||||||
@@ -38,10 +42,24 @@ class MemoryService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _schedule_rag(coro) -> None:
|
def _schedule_rag(coro) -> None:
|
||||||
def runner() -> None:
|
def runner() -> None:
|
||||||
asyncio.run(coro)
|
try:
|
||||||
|
asyncio.run(coro)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("RAG background task failed")
|
||||||
|
|
||||||
threading.Thread(target=runner, daemon=True).start()
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rag_fact_payload(fact: MemoryFact) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fact_id": int(fact.id),
|
||||||
|
"user_id": int(fact.user_id),
|
||||||
|
"content": fact.content,
|
||||||
|
"category": fact.category,
|
||||||
|
"importance": int(fact.importance),
|
||||||
|
"active": bool(fact.active),
|
||||||
|
}
|
||||||
|
|
||||||
def get_profile(self) -> dict[str, Any]:
|
def get_profile(self) -> dict[str, Any]:
|
||||||
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
|
row = self.db.scalar(select(UserProfile).where(UserProfile.user_id == self.user_id).limit(1))
|
||||||
if not row:
|
if not row:
|
||||||
@@ -114,7 +132,7 @@ class MemoryService:
|
|||||||
self.db.commit()
|
self.db.commit()
|
||||||
from app.rag.ingest import index_memory_fact
|
from app.rag.ingest import index_memory_fact
|
||||||
|
|
||||||
self._schedule_rag(index_memory_fact(existing))
|
self._schedule_rag(index_memory_fact(**self._rag_fact_payload(existing)))
|
||||||
result = {
|
result = {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"action": "updated",
|
"action": "updated",
|
||||||
@@ -139,7 +157,7 @@ class MemoryService:
|
|||||||
self.db.refresh(fact)
|
self.db.refresh(fact)
|
||||||
from app.rag.ingest import index_memory_fact
|
from app.rag.ingest import index_memory_fact
|
||||||
|
|
||||||
self._schedule_rag(index_memory_fact(fact))
|
self._schedule_rag(index_memory_fact(**self._rag_fact_payload(fact)))
|
||||||
result = {
|
result = {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"action": "created",
|
"action": "created",
|
||||||
|
|||||||
+18
-10
@@ -10,7 +10,7 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.db.models import ChatSession, Document, DocumentChunk, MemoryFact
|
from app.db.models import ChatSession, Document, DocumentChunk
|
||||||
from app.rag import embeddings
|
from app.rag import embeddings
|
||||||
from app.rag.chunker import chunk_text
|
from app.rag.chunker import chunk_text
|
||||||
from app.rag.store import (
|
from app.rag.store import (
|
||||||
@@ -22,25 +22,33 @@ from app.rag.store import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def index_memory_fact(fact: MemoryFact) -> None:
|
async def index_memory_fact(
|
||||||
|
*,
|
||||||
|
fact_id: int,
|
||||||
|
user_id: int,
|
||||||
|
content: str,
|
||||||
|
category: str,
|
||||||
|
importance: int,
|
||||||
|
active: bool = True,
|
||||||
|
) -> None:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if not settings.rag_enabled or not fact.active:
|
if not settings.rag_enabled or not active:
|
||||||
return
|
return
|
||||||
vectors = await embeddings.embed_texts([fact.content])
|
vectors = await embeddings.embed_texts([content])
|
||||||
if not vectors:
|
if not vectors:
|
||||||
return
|
return
|
||||||
upsert_points(
|
upsert_points(
|
||||||
COLLECTION_FACTS,
|
COLLECTION_FACTS,
|
||||||
[
|
[
|
||||||
qm.PointStruct(
|
qm.PointStruct(
|
||||||
id=int(fact.id),
|
id=int(fact_id),
|
||||||
vector=vectors[0],
|
vector=vectors[0],
|
||||||
payload={
|
payload={
|
||||||
"user_id": fact.user_id,
|
"user_id": user_id,
|
||||||
"fact_id": fact.id,
|
"fact_id": fact_id,
|
||||||
"category": fact.category,
|
"category": category,
|
||||||
"content": fact.content,
|
"content": content,
|
||||||
"importance": fact.importance,
|
"importance": importance,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -23,7 +23,14 @@ async def main() -> None:
|
|||||||
try:
|
try:
|
||||||
facts = db.scalars(select(MemoryFact).where(MemoryFact.active.is_(True))).all()
|
facts = db.scalars(select(MemoryFact).where(MemoryFact.active.is_(True))).all()
|
||||||
for fact in facts:
|
for fact in facts:
|
||||||
await index_memory_fact(fact)
|
await index_memory_fact(
|
||||||
|
fact_id=int(fact.id),
|
||||||
|
user_id=int(fact.user_id),
|
||||||
|
content=fact.content,
|
||||||
|
category=fact.category,
|
||||||
|
importance=int(fact.importance),
|
||||||
|
active=bool(fact.active),
|
||||||
|
)
|
||||||
summaries = db.scalars(select(SessionSummary)).all()
|
summaries = db.scalars(select(SessionSummary)).all()
|
||||||
for row in summaries:
|
for row in summaries:
|
||||||
if row.summary:
|
if row.summary:
|
||||||
|
|||||||
Reference in New Issue
Block a user