added RAG, Multiuser, TG bot
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""RAG: embeddings, Qdrant store, retrieval, ingest."""
|
||||
|
||||
from app.rag import chunker, embeddings, ingest, retriever, store
|
||||
|
||||
__all__ = ["chunker", "embeddings", "ingest", "retriever", "store"]
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def chunk_text(text: str, *, chunk_size: int = 800, overlap: int = 120) -> list[str]:
|
||||
cleaned = (text or "").strip()
|
||||
if not cleaned:
|
||||
return []
|
||||
if len(cleaned) <= chunk_size:
|
||||
return [cleaned]
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
while start < len(cleaned):
|
||||
end = min(len(cleaned), start + chunk_size)
|
||||
piece = cleaned[start:end].strip()
|
||||
if piece:
|
||||
chunks.append(piece)
|
||||
if end >= len(cleaned):
|
||||
break
|
||||
start = max(0, end - overlap)
|
||||
return chunks
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.llm.client import LLMClient
|
||||
|
||||
|
||||
async def embed_texts(texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
client = LLMClient()
|
||||
return await client.embed(texts)
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client.http import models as qm
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.models import ChatSession, Document, DocumentChunk, MemoryFact
|
||||
from app.rag import embeddings
|
||||
from app.rag.chunker import chunk_text
|
||||
from app.rag.store import (
|
||||
COLLECTION_DOC_CHUNKS,
|
||||
COLLECTION_FACTS,
|
||||
COLLECTION_SUMMARIES,
|
||||
delete_by_filter,
|
||||
upsert_points,
|
||||
)
|
||||
|
||||
|
||||
async def index_memory_fact(fact: MemoryFact) -> None:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled or not fact.active:
|
||||
return
|
||||
vectors = await embeddings.embed_texts([fact.content])
|
||||
if not vectors:
|
||||
return
|
||||
upsert_points(
|
||||
COLLECTION_FACTS,
|
||||
[
|
||||
qm.PointStruct(
|
||||
id=int(fact.id),
|
||||
vector=vectors[0],
|
||||
payload={
|
||||
"user_id": fact.user_id,
|
||||
"fact_id": fact.id,
|
||||
"category": fact.category,
|
||||
"content": fact.content,
|
||||
"importance": fact.importance,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def deactivate_memory_fact(fact_id: int) -> None:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled:
|
||||
return
|
||||
delete_by_filter(
|
||||
COLLECTION_FACTS,
|
||||
[qm.FieldCondition(key="fact_id", match=qm.MatchValue(value=fact_id))],
|
||||
)
|
||||
|
||||
|
||||
async def index_session_summary(session_id: int, summary: str) -> None:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled or not summary.strip():
|
||||
return
|
||||
from app.db.base import SessionLocal
|
||||
|
||||
user_id = 1
|
||||
db = SessionLocal()
|
||||
try:
|
||||
session = db.get(ChatSession, session_id)
|
||||
if session:
|
||||
user_id = session.user_id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
vectors = await embeddings.embed_texts([summary])
|
||||
if not vectors:
|
||||
return
|
||||
upsert_points(
|
||||
COLLECTION_SUMMARIES,
|
||||
[
|
||||
qm.PointStruct(
|
||||
id=int(session_id),
|
||||
vector=vectors[0],
|
||||
payload={"user_id": user_id, "session_id": session_id, "summary": summary[:4000]},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def ingest_document_file(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: int,
|
||||
title: str,
|
||||
filename: str,
|
||||
raw_bytes: bytes,
|
||||
) -> dict[str, Any]:
|
||||
settings = get_settings()
|
||||
text = raw_bytes.decode("utf-8", errors="replace").strip()
|
||||
if not text:
|
||||
raise ValueError("Пустой документ")
|
||||
|
||||
digest = hashlib.sha256(raw_bytes).hexdigest()
|
||||
doc = Document(
|
||||
user_id=user_id,
|
||||
title=title or filename,
|
||||
filename=filename,
|
||||
content_hash=digest,
|
||||
size_bytes=len(raw_bytes),
|
||||
)
|
||||
db.add(doc)
|
||||
db.flush()
|
||||
|
||||
chunks = chunk_text(text)
|
||||
chunk_rows: list[DocumentChunk] = []
|
||||
for idx, piece in enumerate(chunks):
|
||||
row = DocumentChunk(document_id=doc.id, chunk_index=idx, content=piece)
|
||||
db.add(row)
|
||||
chunk_rows.append(row)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
if settings.rag_enabled and chunks:
|
||||
vectors = await embeddings.embed_texts(chunks)
|
||||
points: list[qm.PointStruct] = []
|
||||
for row, vector in zip(chunk_rows, vectors, strict=False):
|
||||
db.refresh(row)
|
||||
point_id = int(row.id)
|
||||
points.append(
|
||||
qm.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload={
|
||||
"user_id": user_id,
|
||||
"document_id": doc.id,
|
||||
"chunk_id": row.id,
|
||||
"chunk_index": row.chunk_index,
|
||||
"title": doc.title,
|
||||
"content": row.content,
|
||||
},
|
||||
)
|
||||
)
|
||||
upsert_points(COLLECTION_DOC_CHUNKS, points)
|
||||
|
||||
return {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"filename": doc.filename,
|
||||
"chunk_count": len(chunks),
|
||||
"size_bytes": doc.size_bytes,
|
||||
"created_at": doc.created_at.isoformat() if doc.created_at else None,
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
# Migrate active memory facts into Qdrant
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import get_settings
|
||||
from app.db.base import SessionLocal, init_db
|
||||
from app.db.models import MemoryFact, SessionSummary
|
||||
from app.rag.ingest import index_memory_fact, index_session_summary
|
||||
from app.rag.store import ensure_collections
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled:
|
||||
print("RAG disabled; set RAG_ENABLED=true")
|
||||
return
|
||||
init_db()
|
||||
ensure_collections()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
facts = db.scalars(select(MemoryFact).where(MemoryFact.active.is_(True))).all()
|
||||
for fact in facts:
|
||||
await index_memory_fact(fact)
|
||||
summaries = db.scalars(select(SessionSummary)).all()
|
||||
for row in summaries:
|
||||
if row.summary:
|
||||
await index_session_summary(row.session_id, row.summary)
|
||||
print(f"Indexed {len(facts)} facts and {len(summaries)} summaries")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client.http import models as qm
|
||||
|
||||
from app.config import get_settings
|
||||
from app.rag import embeddings
|
||||
from app.rag.store import COLLECTION_DOC_CHUNKS, COLLECTION_FACTS, search
|
||||
|
||||
|
||||
def _user_filter(user_id: int) -> qm.Filter:
|
||||
return qm.Filter(
|
||||
must=[qm.FieldCondition(key="user_id", match=qm.MatchValue(value=user_id))]
|
||||
)
|
||||
|
||||
|
||||
async def retrieve_memory_facts(
|
||||
query: str, *, user_id: int, top_k: int | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled or not query.strip():
|
||||
return []
|
||||
k = top_k or settings.rag_top_k
|
||||
vectors = await embeddings.embed_texts([query])
|
||||
if not vectors:
|
||||
return []
|
||||
hits = search(COLLECTION_FACTS, vectors[0], limit=k, query_filter=_user_filter(user_id))
|
||||
results: list[dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
payload = hit.payload or {}
|
||||
results.append(
|
||||
{
|
||||
"id": payload.get("fact_id") or hit.id,
|
||||
"category": payload.get("category", "fact"),
|
||||
"content": payload.get("content", ""),
|
||||
"score": hit.score,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def retrieve_document_chunks(
|
||||
query: str, *, user_id: int, top_k: int = 6
|
||||
) -> list[dict[str, Any]]:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled or not query.strip():
|
||||
return []
|
||||
vectors = await embeddings.embed_texts([query])
|
||||
if not vectors:
|
||||
return []
|
||||
hits = search(
|
||||
COLLECTION_DOC_CHUNKS, vectors[0], limit=top_k, query_filter=_user_filter(user_id)
|
||||
)
|
||||
out: list[dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
payload = hit.payload or {}
|
||||
out.append(
|
||||
{
|
||||
"document_id": payload.get("document_id"),
|
||||
"chunk_index": payload.get("chunk_index"),
|
||||
"title": payload.get("title", ""),
|
||||
"content": payload.get("content", ""),
|
||||
"score": hit.score,
|
||||
}
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qm
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COLLECTION_FACTS = "memory_facts"
|
||||
COLLECTION_SUMMARIES = "session_summaries"
|
||||
COLLECTION_DOC_CHUNKS = "document_chunks"
|
||||
VECTOR_SIZE = 1536
|
||||
|
||||
|
||||
def _client() -> QdrantClient:
|
||||
settings = get_settings()
|
||||
return QdrantClient(url=settings.qdrant_url)
|
||||
|
||||
|
||||
def ensure_collections() -> None:
|
||||
settings = get_settings()
|
||||
if not settings.rag_enabled:
|
||||
return
|
||||
client = _client()
|
||||
for name in (COLLECTION_FACTS, COLLECTION_SUMMARIES, COLLECTION_DOC_CHUNKS):
|
||||
if client.collection_exists(name):
|
||||
continue
|
||||
client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config=qm.VectorParams(size=VECTOR_SIZE, distance=qm.Distance.COSINE),
|
||||
)
|
||||
logger.info("Created Qdrant collection %s", name)
|
||||
|
||||
|
||||
def upsert_points(collection: str, points: list[qm.PointStruct]) -> None:
|
||||
if not points:
|
||||
return
|
||||
_client().upsert(collection_name=collection, points=points)
|
||||
|
||||
|
||||
def delete_by_filter(collection: str, must: list[qm.FieldCondition]) -> None:
|
||||
_client().delete(
|
||||
collection_name=collection,
|
||||
points_selector=qm.FilterSelector(filter=qm.Filter(must=must)),
|
||||
)
|
||||
|
||||
|
||||
def search(
|
||||
collection: str,
|
||||
vector: list[float],
|
||||
*,
|
||||
limit: int,
|
||||
query_filter: qm.Filter | None = None,
|
||||
) -> list[qm.ScoredPoint]:
|
||||
return _client().search(
|
||||
collection_name=collection,
|
||||
query_vector=vector,
|
||||
limit=limit,
|
||||
query_filter=query_filter,
|
||||
)
|
||||
Reference in New Issue
Block a user