68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from qdrant_client.http import models as qm
|
|
|
|
from app.config import get_settings
|
|
from app.rag import embeddings
|
|
from app.rag.store import COLLECTION_DOC_CHUNKS, COLLECTION_FACTS, search
|
|
|
|
|
|
def _user_filter(user_id: int) -> qm.Filter:
|
|
return qm.Filter(
|
|
must=[qm.FieldCondition(key="user_id", match=qm.MatchValue(value=user_id))]
|
|
)
|
|
|
|
|
|
async def retrieve_memory_facts(
|
|
query: str, *, user_id: int, top_k: int | None = None
|
|
) -> list[dict[str, Any]]:
|
|
settings = get_settings()
|
|
if not settings.rag_enabled or not query.strip():
|
|
return []
|
|
k = top_k or settings.rag_top_k
|
|
vectors = await embeddings.embed_texts([query])
|
|
if not vectors:
|
|
return []
|
|
hits = search(COLLECTION_FACTS, vectors[0], limit=k, query_filter=_user_filter(user_id))
|
|
results: list[dict[str, Any]] = []
|
|
for hit in hits:
|
|
payload = hit.payload or {}
|
|
results.append(
|
|
{
|
|
"id": payload.get("fact_id") or hit.id,
|
|
"category": payload.get("category", "fact"),
|
|
"content": payload.get("content", ""),
|
|
"score": hit.score,
|
|
}
|
|
)
|
|
return results
|
|
|
|
|
|
async def retrieve_document_chunks(
|
|
query: str, *, user_id: int, top_k: int = 6
|
|
) -> list[dict[str, Any]]:
|
|
settings = get_settings()
|
|
if not settings.rag_enabled or not query.strip():
|
|
return []
|
|
vectors = await embeddings.embed_texts([query])
|
|
if not vectors:
|
|
return []
|
|
hits = search(
|
|
COLLECTION_DOC_CHUNKS, vectors[0], limit=top_k, query_filter=_user_filter(user_id)
|
|
)
|
|
out: list[dict[str, Any]] = []
|
|
for hit in hits:
|
|
payload = hit.payload or {}
|
|
out.append(
|
|
{
|
|
"document_id": payload.get("document_id"),
|
|
"chunk_index": payload.get("chunk_index"),
|
|
"title": payload.get("title", ""),
|
|
"content": payload.get("content", ""),
|
|
"score": hit.score,
|
|
}
|
|
)
|
|
return out
|