added RAG, Multiuser, TG bot
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from bot.sse import SseChunk, iter_sse
|
||||
|
||||
|
||||
class HaApiError(RuntimeError):
|
||||
def __init__(self, message: str, status_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HaClient:
|
||||
def __init__(self, base_url: str, token: str = "", *, timeout: float = 120.0) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token.strip()
|
||||
self.timeout = timeout
|
||||
|
||||
def with_token(self, token: str) -> HaClient:
|
||||
return HaClient(self.base_url, token, timeout=self.timeout)
|
||||
|
||||
def _headers(self, extra: dict[str, str] | None = None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {"Accept": "application/json"}
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json_body: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
url = f"{self.base_url}{path}"
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.request(
|
||||
method,
|
||||
url,
|
||||
headers=self._headers(
|
||||
{"Content-Type": "application/json"} if json_body is not None else None
|
||||
),
|
||||
json=json_body,
|
||||
params=params,
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
detail = response.text.strip() or f"HTTP {response.status_code}"
|
||||
raise HaApiError(detail, response.status_code)
|
||||
if response.status_code == 204 or not response.content:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
async def login(self, token: str) -> dict[str, Any]:
|
||||
url = f"{self.base_url}/auth/login"
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"token": token.strip()},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
raise HaApiError(response.text.strip() or "Неверный токен", response.status_code)
|
||||
return response.json()
|
||||
|
||||
async def me(self) -> dict[str, Any]:
|
||||
return await self._request("GET", "/auth/me")
|
||||
|
||||
async def list_sessions(self) -> list[dict[str, Any]]:
|
||||
result = await self._request("GET", "/chat/sessions")
|
||||
return list(result or [])
|
||||
|
||||
async def create_session(self, title: str = "Telegram") -> dict[str, Any]:
|
||||
return await self._request("POST", "/chat/sessions", json_body={"title": title})
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: int,
|
||||
*,
|
||||
after_id: int | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {"limit": limit}
|
||||
if after_id is not None:
|
||||
params["after_id"] = after_id
|
||||
return await self._request(
|
||||
"GET",
|
||||
f"/chat/sessions/{session_id}/messages",
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def generation_status(self, session_id: int) -> dict[str, Any]:
|
||||
return await self._request("GET", f"/chat/sessions/{session_id}/generation")
|
||||
|
||||
async def get_reminders_snapshot(self) -> dict[str, Any]:
|
||||
return await self._request("GET", "/reminders")
|
||||
|
||||
async def get_pomodoro_status(self) -> dict[str, Any]:
|
||||
return await self._request("GET", "/pomodoro/status")
|
||||
|
||||
async def _stream(self, method: str, path: str, *, json_body: dict[str, Any] | None = None) -> AsyncIterator[SseChunk]:
|
||||
url = f"{self.base_url}{path}"
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
async with client.stream(
|
||||
method,
|
||||
url,
|
||||
headers=self._headers(
|
||||
{"Content-Type": "application/json", "Accept": "text/event-stream"}
|
||||
if json_body is not None
|
||||
else {"Accept": "text/event-stream"}
|
||||
),
|
||||
json=json_body,
|
||||
) as response:
|
||||
if response.status_code == 404:
|
||||
return
|
||||
async for chunk in iter_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def send_message_stream(self, session_id: int, content: str) -> AsyncIterator[SseChunk]:
|
||||
async for chunk in self._stream(
|
||||
"POST",
|
||||
f"/chat/sessions/{session_id}/messages",
|
||||
json_body={"content": content},
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def stream_generation(self, session_id: int) -> AsyncIterator[SseChunk]:
|
||||
async for chunk in self._stream("GET", f"/chat/sessions/{session_id}/generation/stream"):
|
||||
yield chunk
|
||||
|
||||
async def find_or_create_telegram_session(self) -> int:
|
||||
sessions = await self.list_sessions()
|
||||
for session in sessions:
|
||||
if session.get("title") == "Telegram":
|
||||
return int(session["id"])
|
||||
created = await self.create_session("Telegram")
|
||||
return int(created["id"])
|
||||
Reference in New Issue
Block a user