Files
Home_assistant/telegram-bot/bot/handlers/chat.py
T
2026-06-13 20:20:56 +00:00

130 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator
from aiogram import F, Router
from aiogram.enums import ChatAction
from aiogram.types import Message
from bot.access import is_allowed
from bot.config import Settings
from bot.filters import IsLinked
from bot.ha_client import HaClient
from bot.sse import SseChunk
from bot.notify_worker import advance_cursors, send_text
from bot.storage import LinkedUser, Storage
router = Router()
logger = logging.getLogger(__name__)
_generation_locks: dict[int, asyncio.Lock] = {}
def _user_lock(telegram_id: int) -> asyncio.Lock:
if telegram_id not in _generation_locks:
_generation_locks[telegram_id] = asyncio.Lock()
return _generation_locks[telegram_id]
async def _iter_stream(stream: AsyncIterator[SseChunk]) -> AsyncIterator[SseChunk]:
async for chunk in stream:
yield chunk
async def _run_chat_stream(
message: Message,
settings: Settings,
storage: Storage,
linked: LinkedUser,
stream: AsyncIterator[SseChunk],
) -> None:
accumulated = ""
async for chunk in stream:
if chunk.event == "status":
await message.bot.send_chat_action(message.chat.id, ChatAction.TYPING)
elif chunk.event == "token":
piece = str(chunk.data.get("content") or "")
if piece:
accumulated += piece
elif chunk.event == "notice":
content = str(chunk.data.get("content") or "").strip()
if content:
await send_text(message.bot, message.chat.id, content)
elif chunk.event == "error":
err = str(chunk.data.get("message") or "Ошибка генерации")
await message.answer(err)
return
elif chunk.event == "done":
break
if accumulated.strip():
parts = _split_for_edit_or_send(accumulated)
for part in parts:
await message.answer(part)
client = HaClient(settings.ha_api_base_url, linked.api_token)
try:
await advance_cursors(storage, client, linked)
except Exception:
logger.exception("Failed to advance cursors for telegram_id=%s", linked.telegram_id)
def _split_for_edit_or_send(text: str, limit: int = 4096) -> list[str]:
if len(text) <= limit:
return [text]
parts: list[str] = []
remaining = text
while remaining:
if len(remaining) <= limit:
parts.append(remaining)
break
cut = remaining.rfind("\n", 0, limit)
if cut <= 0:
cut = limit
parts.append(remaining[:cut])
remaining = remaining[cut:].lstrip("\n")
return parts
@router.message(F.text & ~F.text.startswith("/"), IsLinked())
async def handle_chat_message(message: Message, settings: Settings, storage: Storage) -> None:
if not is_allowed(message, settings):
return
if not message.from_user or not message.text:
return
linked = await storage.get_user(message.from_user.id)
if not linked:
return
lock = _user_lock(message.from_user.id)
if lock.locked():
await message.answer("Подожди, предыдущий ответ ещё генерируется.")
return
async with lock:
client = HaClient(settings.ha_api_base_url, linked.api_token)
content = message.text.strip()
if not content:
return
try:
status = await client.generation_status(linked.session_id)
if status.get("active"):
stream = client.stream_generation(linked.session_id)
else:
stream = client.send_message_stream(linked.session_id, content)
except Exception as exc:
logger.exception("Failed to start chat for telegram_id=%s", message.from_user.id)
await message.answer(f"Ошибка связи с Home Assistant: {exc}")
return
try:
await _run_chat_stream(message, settings, storage, linked, _iter_stream(stream))
except Exception as exc:
logger.exception("Chat stream failed for telegram_id=%s", message.from_user.id)
await message.answer(f"Ошибка при получении ответа: {exc}")