added rp api
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpChatClient:
|
||||
def __init__(self) -> None:
|
||||
settings = get_settings()
|
||||
self.base_url = settings.rp_chat_base_url.rstrip("/")
|
||||
self.enabled = settings.rp_chat_enabled
|
||||
self.timeout = settings.rp_chat_timeout_sec
|
||||
|
||||
def _client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=self.timeout)
|
||||
|
||||
async def health(self) -> dict[str, Any]:
|
||||
async with self._client() as client:
|
||||
response = await client.get(f"{self.base_url}/health")
|
||||
return {"ok": response.status_code == 200, "status_code": response.status_code}
|
||||
|
||||
async def sd_prompt(
|
||||
self,
|
||||
persona_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
appearance_override: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"persona_id": persona_id,
|
||||
"messages": messages,
|
||||
"outfit_json": "[]",
|
||||
"use_prose": False,
|
||||
}
|
||||
if appearance_override:
|
||||
payload["appearance_override"] = appearance_override
|
||||
|
||||
async with self._client() as client:
|
||||
response = await client.post(f"{self.base_url}/api/sd-prompt", json=payload)
|
||||
if response.status_code >= 400:
|
||||
return {"ok": False, "error": response.text[:500]}
|
||||
data = response.json()
|
||||
if data.get("skipped") or data.get("error"):
|
||||
return {"ok": False, "error": data.get("error", "should_generate=false"), "raw": data}
|
||||
return {"ok": True, **data}
|
||||
|
||||
async def generate(self, positive: str, negative: str = "") -> dict[str, Any]:
|
||||
async with self._client() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json={"positive": positive, "negative": negative},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
return {"ok": False, "error": response.text[:500]}
|
||||
data = response.json()
|
||||
if data.get("status") != "ok" or not data.get("image_path"):
|
||||
return {"ok": False, "error": data.get("detail", "generation failed")}
|
||||
return {"ok": True, **data}
|
||||
|
||||
async def download_image(self, image_path: str) -> bytes | None:
|
||||
path = image_path if image_path.startswith("/") else f"/{image_path}"
|
||||
async with self._client() as client:
|
||||
response = await client.get(f"{self.base_url}{path}")
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
return response.content
|
||||
|
||||
async def save_image_locally(self, image_path: str) -> dict[str, Any]:
|
||||
content = await self.download_image(image_path)
|
||||
if not content:
|
||||
return {"ok": False, "error": f"Не удалось скачать {image_path}"}
|
||||
|
||||
settings = get_settings()
|
||||
out_dir = Path(settings.generated_media_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
(out_dir / filename).write_bytes(content)
|
||||
return {
|
||||
"ok": True,
|
||||
"filename": filename,
|
||||
"url": f"/api/v1/media/generated/{filename}",
|
||||
"source_path": image_path,
|
||||
}
|
||||
Reference in New Issue
Block a user