added api
This commit is contained in:
@@ -0,0 +1,276 @@
|
||||
import asyncio
|
||||
import random
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
ANIMA_QUALITY_PREFIX = "masterpiece, best quality, score_7, anime"
|
||||
ANIMA_DEFAULT_NEGATIVE = (
|
||||
"worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
|
||||
)
|
||||
|
||||
ROFL_PROMPTS = [
|
||||
f"{ANIMA_QUALITY_PREFIX}, confused cat in tiny business suit, server room, meme, chibi",
|
||||
f"{ANIMA_QUALITY_PREFIX}, potato with sunglasses on skateboard, absurd cartoon, silly",
|
||||
f"{ANIMA_QUALITY_PREFIX}, astronaut watering houseplant on the moon, wholesome, cute",
|
||||
f"{ANIMA_QUALITY_PREFIX}, rubber duck as judge at programming contest, comic style",
|
||||
f"{ANIMA_QUALITY_PREFIX}, llama DJ at house party, neon lights, party, silly",
|
||||
]
|
||||
|
||||
|
||||
def _use_anima(settings) -> bool:
|
||||
return bool(settings.comfyui_unet.strip()) and not settings.comfyui_checkpoint.strip()
|
||||
|
||||
|
||||
def _build_anima_workflow(
|
||||
positive: str,
|
||||
negative: str,
|
||||
seed: int,
|
||||
settings,
|
||||
) -> dict[str, Any]:
|
||||
workflow: dict[str, Any] = {
|
||||
"44": {
|
||||
"class_type": "UNETLoader",
|
||||
"inputs": {"unet_name": settings.comfyui_unet, "weight_dtype": "default"},
|
||||
},
|
||||
"45": {
|
||||
"class_type": "CLIPLoader",
|
||||
"inputs": {
|
||||
"clip_name": settings.comfyui_clip,
|
||||
"type": "stable_diffusion",
|
||||
"device": "default",
|
||||
},
|
||||
},
|
||||
"15": {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {"vae_name": settings.comfyui_vae},
|
||||
},
|
||||
"28": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {
|
||||
"width": settings.comfyui_width,
|
||||
"height": settings.comfyui_height,
|
||||
"batch_size": 1,
|
||||
},
|
||||
},
|
||||
"11": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": positive, "clip": ["45", 0]},
|
||||
},
|
||||
"12": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": negative, "clip": ["45", 0]},
|
||||
},
|
||||
"19": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["44", 0],
|
||||
"positive": ["11", 0],
|
||||
"negative": ["12", 0],
|
||||
"latent_image": ["28", 0],
|
||||
"seed": seed,
|
||||
"steps": settings.comfyui_steps,
|
||||
"cfg": settings.comfyui_cfg,
|
||||
"sampler_name": settings.comfyui_sampler,
|
||||
"scheduler": settings.comfyui_scheduler,
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {"samples": ["19", 0], "vae": ["15", 0]},
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {"filename_prefix": "assistant", "images": ["8", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
lora = settings.comfyui_style_lora.strip()
|
||||
if lora:
|
||||
workflow["46"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": lora,
|
||||
"model": ["44", 0],
|
||||
"clip": ["45", 0],
|
||||
"strength_model": settings.comfyui_style_lora_weight,
|
||||
"strength_clip": settings.comfyui_style_lora_weight,
|
||||
},
|
||||
}
|
||||
workflow["19"]["inputs"]["model"] = ["46", 0]
|
||||
workflow["11"]["inputs"]["clip"] = ["46", 1]
|
||||
workflow["12"]["inputs"]["clip"] = ["46", 1]
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
def _build_checkpoint_workflow(
|
||||
positive: str,
|
||||
negative: str,
|
||||
seed: int,
|
||||
settings,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": settings.comfyui_checkpoint},
|
||||
},
|
||||
"5": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {
|
||||
"width": settings.comfyui_width,
|
||||
"height": settings.comfyui_height,
|
||||
"batch_size": 1,
|
||||
},
|
||||
},
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": positive, "clip": ["4", 1]},
|
||||
},
|
||||
"7": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": negative, "clip": ["4", 1]},
|
||||
},
|
||||
"10": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
"seed": seed,
|
||||
"steps": settings.comfyui_steps,
|
||||
"cfg": settings.comfyui_cfg,
|
||||
"sampler_name": settings.comfyui_sampler,
|
||||
"scheduler": settings.comfyui_scheduler,
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {"samples": ["10", 0], "vae": ["4", 2]},
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {"filename_prefix": "assistant", "images": ["8", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str, seed: int, settings) -> dict[str, Any]:
|
||||
if _use_anima(settings):
|
||||
return _build_anima_workflow(positive, negative, seed, settings)
|
||||
return _build_checkpoint_workflow(positive, negative, seed, settings)
|
||||
|
||||
|
||||
def _wrap_positive_prompt(prompt: str, settings) -> str:
|
||||
text = prompt.strip()
|
||||
if not text:
|
||||
return text
|
||||
if _use_anima(settings) and ANIMA_QUALITY_PREFIX.lower() not in text.lower():
|
||||
return f"{ANIMA_QUALITY_PREFIX}, {text}"
|
||||
return text
|
||||
|
||||
|
||||
class ComfyUIClient:
|
||||
def __init__(self) -> None:
|
||||
settings = get_settings()
|
||||
self.base_url = settings.comfyui_base_url.rstrip("/")
|
||||
self.enabled = settings.comfyui_enabled
|
||||
self.settings = settings
|
||||
self.output_dir = Path(settings.generated_media_dir)
|
||||
self.poll_interval = settings.comfyui_poll_interval_sec
|
||||
self.timeout = settings.comfyui_timeout_sec
|
||||
|
||||
def _default_negative(self) -> str:
|
||||
if _use_anima(self.settings):
|
||||
return self.settings.comfyui_negative_prompt or ANIMA_DEFAULT_NEGATIVE
|
||||
return self.settings.comfyui_negative_prompt
|
||||
|
||||
def _ensure_output_dir(self) -> None:
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
negative_prompt: str | None = None,
|
||||
seed: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if not self.enabled:
|
||||
return {"ok": False, "error": "ComfyUI отключён (COMFYUI_ENABLED=false)"}
|
||||
|
||||
if not _use_anima(self.settings) and not self.settings.comfyui_checkpoint.strip():
|
||||
return {
|
||||
"ok": False,
|
||||
"error": "Не задан COMFYUI_UNET (Anima) или COMFYUI_CHECKPOINT",
|
||||
}
|
||||
|
||||
self._ensure_output_dir()
|
||||
seed = seed if seed is not None else random.randint(1, 2**31 - 1)
|
||||
positive = _wrap_positive_prompt(prompt, self.settings)
|
||||
negative = negative_prompt or self._default_negative()
|
||||
workflow = _build_workflow(positive, negative, seed, self.settings)
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
return {"ok": False, "error": f"ComfyUI prompt error: {response.text[:300]}"}
|
||||
prompt_id = response.json().get("prompt_id")
|
||||
if not prompt_id:
|
||||
return {"ok": False, "error": "ComfyUI не вернул prompt_id"}
|
||||
|
||||
elapsed = 0.0
|
||||
while elapsed < self.timeout:
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
elapsed += self.poll_interval
|
||||
hist_resp = await client.get(f"{self.base_url}/history/{prompt_id}")
|
||||
if hist_resp.status_code != 200:
|
||||
continue
|
||||
history = hist_resp.json()
|
||||
if prompt_id not in history:
|
||||
continue
|
||||
entry = history[prompt_id]
|
||||
status = (entry.get("status") or {}).get("status_str")
|
||||
if status == "error":
|
||||
msgs = entry.get("status", {}).get("messages", [])
|
||||
return {"ok": False, "error": f"ComfyUI workflow error: {msgs}"}
|
||||
|
||||
outputs = entry.get("outputs") or {}
|
||||
for node_output in outputs.values():
|
||||
images = node_output.get("images") or []
|
||||
if not images:
|
||||
continue
|
||||
image_info = images[0]
|
||||
view_params = {
|
||||
"filename": image_info["filename"],
|
||||
"subfolder": image_info.get("subfolder", ""),
|
||||
"type": image_info.get("type", "output"),
|
||||
}
|
||||
img_resp = await client.get(f"{self.base_url}/view", params=view_params)
|
||||
if img_resp.status_code != 200:
|
||||
continue
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
out_path = self.output_dir / filename
|
||||
out_path.write_bytes(img_resp.content)
|
||||
return {
|
||||
"ok": True,
|
||||
"filename": filename,
|
||||
"url": f"/api/v1/media/generated/{filename}",
|
||||
"prompt": positive,
|
||||
"backend": "anima" if _use_anima(self.settings) else "checkpoint",
|
||||
}
|
||||
|
||||
return {"ok": False, "error": f"Таймаут генерации ({self.timeout}s)"}
|
||||
|
||||
def random_rofl_prompt(self) -> str:
|
||||
return random.choice(ROFL_PROMPTS)
|
||||
Reference in New Issue
Block a user