96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
from app.chat.service import ChatService
|
|
from app.db.base import SessionLocal
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GenerationBusyError(Exception):
|
|
"""Сессия уже генерирует ответ."""
|
|
|
|
|
|
@dataclass
|
|
class GenerationHandle:
|
|
session_id: int
|
|
user_id: int
|
|
user_text: str
|
|
task: asyncio.Task | None = None
|
|
subscribers: list[asyncio.Queue[str | None]] = field(default_factory=list)
|
|
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
|
|
async def broadcast(self, chunk: str | None) -> None:
|
|
async with self._lock:
|
|
targets = list(self.subscribers)
|
|
for queue in targets:
|
|
try:
|
|
queue.put_nowait(chunk)
|
|
except asyncio.QueueFull:
|
|
logger.debug("generation queue full for session=%s, dropping subscriber", self.session_id)
|
|
|
|
def add_subscriber(self) -> asyncio.Queue[str | None]:
|
|
queue: asyncio.Queue[str | None] = asyncio.Queue(maxsize=512)
|
|
self.subscribers.append(queue)
|
|
return queue
|
|
|
|
def remove_subscriber(self, queue: asyncio.Queue[str | None]) -> None:
|
|
if queue in self.subscribers:
|
|
self.subscribers.remove(queue)
|
|
|
|
|
|
_registry: dict[int, GenerationHandle] = {}
|
|
_registry_lock = asyncio.Lock()
|
|
|
|
|
|
def is_generation_active(session_id: int) -> bool:
|
|
return session_id in _registry
|
|
|
|
|
|
def get_active_handle(session_id: int) -> GenerationHandle | None:
|
|
return _registry.get(session_id)
|
|
|
|
|
|
async def _run_generation(handle: GenerationHandle) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
service = ChatService(db, handle.user_id)
|
|
async for chunk in service.stream_response(
|
|
handle.session_id,
|
|
handle.user_text,
|
|
user_message_saved=True,
|
|
):
|
|
await handle.broadcast(chunk)
|
|
except Exception as exc:
|
|
logger.exception("Background generation failed session=%s", handle.session_id)
|
|
await handle.broadcast(ChatService._sse("error", {"message": str(exc)}))
|
|
finally:
|
|
await handle.broadcast(None)
|
|
db.close()
|
|
async with _registry_lock:
|
|
if _registry.get(handle.session_id) is handle:
|
|
_registry.pop(handle.session_id, None)
|
|
|
|
|
|
async def start_generation(session_id: int, user_id: int, user_text: str) -> GenerationHandle:
|
|
async with _registry_lock:
|
|
if session_id in _registry:
|
|
raise GenerationBusyError()
|
|
handle = GenerationHandle(session_id=session_id, user_id=user_id, user_text=user_text)
|
|
_registry[session_id] = handle
|
|
handle.task = asyncio.create_task(_run_generation(handle))
|
|
return handle
|
|
|
|
|
|
async def subscribe_generation(handle: GenerationHandle):
|
|
queue = handle.add_subscriber()
|
|
try:
|
|
while True:
|
|
chunk = await queue.get()
|
|
if chunk is None:
|
|
break
|
|
yield chunk
|
|
finally:
|
|
handle.remove_subscriber(queue)
|