from __future__ import annotations import asyncio import logging from collections.abc import AsyncIterator from aiogram import F, Router from aiogram.enums import ChatAction from aiogram.types import Message from bot.access import is_allowed from bot.config import Settings from bot.filters import IsLinked from bot.ha_client import HaClient from bot.sse import SseChunk from bot.notify_worker import advance_cursors, send_text from bot.storage import LinkedUser, Storage router = Router() logger = logging.getLogger(__name__) _generation_locks: dict[int, asyncio.Lock] = {} def _user_lock(telegram_id: int) -> asyncio.Lock: if telegram_id not in _generation_locks: _generation_locks[telegram_id] = asyncio.Lock() return _generation_locks[telegram_id] async def _iter_stream(stream: AsyncIterator[SseChunk]) -> AsyncIterator[SseChunk]: async for chunk in stream: yield chunk async def _run_chat_stream( message: Message, settings: Settings, storage: Storage, linked: LinkedUser, stream: AsyncIterator[SseChunk], ) -> None: accumulated = "" async for chunk in stream: if chunk.event == "status": await message.bot.send_chat_action(message.chat.id, ChatAction.TYPING) elif chunk.event == "token": piece = str(chunk.data.get("content") or "") if piece: accumulated += piece elif chunk.event == "notice": content = str(chunk.data.get("content") or "").strip() if content: await send_text(message.bot, message.chat.id, content) elif chunk.event == "error": err = str(chunk.data.get("message") or "Ошибка генерации") await message.answer(err) return elif chunk.event == "done": break if accumulated.strip(): parts = _split_for_edit_or_send(accumulated) for part in parts: await message.answer(part) client = HaClient(settings.ha_api_base_url, linked.api_token) try: await advance_cursors(storage, client, linked) except Exception: logger.exception("Failed to advance cursors for telegram_id=%s", linked.telegram_id) def _split_for_edit_or_send(text: str, limit: int = 4096) -> list[str]: if len(text) <= limit: return [text] parts: list[str] = [] remaining = text while remaining: if len(remaining) <= limit: parts.append(remaining) break cut = remaining.rfind("\n", 0, limit) if cut <= 0: cut = limit parts.append(remaining[:cut]) remaining = remaining[cut:].lstrip("\n") return parts @router.message(F.text & ~F.text.startswith("/"), IsLinked()) async def handle_chat_message(message: Message, settings: Settings, storage: Storage) -> None: if not is_allowed(message, settings): return if not message.from_user or not message.text: return linked = await storage.get_user(message.from_user.id) if not linked: return lock = _user_lock(message.from_user.id) if lock.locked(): await message.answer("Подожди, предыдущий ответ ещё генерируется.") return async with lock: client = HaClient(settings.ha_api_base_url, linked.api_token) content = message.text.strip() if not content: return try: status = await client.generation_status(linked.session_id) if status.get("active"): stream = client.stream_generation(linked.session_id) else: stream = client.send_message_stream(linked.session_id, content) except Exception as exc: logger.exception("Failed to start chat for telegram_id=%s", message.from_user.id) await message.answer(f"Ошибка связи с Home Assistant: {exc}") return try: await _run_chat_stream(message, settings, storage, linked, _iter_stream(stream)) except Exception as exc: logger.exception("Chat stream failed for telegram_id=%s", message.from_user.id) await message.answer(f"Ошибка при получении ответа: {exc}")