65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models as qm
|
|
|
|
from app.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
COLLECTION_FACTS = "memory_facts"
|
|
COLLECTION_SUMMARIES = "session_summaries"
|
|
COLLECTION_DOC_CHUNKS = "document_chunks"
|
|
VECTOR_SIZE = 1536
|
|
|
|
|
|
def _client() -> QdrantClient:
|
|
settings = get_settings()
|
|
return QdrantClient(url=settings.qdrant_url)
|
|
|
|
|
|
def ensure_collections() -> None:
|
|
settings = get_settings()
|
|
if not settings.rag_enabled:
|
|
return
|
|
client = _client()
|
|
for name in (COLLECTION_FACTS, COLLECTION_SUMMARIES, COLLECTION_DOC_CHUNKS):
|
|
if client.collection_exists(name):
|
|
continue
|
|
client.create_collection(
|
|
collection_name=name,
|
|
vectors_config=qm.VectorParams(size=VECTOR_SIZE, distance=qm.Distance.COSINE),
|
|
)
|
|
logger.info("Created Qdrant collection %s", name)
|
|
|
|
|
|
def upsert_points(collection: str, points: list[qm.PointStruct]) -> None:
|
|
if not points:
|
|
return
|
|
_client().upsert(collection_name=collection, points=points)
|
|
|
|
|
|
def delete_by_filter(collection: str, must: list[qm.FieldCondition]) -> None:
|
|
_client().delete(
|
|
collection_name=collection,
|
|
points_selector=qm.FilterSelector(filter=qm.Filter(must=must)),
|
|
)
|
|
|
|
|
|
def search(
|
|
collection: str,
|
|
vector: list[float],
|
|
*,
|
|
limit: int,
|
|
query_filter: qm.Filter | None = None,
|
|
) -> list[qm.ScoredPoint]:
|
|
return _client().search(
|
|
collection_name=collection,
|
|
query_vector=vector,
|
|
limit=limit,
|
|
query_filter=query_filter,
|
|
)
|