added RAG, Multiuser, TG bot
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.chat.service import ChatService
|
||||
from app.db.base import SessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationBusyError(Exception):
|
||||
"""Сессия уже генерирует ответ."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationHandle:
|
||||
session_id: int
|
||||
user_id: int
|
||||
user_text: str
|
||||
task: asyncio.Task | None = None
|
||||
subscribers: list[asyncio.Queue[str | None]] = field(default_factory=list)
|
||||
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
async def broadcast(self, chunk: str | None) -> None:
|
||||
async with self._lock:
|
||||
targets = list(self.subscribers)
|
||||
for queue in targets:
|
||||
try:
|
||||
queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.debug("generation queue full for session=%s, dropping subscriber", self.session_id)
|
||||
|
||||
def add_subscriber(self) -> asyncio.Queue[str | None]:
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue(maxsize=512)
|
||||
self.subscribers.append(queue)
|
||||
return queue
|
||||
|
||||
def remove_subscriber(self, queue: asyncio.Queue[str | None]) -> None:
|
||||
if queue in self.subscribers:
|
||||
self.subscribers.remove(queue)
|
||||
|
||||
|
||||
_registry: dict[int, GenerationHandle] = {}
|
||||
_registry_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def is_generation_active(session_id: int) -> bool:
|
||||
return session_id in _registry
|
||||
|
||||
|
||||
def get_active_handle(session_id: int) -> GenerationHandle | None:
|
||||
return _registry.get(session_id)
|
||||
|
||||
|
||||
async def _run_generation(handle: GenerationHandle) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
service = ChatService(db, handle.user_id)
|
||||
async for chunk in service.stream_response(
|
||||
handle.session_id,
|
||||
handle.user_text,
|
||||
user_message_saved=True,
|
||||
):
|
||||
await handle.broadcast(chunk)
|
||||
except Exception as exc:
|
||||
logger.exception("Background generation failed session=%s", handle.session_id)
|
||||
await handle.broadcast(ChatService._sse("error", {"message": str(exc)}))
|
||||
finally:
|
||||
await handle.broadcast(None)
|
||||
db.close()
|
||||
async with _registry_lock:
|
||||
if _registry.get(handle.session_id) is handle:
|
||||
_registry.pop(handle.session_id, None)
|
||||
|
||||
|
||||
async def start_generation(session_id: int, user_id: int, user_text: str) -> GenerationHandle:
|
||||
async with _registry_lock:
|
||||
if session_id in _registry:
|
||||
raise GenerationBusyError()
|
||||
handle = GenerationHandle(session_id=session_id, user_id=user_id, user_text=user_text)
|
||||
_registry[session_id] = handle
|
||||
handle.task = asyncio.create_task(_run_generation(handle))
|
||||
return handle
|
||||
|
||||
|
||||
async def subscribe_generation(handle: GenerationHandle):
|
||||
queue = handle.add_subscriber()
|
||||
try:
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
handle.remove_subscriber(queue)
|
||||
Reference in New Issue
Block a user