import asyncio import logging from dataclasses import dataclass, field from app.chat.service import ChatService from app.db.base import SessionLocal logger = logging.getLogger(__name__) class GenerationBusyError(Exception): """Сессия уже генерирует ответ.""" @dataclass class GenerationHandle: session_id: int user_id: int user_text: str task: asyncio.Task | None = None subscribers: list[asyncio.Queue[str | None]] = field(default_factory=list) _lock: asyncio.Lock = field(default_factory=asyncio.Lock) async def broadcast(self, chunk: str | None) -> None: async with self._lock: targets = list(self.subscribers) for queue in targets: try: queue.put_nowait(chunk) except asyncio.QueueFull: logger.debug("generation queue full for session=%s, dropping subscriber", self.session_id) def add_subscriber(self) -> asyncio.Queue[str | None]: queue: asyncio.Queue[str | None] = asyncio.Queue(maxsize=512) self.subscribers.append(queue) return queue def remove_subscriber(self, queue: asyncio.Queue[str | None]) -> None: if queue in self.subscribers: self.subscribers.remove(queue) _registry: dict[int, GenerationHandle] = {} _registry_lock = asyncio.Lock() def is_generation_active(session_id: int) -> bool: return session_id in _registry def get_active_handle(session_id: int) -> GenerationHandle | None: return _registry.get(session_id) async def _run_generation(handle: GenerationHandle) -> None: db = SessionLocal() try: service = ChatService(db, handle.user_id) async for chunk in service.stream_response( handle.session_id, handle.user_text, user_message_saved=True, ): await handle.broadcast(chunk) except Exception as exc: logger.exception("Background generation failed session=%s", handle.session_id) await handle.broadcast(ChatService._sse("error", {"message": str(exc)})) finally: await handle.broadcast(None) db.close() async with _registry_lock: if _registry.get(handle.session_id) is handle: _registry.pop(handle.session_id, None) async def start_generation(session_id: int, user_id: int, user_text: str) -> GenerationHandle: async with _registry_lock: if session_id in _registry: raise GenerationBusyError() handle = GenerationHandle(session_id=session_id, user_id=user_id, user_text=user_text) _registry[session_id] = handle handle.task = asyncio.create_task(_run_generation(handle)) return handle async def subscribe_generation(handle: GenerationHandle): queue = handle.add_subscriber() try: while True: chunk = await queue.get() if chunk is None: break yield chunk finally: handle.remove_subscriber(queue)