Files
Home_assistant/telegram-bot/bot/ha_client.py
T
2026-06-16 04:38:23 +00:00

191 lines
7.3 KiB
Python

from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import httpx
from bot.sse import SseChunk, iter_sse
class HaApiError(RuntimeError):
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class HaClient:
def __init__(self, base_url: str, token: str = "", *, timeout: float = 120.0) -> None:
self.base_url = base_url.rstrip("/")
self.token = token.strip()
self.timeout = timeout
def with_token(self, token: str) -> HaClient:
return HaClient(self.base_url, token, timeout=self.timeout)
def _headers(self, extra: dict[str, str] | None = None) -> dict[str, str]:
headers: dict[str, str] = {"Accept": "application/json"}
if extra:
headers.update(extra)
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
return headers
async def _request(
self,
method: str,
path: str,
*,
json_body: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> Any:
url = f"{self.base_url}{path}"
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(
method,
url,
headers=self._headers(
{"Content-Type": "application/json"} if json_body is not None else None
),
json=json_body,
params=params,
)
if response.status_code >= 400:
detail = response.text.strip() or f"HTTP {response.status_code}"
raise HaApiError(detail, response.status_code)
if response.status_code == 204 or not response.content:
return None
return response.json()
async def login(self, token: str) -> dict[str, Any]:
url = f"{self.base_url}/auth/login"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
url,
headers={"Content-Type": "application/json"},
json={"token": token.strip()},
)
if response.status_code >= 400:
raise HaApiError(response.text.strip() or "Неверный токен", response.status_code)
return response.json()
async def me(self) -> dict[str, Any]:
return await self._request("GET", "/auth/me")
async def list_sessions(self) -> list[dict[str, Any]]:
result = await self._request("GET", "/chat/sessions")
return list(result or [])
async def create_session(self, title: str = "Telegram") -> dict[str, Any]:
return await self._request("POST", "/chat/sessions", json_body={"title": title})
async def get_messages(
self,
session_id: int,
*,
after_id: int | None = None,
limit: int = 100,
) -> dict[str, Any]:
params: dict[str, Any] = {"limit": limit}
if after_id is not None:
params["after_id"] = after_id
return await self._request(
"GET",
f"/chat/sessions/{session_id}/messages",
params=params,
)
async def generation_status(self, session_id: int) -> dict[str, Any]:
return await self._request("GET", f"/chat/sessions/{session_id}/generation")
async def get_reminders_snapshot(self) -> dict[str, Any]:
return await self._request("GET", "/reminders")
async def get_pomodoro_status(self) -> dict[str, Any]:
return await self._request("GET", "/pomodoro/status")
async def _stream(self, method: str, path: str, *, json_body: dict[str, Any] | None = None) -> AsyncIterator[SseChunk]:
url = f"{self.base_url}{path}"
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
method,
url,
headers=self._headers(
{"Content-Type": "application/json", "Accept": "text/event-stream"}
if json_body is not None
else {"Accept": "text/event-stream"}
),
json=json_body,
) as response:
if response.status_code == 404:
return
async for chunk in iter_sse(response):
yield chunk
async def send_message_stream(self, session_id: int, content: str) -> AsyncIterator[SseChunk]:
async for chunk in self._stream(
"POST",
f"/chat/sessions/{session_id}/messages",
json_body={"content": content},
):
yield chunk
async def send_message_with_image_stream(
self,
session_id: int,
content: str,
image_bytes: bytes,
*,
filename: str = "photo.jpg",
content_type: str = "image/jpeg",
) -> AsyncIterator[SseChunk]:
url = f"{self.base_url}/chat/sessions/{session_id}/messages"
files = {"image": (filename, image_bytes, content_type)}
data = {"content": content}
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
url,
headers=self._headers({"Accept": "text/event-stream"}),
files=files,
data=data,
) as response:
if response.status_code >= 400:
body = await response.aread()
raise HaApiError(body.decode("utf-8", errors="replace") or f"HTTP {response.status_code}", response.status_code)
async for chunk in iter_sse(response):
yield chunk
async def stream_generation(self, session_id: int) -> AsyncIterator[SseChunk]:
async for chunk in self._stream("GET", f"/chat/sessions/{session_id}/generation/stream"):
yield chunk
async def find_or_create_telegram_session(self) -> int:
sessions = await self.list_sessions()
for session in sessions:
if session.get("title") == "Telegram":
return int(session["id"])
created = await self.create_session("Telegram")
return int(created["id"])
async def download_media(self, path_or_url: str, *, ha_api_base: str | None = None) -> bytes:
base = ha_api_base or self.base_url
url = resolve_media_url(base, path_or_url)
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(url, headers=self._headers())
if response.status_code >= 400:
raise HaApiError(response.text.strip() or f"HTTP {response.status_code}", response.status_code)
return response.content
def resolve_media_url(ha_api_base: str, path_or_url: str) -> str:
raw = (path_or_url or "").strip()
if raw.startswith("http://") or raw.startswith("https://"):
return raw
origin = ha_api_base.rstrip("/")
if origin.endswith("/api/v1"):
origin = origin[: -len("/api/v1")]
if not raw.startswith("/"):
raw = f"/{raw}"
return f"{origin}{raw}"