144 lines
5.3 KiB
Python
144 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from bot.sse import SseChunk, iter_sse
|
|
|
|
|
|
class HaApiError(RuntimeError):
|
|
def __init__(self, message: str, status_code: int | None = None) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
class HaClient:
|
|
def __init__(self, base_url: str, token: str = "", *, timeout: float = 120.0) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.token = token.strip()
|
|
self.timeout = timeout
|
|
|
|
def with_token(self, token: str) -> HaClient:
|
|
return HaClient(self.base_url, token, timeout=self.timeout)
|
|
|
|
def _headers(self, extra: dict[str, str] | None = None) -> dict[str, str]:
|
|
headers: dict[str, str] = {"Accept": "application/json"}
|
|
if extra:
|
|
headers.update(extra)
|
|
if self.token:
|
|
headers["Authorization"] = f"Bearer {self.token}"
|
|
return headers
|
|
|
|
async def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
json_body: dict[str, Any] | None = None,
|
|
params: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
url = f"{self.base_url}{path}"
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.request(
|
|
method,
|
|
url,
|
|
headers=self._headers(
|
|
{"Content-Type": "application/json"} if json_body is not None else None
|
|
),
|
|
json=json_body,
|
|
params=params,
|
|
)
|
|
if response.status_code >= 400:
|
|
detail = response.text.strip() or f"HTTP {response.status_code}"
|
|
raise HaApiError(detail, response.status_code)
|
|
if response.status_code == 204 or not response.content:
|
|
return None
|
|
return response.json()
|
|
|
|
async def login(self, token: str) -> dict[str, Any]:
|
|
url = f"{self.base_url}/auth/login"
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
url,
|
|
headers={"Content-Type": "application/json"},
|
|
json={"token": token.strip()},
|
|
)
|
|
if response.status_code >= 400:
|
|
raise HaApiError(response.text.strip() or "Неверный токен", response.status_code)
|
|
return response.json()
|
|
|
|
async def me(self) -> dict[str, Any]:
|
|
return await self._request("GET", "/auth/me")
|
|
|
|
async def list_sessions(self) -> list[dict[str, Any]]:
|
|
result = await self._request("GET", "/chat/sessions")
|
|
return list(result or [])
|
|
|
|
async def create_session(self, title: str = "Telegram") -> dict[str, Any]:
|
|
return await self._request("POST", "/chat/sessions", json_body={"title": title})
|
|
|
|
async def get_messages(
|
|
self,
|
|
session_id: int,
|
|
*,
|
|
after_id: int | None = None,
|
|
limit: int = 100,
|
|
) -> dict[str, Any]:
|
|
params: dict[str, Any] = {"limit": limit}
|
|
if after_id is not None:
|
|
params["after_id"] = after_id
|
|
return await self._request(
|
|
"GET",
|
|
f"/chat/sessions/{session_id}/messages",
|
|
params=params,
|
|
)
|
|
|
|
async def generation_status(self, session_id: int) -> dict[str, Any]:
|
|
return await self._request("GET", f"/chat/sessions/{session_id}/generation")
|
|
|
|
async def get_reminders_snapshot(self) -> dict[str, Any]:
|
|
return await self._request("GET", "/reminders")
|
|
|
|
async def get_pomodoro_status(self) -> dict[str, Any]:
|
|
return await self._request("GET", "/pomodoro/status")
|
|
|
|
async def _stream(self, method: str, path: str, *, json_body: dict[str, Any] | None = None) -> AsyncIterator[SseChunk]:
|
|
url = f"{self.base_url}{path}"
|
|
async with httpx.AsyncClient(timeout=None) as client:
|
|
async with client.stream(
|
|
method,
|
|
url,
|
|
headers=self._headers(
|
|
{"Content-Type": "application/json", "Accept": "text/event-stream"}
|
|
if json_body is not None
|
|
else {"Accept": "text/event-stream"}
|
|
),
|
|
json=json_body,
|
|
) as response:
|
|
if response.status_code == 404:
|
|
return
|
|
async for chunk in iter_sse(response):
|
|
yield chunk
|
|
|
|
async def send_message_stream(self, session_id: int, content: str) -> AsyncIterator[SseChunk]:
|
|
async for chunk in self._stream(
|
|
"POST",
|
|
f"/chat/sessions/{session_id}/messages",
|
|
json_body={"content": content},
|
|
):
|
|
yield chunk
|
|
|
|
async def stream_generation(self, session_id: int) -> AsyncIterator[SseChunk]:
|
|
async for chunk in self._stream("GET", f"/chat/sessions/{session_id}/generation/stream"):
|
|
yield chunk
|
|
|
|
async def find_or_create_telegram_session(self) -> int:
|
|
sessions = await self.list_sessions()
|
|
for session in sessions:
|
|
if session.get("title") == "Telegram":
|
|
return int(session["id"])
|
|
created = await self.create_session("Telegram")
|
|
return int(created["id"])
|