import asyncio import logging import os import uuid from pathlib import Path import httpx from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/") SD_STEPS = int(os.getenv("SD_STEPS", "28")) SD_CFG = float(os.getenv("SD_CFG", "7")) SD_SAMPLER = os.getenv("SD_SAMPLER", "euler") SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal") SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "") SD_DEFAULT_NEGATIVE = os.getenv( "SD_DEFAULT_NEGATIVE", "low quality, worst quality, blurry, bad anatomy, watermark, text", ) # Anima split-model settings SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors") SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors") SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors") IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images")) ANIMA_CHECKPOINTS = {"anima-preview3-base.safetensors"} PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"} def _use_anima() -> bool: return bool(SD_UNET) and not SD_CHECKPOINT def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]: if "\n\nNegative prompt:" in full_prompt: pos, _, neg = full_prompt.partition("\n\nNegative prompt:") return pos.strip(), neg.strip() return full_prompt.strip(), SD_DEFAULT_NEGATIVE def _build_workflow(positive: str, negative: str) -> dict: seed = int(uuid.uuid4().int % 2**32) if _use_anima(): return { "44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}}, "45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}}, "15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}}, "28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}}, "11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}}, "12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}}, "19": { "class_type": "KSampler", "inputs": { "model": ["44", 0], "positive": ["11", 0], "negative": ["12", 0], "latent_image": ["28", 0], "seed": seed, "steps": SD_STEPS, "cfg": SD_CFG, "sampler_name": os.getenv("SD_SAMPLER", "er_sde"), "scheduler": os.getenv("SD_SCHEDULER", "simple"), "denoise": 1.0, }, }, "8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}}, "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, } # Standard checkpoint workflow (Pony / SDXL) return { "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}}, "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}}, "6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}}, "7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}}, "8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}}, "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, "10": { "class_type": "KSampler", "inputs": { "model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], "seed": seed, "steps": SD_STEPS, "cfg": SD_CFG, "sampler_name": SD_SAMPLER, "scheduler": SD_SCHEDULER, "denoise": 1.0, }, }, } async def check_sd() -> bool: try: async with httpx.AsyncClient(timeout=5) as client: r = await client.get(f"{SD_BASE_URL}/system_stats") return r.status_code == 200 except Exception: return False async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]: neg = negative_prompt or SD_DEFAULT_NEGATIVE workflow = _build_workflow(prompt, neg) client_id = uuid.uuid4().hex logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt) async with httpx.AsyncClient(timeout=300) as client: resp = await client.post( f"{SD_BASE_URL}/prompt", json={"prompt": workflow, "client_id": client_id}, ) resp.raise_for_status() prompt_id = resp.json()["prompt_id"] logger.info("ComfyUI queued prompt_id=%s", prompt_id) for _ in range(300): await asyncio.sleep(1) hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}") data = hist.json() if prompt_id in data: entry = data[prompt_id] # Log any errors from ComfyUI if entry.get("status", {}).get("status_str") == "error": msgs = entry.get("status", {}).get("messages", []) logger.error("ComfyUI workflow error: %s", msgs) outputs = entry.get("outputs", {}) for node_output in outputs.values(): if "images" in node_output: img_info = node_output["images"][0] img_resp = await client.get( f"{SD_BASE_URL}/view", params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")}, ) img_resp.raise_for_status() image_bytes = img_resp.content IMAGES_DIR.mkdir(parents=True, exist_ok=True) filename = f"{uuid.uuid4().hex}.png" (IMAGES_DIR / filename).write_bytes(image_bytes) logger.info("ComfyUI done → saved %s", filename) return image_bytes, f"images/{filename}" logger.error("ComfyUI no image output. status=%s outputs_keys=%s", entry.get("status"), list(outputs.keys())) break raise RuntimeError("ComfyUI generation timed out or produced no output") async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]: positive, negative = split_prompt_and_negative(full_prompt) try: _, rel_path = await txt2img(positive, negative) return rel_path, None except Exception as e: logger.error("ComfyUI error: %s", e) return None, str(e)