Files
Home_assistant/backend/app/chat/generation.py
T
2026-06-13 20:20:56 +00:00

96 lines
3.1 KiB
Python

import asyncio
import logging
from dataclasses import dataclass, field
from app.chat.service import ChatService
from app.db.base import SessionLocal
logger = logging.getLogger(__name__)
class GenerationBusyError(Exception):
"""Сессия уже генерирует ответ."""
@dataclass
class GenerationHandle:
session_id: int
user_id: int
user_text: str
task: asyncio.Task | None = None
subscribers: list[asyncio.Queue[str | None]] = field(default_factory=list)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
async def broadcast(self, chunk: str | None) -> None:
async with self._lock:
targets = list(self.subscribers)
for queue in targets:
try:
queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.debug("generation queue full for session=%s, dropping subscriber", self.session_id)
def add_subscriber(self) -> asyncio.Queue[str | None]:
queue: asyncio.Queue[str | None] = asyncio.Queue(maxsize=512)
self.subscribers.append(queue)
return queue
def remove_subscriber(self, queue: asyncio.Queue[str | None]) -> None:
if queue in self.subscribers:
self.subscribers.remove(queue)
_registry: dict[int, GenerationHandle] = {}
_registry_lock = asyncio.Lock()
def is_generation_active(session_id: int) -> bool:
return session_id in _registry
def get_active_handle(session_id: int) -> GenerationHandle | None:
return _registry.get(session_id)
async def _run_generation(handle: GenerationHandle) -> None:
db = SessionLocal()
try:
service = ChatService(db, handle.user_id)
async for chunk in service.stream_response(
handle.session_id,
handle.user_text,
user_message_saved=True,
):
await handle.broadcast(chunk)
except Exception as exc:
logger.exception("Background generation failed session=%s", handle.session_id)
await handle.broadcast(ChatService._sse("error", {"message": str(exc)}))
finally:
await handle.broadcast(None)
db.close()
async with _registry_lock:
if _registry.get(handle.session_id) is handle:
_registry.pop(handle.session_id, None)
async def start_generation(session_id: int, user_id: int, user_text: str) -> GenerationHandle:
async with _registry_lock:
if session_id in _registry:
raise GenerationBusyError()
handle = GenerationHandle(session_id=session_id, user_id=user_id, user_text=user_text)
_registry[session_id] = handle
handle.task = asyncio.create_task(_run_generation(handle))
return handle
async def subscribe_generation(handle: GenerationHandle):
queue = handle.add_subscriber()
try:
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
finally:
handle.remove_subscriber(queue)