from __future__ import annotations from collections.abc import AsyncIterator from typing import Any import httpx from bot.sse import SseChunk, iter_sse class HaApiError(RuntimeError): def __init__(self, message: str, status_code: int | None = None) -> None: super().__init__(message) self.status_code = status_code class HaClient: def __init__(self, base_url: str, token: str = "", *, timeout: float = 120.0) -> None: self.base_url = base_url.rstrip("/") self.token = token.strip() self.timeout = timeout def with_token(self, token: str) -> HaClient: return HaClient(self.base_url, token, timeout=self.timeout) def _headers(self, extra: dict[str, str] | None = None) -> dict[str, str]: headers: dict[str, str] = {"Accept": "application/json"} if extra: headers.update(extra) if self.token: headers["Authorization"] = f"Bearer {self.token}" return headers async def _request( self, method: str, path: str, *, json_body: dict[str, Any] | None = None, params: dict[str, Any] | None = None, ) -> Any: url = f"{self.base_url}{path}" async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.request( method, url, headers=self._headers( {"Content-Type": "application/json"} if json_body is not None else None ), json=json_body, params=params, ) if response.status_code >= 400: detail = response.text.strip() or f"HTTP {response.status_code}" raise HaApiError(detail, response.status_code) if response.status_code == 204 or not response.content: return None return response.json() async def login(self, token: str) -> dict[str, Any]: url = f"{self.base_url}/auth/login" async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( url, headers={"Content-Type": "application/json"}, json={"token": token.strip()}, ) if response.status_code >= 400: raise HaApiError(response.text.strip() or "Неверный токен", response.status_code) return response.json() async def me(self) -> dict[str, Any]: return await self._request("GET", "/auth/me") async def list_sessions(self) -> list[dict[str, Any]]: result = await self._request("GET", "/chat/sessions") return list(result or []) async def create_session(self, title: str = "Telegram") -> dict[str, Any]: return await self._request("POST", "/chat/sessions", json_body={"title": title}) async def get_messages( self, session_id: int, *, after_id: int | None = None, limit: int = 100, ) -> dict[str, Any]: params: dict[str, Any] = {"limit": limit} if after_id is not None: params["after_id"] = after_id return await self._request( "GET", f"/chat/sessions/{session_id}/messages", params=params, ) async def generation_status(self, session_id: int) -> dict[str, Any]: return await self._request("GET", f"/chat/sessions/{session_id}/generation") async def get_reminders_snapshot(self) -> dict[str, Any]: return await self._request("GET", "/reminders") async def get_pomodoro_status(self) -> dict[str, Any]: return await self._request("GET", "/pomodoro/status") async def _stream(self, method: str, path: str, *, json_body: dict[str, Any] | None = None) -> AsyncIterator[SseChunk]: url = f"{self.base_url}{path}" async with httpx.AsyncClient(timeout=None) as client: async with client.stream( method, url, headers=self._headers( {"Content-Type": "application/json", "Accept": "text/event-stream"} if json_body is not None else {"Accept": "text/event-stream"} ), json=json_body, ) as response: if response.status_code == 404: return async for chunk in iter_sse(response): yield chunk async def send_message_stream(self, session_id: int, content: str) -> AsyncIterator[SseChunk]: async for chunk in self._stream( "POST", f"/chat/sessions/{session_id}/messages", json_body={"content": content}, ): yield chunk async def stream_generation(self, session_id: int) -> AsyncIterator[SseChunk]: async for chunk in self._stream("GET", f"/chat/sessions/{session_id}/generation/stream"): yield chunk async def find_or_create_telegram_session(self) -> int: sessions = await self.list_sessions() for session in sessions: if session.get("title") == "Telegram": return int(session["id"]) created = await self.create_session("Telegram") return int(created["id"])