214 lines
8.2 KiB
Python
214 lines
8.2 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
from collections.abc import AsyncIterator
|
||
|
||
from aiogram import F, Router
|
||
from aiogram.enums import ChatAction
|
||
from aiogram.types import Message
|
||
|
||
from bot.access import is_allowed
|
||
from bot.config import Settings
|
||
from bot.filters import IsLinked
|
||
from bot.ha_client import HaClient
|
||
from bot.sse import SseChunk
|
||
from bot.notify_worker import advance_cursors
|
||
from bot.notice_delivery import send_notice_content
|
||
from bot.storage import LinkedUser, Storage
|
||
|
||
router = Router()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_generation_locks: dict[int, asyncio.Lock] = {}
|
||
|
||
|
||
def _user_lock(telegram_id: int) -> asyncio.Lock:
|
||
if telegram_id not in _generation_locks:
|
||
_generation_locks[telegram_id] = asyncio.Lock()
|
||
return _generation_locks[telegram_id]
|
||
|
||
|
||
async def _iter_stream(stream: AsyncIterator[SseChunk]) -> AsyncIterator[SseChunk]:
|
||
async for chunk in stream:
|
||
yield chunk
|
||
|
||
|
||
async def _run_chat_stream(
|
||
message: Message,
|
||
settings: Settings,
|
||
storage: Storage,
|
||
linked: LinkedUser,
|
||
stream: AsyncIterator[SseChunk],
|
||
) -> None:
|
||
accumulated = ""
|
||
|
||
async for chunk in stream:
|
||
if chunk.event == "status":
|
||
await message.bot.send_chat_action(message.chat.id, ChatAction.TYPING)
|
||
elif chunk.event == "token":
|
||
piece = str(chunk.data.get("content") or "")
|
||
if piece:
|
||
accumulated += piece
|
||
elif chunk.event == "notice":
|
||
content = str(chunk.data.get("content") or "").strip()
|
||
if content:
|
||
await send_notice_content(
|
||
message.bot,
|
||
message.chat.id,
|
||
content,
|
||
HaClient(settings.ha_api_base_url, linked.api_token),
|
||
ha_api_base=settings.ha_api_base_url,
|
||
)
|
||
elif chunk.event == "vision":
|
||
data = chunk.data if isinstance(chunk.data, dict) else {}
|
||
images = data.get("images")
|
||
if isinstance(images, list) and images:
|
||
for index, item in enumerate(images, start=1):
|
||
if not isinstance(item, dict):
|
||
continue
|
||
parsed = item.get("parsed")
|
||
model = item.get("model")
|
||
preview = ""
|
||
if isinstance(parsed, dict):
|
||
preview = str(parsed.get("description") or "")[:400]
|
||
lines = [f"Vision {index}/{len(images)} ({model or '?'}):", preview or "(нет описания)"]
|
||
if isinstance(parsed, dict) and parsed.get("fitness_hints"):
|
||
lines.append(f"fitness_hints: {parsed.get('fitness_hints')}")
|
||
await message.answer("\n".join(lines)[:4000])
|
||
else:
|
||
parsed = data.get("parsed")
|
||
model = data.get("model")
|
||
preview = ""
|
||
if isinstance(parsed, dict):
|
||
preview = str(parsed.get("description") or "")[:400]
|
||
lines = [f"Vision ({model or '?'}):", preview or "(нет описания)"]
|
||
if isinstance(parsed, dict) and parsed.get("fitness_hints"):
|
||
lines.append(f"fitness_hints: {parsed.get('fitness_hints')}")
|
||
await message.answer("\n".join(lines)[:4000])
|
||
elif chunk.event == "error":
|
||
err = str(chunk.data.get("message") or "Ошибка генерации")
|
||
await message.answer(err)
|
||
return
|
||
elif chunk.event == "done":
|
||
break
|
||
|
||
if accumulated.strip():
|
||
parts = _split_for_edit_or_send(accumulated)
|
||
for part in parts:
|
||
await message.answer(part)
|
||
|
||
client = HaClient(settings.ha_api_base_url, linked.api_token)
|
||
try:
|
||
await advance_cursors(storage, client, linked)
|
||
except Exception:
|
||
logger.exception("Failed to advance cursors for telegram_id=%s", linked.telegram_id)
|
||
|
||
|
||
def _split_for_edit_or_send(text: str, limit: int = 4096) -> list[str]:
|
||
if len(text) <= limit:
|
||
return [text]
|
||
parts: list[str] = []
|
||
remaining = text
|
||
while remaining:
|
||
if len(remaining) <= limit:
|
||
parts.append(remaining)
|
||
break
|
||
cut = remaining.rfind("\n", 0, limit)
|
||
if cut <= 0:
|
||
cut = limit
|
||
parts.append(remaining[:cut])
|
||
remaining = remaining[cut:].lstrip("\n")
|
||
return parts
|
||
|
||
|
||
@router.message(F.text & ~F.text.startswith("/"), IsLinked())
|
||
async def handle_chat_message(message: Message, settings: Settings, storage: Storage) -> None:
|
||
if not is_allowed(message, settings):
|
||
return
|
||
if not message.from_user or not message.text:
|
||
return
|
||
|
||
linked = await storage.get_user(message.from_user.id)
|
||
if not linked:
|
||
return
|
||
|
||
lock = _user_lock(message.from_user.id)
|
||
if lock.locked():
|
||
await message.answer("Подожди, предыдущий ответ ещё генерируется.")
|
||
return
|
||
|
||
async with lock:
|
||
client = HaClient(settings.ha_api_base_url, linked.api_token)
|
||
content = message.text.strip()
|
||
if not content:
|
||
return
|
||
|
||
try:
|
||
status = await client.generation_status(linked.session_id)
|
||
if status.get("active"):
|
||
stream = client.stream_generation(linked.session_id)
|
||
else:
|
||
stream = client.send_message_stream(linked.session_id, content)
|
||
except Exception as exc:
|
||
logger.exception("Failed to start chat for telegram_id=%s", message.from_user.id)
|
||
await message.answer(f"Ошибка связи с Home Assistant: {exc}")
|
||
return
|
||
|
||
try:
|
||
await _run_chat_stream(message, settings, storage, linked, _iter_stream(stream))
|
||
except Exception as exc:
|
||
logger.exception("Chat stream failed for telegram_id=%s", message.from_user.id)
|
||
await message.answer(f"Ошибка при получении ответа: {exc}")
|
||
|
||
|
||
@router.message(F.photo, IsLinked())
|
||
async def handle_chat_photo(message: Message, settings: Settings, storage: Storage) -> None:
|
||
if not is_allowed(message, settings):
|
||
return
|
||
if not message.from_user or not message.photo:
|
||
return
|
||
|
||
linked = await storage.get_user(message.from_user.id)
|
||
if not linked:
|
||
return
|
||
|
||
lock = _user_lock(message.from_user.id)
|
||
if lock.locked():
|
||
await message.answer("Подожди, предыдущий ответ ещё генерируется.")
|
||
return
|
||
|
||
async with lock:
|
||
client = HaClient(settings.ha_api_base_url, linked.api_token)
|
||
caption = (message.caption or "").strip()
|
||
photo = message.photo[-1]
|
||
|
||
try:
|
||
file = await message.bot.get_file(photo.file_id)
|
||
if not file.file_path:
|
||
await message.answer("Не удалось получить файл фото.")
|
||
return
|
||
downloaded = await message.bot.download_file(file.file_path)
|
||
image_bytes = downloaded.read() if hasattr(downloaded, "read") else bytes(downloaded)
|
||
|
||
status = await client.generation_status(linked.session_id)
|
||
if status.get("active"):
|
||
stream = client.stream_generation(linked.session_id)
|
||
else:
|
||
stream = client.send_message_with_image_stream(
|
||
linked.session_id,
|
||
caption,
|
||
image_bytes,
|
||
filename="telegram.jpg",
|
||
)
|
||
except Exception as exc:
|
||
logger.exception("Failed to start chat photo for telegram_id=%s", message.from_user.id)
|
||
await message.answer(f"Ошибка связи с Home Assistant: {exc}")
|
||
return
|
||
|
||
try:
|
||
await _run_chat_stream(message, settings, storage, linked, _iter_stream(stream))
|
||
except Exception as exc:
|
||
logger.exception("Chat photo stream failed for telegram_id=%s", message.from_user.id)
|
||
await message.answer(f"Ошибка при получении ответа: {exc}")
|