added RAG, Multiuser, TG bot
This commit is contained in:
@@ -141,3 +141,24 @@ class HaClient:
|
||||
return int(session["id"])
|
||||
created = await self.create_session("Telegram")
|
||||
return int(created["id"])
|
||||
|
||||
async def download_media(self, path_or_url: str, *, ha_api_base: str | None = None) -> bytes:
|
||||
base = ha_api_base or self.base_url
|
||||
url = resolve_media_url(base, path_or_url)
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(url, headers=self._headers())
|
||||
if response.status_code >= 400:
|
||||
raise HaApiError(response.text.strip() or f"HTTP {response.status_code}", response.status_code)
|
||||
return response.content
|
||||
|
||||
|
||||
def resolve_media_url(ha_api_base: str, path_or_url: str) -> str:
|
||||
raw = (path_or_url or "").strip()
|
||||
if raw.startswith("http://") or raw.startswith("https://"):
|
||||
return raw
|
||||
origin = ha_api_base.rstrip("/")
|
||||
if origin.endswith("/api/v1"):
|
||||
origin = origin[: -len("/api/v1")]
|
||||
if not raw.startswith("/"):
|
||||
raw = f"/{raw}"
|
||||
return f"{origin}{raw}"
|
||||
|
||||
@@ -13,7 +13,8 @@ from bot.config import Settings
|
||||
from bot.filters import IsLinked
|
||||
from bot.ha_client import HaClient
|
||||
from bot.sse import SseChunk
|
||||
from bot.notify_worker import advance_cursors, send_text
|
||||
from bot.notify_worker import advance_cursors
|
||||
from bot.notice_delivery import send_notice_content
|
||||
from bot.storage import LinkedUser, Storage
|
||||
|
||||
router = Router()
|
||||
@@ -52,7 +53,13 @@ async def _run_chat_stream(
|
||||
elif chunk.event == "notice":
|
||||
content = str(chunk.data.get("content") or "").strip()
|
||||
if content:
|
||||
await send_text(message.bot, message.chat.id, content)
|
||||
await send_notice_content(
|
||||
message.bot,
|
||||
message.chat.id,
|
||||
content,
|
||||
HaClient(settings.ha_api_base_url, linked.api_token),
|
||||
ha_api_base=settings.ha_api_base_url,
|
||||
)
|
||||
elif chunk.event == "error":
|
||||
err = str(chunk.data.get("message") or "Ошибка генерации")
|
||||
await message.answer(err)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
|
||||
from bot.ha_client import HaClient
|
||||
from bot.tg_util import send_text, split_telegram_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMAGE_MD_RE = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||||
TG_CAPTION_MAX = 1024
|
||||
|
||||
|
||||
def parse_notice_content(content: str) -> tuple[str, list[str]]:
|
||||
image_paths: list[str] = []
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
image_paths.append(match.group(2).strip())
|
||||
return ""
|
||||
|
||||
text = IMAGE_MD_RE.sub(_replace, content)
|
||||
text = _plain_markdown(text)
|
||||
return text, image_paths
|
||||
|
||||
|
||||
def _plain_markdown(text: str) -> str:
|
||||
text = re.sub(r"```[^\n]*\n(.*?)```", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
|
||||
text = re.sub(r"\*([^*]+)\*", r"\1", text)
|
||||
text = re.sub(r"__([^_]+)__", r"\1", text)
|
||||
text = re.sub(r"`([^`]+)`", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
async def send_notice_content(
|
||||
bot: Bot,
|
||||
chat_id: int,
|
||||
content: str,
|
||||
client: HaClient,
|
||||
*,
|
||||
ha_api_base: str,
|
||||
) -> None:
|
||||
caption, image_paths = parse_notice_content(content)
|
||||
if not image_paths:
|
||||
await send_text(bot, chat_id, caption or content)
|
||||
return
|
||||
|
||||
caption_chunks = split_telegram_message(caption, TG_CAPTION_MAX) if caption else []
|
||||
first_caption = caption_chunks[0] if caption_chunks else None
|
||||
|
||||
for index, image_path in enumerate(image_paths):
|
||||
try:
|
||||
image_bytes = await client.download_media(image_path, ha_api_base=ha_api_base)
|
||||
except Exception:
|
||||
logger.exception("Failed to download image %s for chat_id=%s", image_path, chat_id)
|
||||
fallback = f"{caption}\n\n(не удалось загрузить: {image_path})".strip()
|
||||
await send_text(bot, chat_id, fallback or image_path)
|
||||
return
|
||||
|
||||
cap = first_caption if index == 0 else None
|
||||
await bot.send_photo(
|
||||
chat_id,
|
||||
BufferedInputFile(image_bytes, filename="image.png"),
|
||||
caption=cap or None,
|
||||
)
|
||||
|
||||
if len(caption_chunks) > 1:
|
||||
for extra in caption_chunks[1:]:
|
||||
await send_text(bot, chat_id, extra)
|
||||
elif caption and not first_caption and len(caption) > TG_CAPTION_MAX:
|
||||
await send_text(bot, chat_id, caption)
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||
from aiogram import Bot
|
||||
|
||||
from bot.ha_client import HaClient
|
||||
from bot.notice_delivery import send_notice_content
|
||||
from bot.storage import LinkedUser, Storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,29 +15,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOTICE_ROLES = frozenset({"notice", "character"})
|
||||
TG_MAX_LEN = 4096
|
||||
|
||||
|
||||
def split_telegram_message(text: str, limit: int = TG_MAX_LEN) -> list[str]:
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
chunks: list[str] = []
|
||||
remaining = text
|
||||
while remaining:
|
||||
if len(remaining) <= limit:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
split_at = remaining.rfind("\n", 0, limit)
|
||||
if split_at <= 0:
|
||||
split_at = limit
|
||||
chunks.append(remaining[:split_at])
|
||||
remaining = remaining[split_at:].lstrip("\n")
|
||||
return chunks
|
||||
|
||||
|
||||
async def send_text(bot: Bot, chat_id: int, text: str) -> None:
|
||||
for chunk in split_telegram_message(text):
|
||||
await bot.send_message(chat_id, chunk)
|
||||
|
||||
|
||||
async def advance_cursors(
|
||||
@@ -113,7 +91,13 @@ async def sync_notices_for_user(
|
||||
pending.sort(key=lambda item: item[0])
|
||||
for _, content in pending:
|
||||
try:
|
||||
await send_text(bot, user.telegram_id, content)
|
||||
await send_notice_content(
|
||||
bot,
|
||||
user.telegram_id,
|
||||
content,
|
||||
client,
|
||||
ha_api_base=ha_base_url,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to send notice to telegram_id=%s", user.telegram_id)
|
||||
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiogram import Bot
|
||||
|
||||
TG_MAX_LEN = 4096
|
||||
|
||||
|
||||
def split_telegram_message(text: str, limit: int = TG_MAX_LEN) -> list[str]:
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
chunks: list[str] = []
|
||||
remaining = text
|
||||
while remaining:
|
||||
if len(remaining) <= limit:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
split_at = remaining.rfind("\n", 0, limit)
|
||||
if split_at <= 0:
|
||||
split_at = limit
|
||||
chunks.append(remaining[:split_at])
|
||||
remaining = remaining[split_at:].lstrip("\n")
|
||||
return chunks
|
||||
|
||||
|
||||
async def send_text(bot: Bot, chat_id: int, text: str) -> None:
|
||||
for chunk in split_telegram_message(text):
|
||||
await bot.send_message(chat_id, chunk)
|
||||
@@ -0,0 +1,23 @@
|
||||
from bot.ha_client import resolve_media_url
|
||||
from bot.notice_delivery import parse_notice_content
|
||||
|
||||
|
||||
def test_resolve_media_url_relative():
|
||||
url = resolve_media_url(
|
||||
"https://home.example.com/api/v1",
|
||||
"/api/v1/media/generated/abc.png",
|
||||
)
|
||||
assert url == "https://home.example.com/api/v1/media/generated/abc.png"
|
||||
|
||||
|
||||
def test_parse_notice_content_extracts_image():
|
||||
content = (
|
||||
"🎨 **Картинка готова**\n\n"
|
||||
"\n\n"
|
||||
"**Comfy (+):**\n```\n1girl, smile\n```"
|
||||
)
|
||||
text, paths = parse_notice_content(content)
|
||||
assert paths == ["/api/v1/media/generated/abc.png"]
|
||||
assert "![image]" not in text
|
||||
assert "Картинка готова" in text
|
||||
assert "1girl, smile" in text
|
||||
Reference in New Issue
Block a user