import asyncio import logging import os import uuid from pathlib import Path from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) def _parse_basic_auth() -> httpx.BasicAuth | None: """ Vast Caddy on mapped ports often uses Basic realm=restricted. Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD. """ raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip() if raw: if ":" in raw: user, _, password = raw.partition(":") else: user, password = "", raw return httpx.BasicAuth(user, password) user = (os.getenv("SD_COMFY_USER") or "").strip() password = (os.getenv("SD_COMFY_PASSWORD") or "").strip() if user or password: return httpx.BasicAuth(user, password) return None SD_BASIC_AUTH = _parse_basic_auth() def _parse_comfy_config() -> tuple[str, dict[str, str]]: """ SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=... API paths must be base + /prompt, not ...?token=xxx/prompt """ raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip() extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip() parsed = urlparse(raw) base = f"{parsed.scheme}://{parsed.netloc}" path = (parsed.path or "").rstrip("/") if path and path != "/": base = f"{base}{path}" query: dict[str, str] = {} for key, values in parse_qs(parsed.query).items(): if values: query[key] = values[-1] if extra_token: query["token"] = extra_token base = base.rstrip("/") # Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply if "trycloudflare.com" in base.lower(): if query.pop("token", None): logger.info( "SD_BASE_URL is trycloudflare tunnel: Vast token stripped. " "Use tunnel for port 8188 only (see instance Port Mapping)." ) return base, query SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config() def _comfy_url(path: str) -> str: if not path.startswith("/"): path = f"/{path}" return f"{SD_BASE_URL}{path}" def _log_comfy_target() -> str: if SD_QUERY_PARAMS.get("token"): return f"{SD_BASE_URL}?token=***" return SD_BASE_URL def _absolute_url(location: str, fallback_path: str = "/") -> str: if not location: return _comfy_url(fallback_path) if location.startswith(("http://", "https://")): return location if location.startswith("/"): return f"{SD_BASE_URL}{location}" return f"{SD_BASE_URL}/{location}" def _url_with_token(url: str) -> str: """Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect).""" if not SD_QUERY_PARAMS.get("token"): return url p = urlparse(url) q: dict[str, str] = {} for key, values in parse_qs(p.query).items(): if values: q[key] = values[-1] q.update(SD_QUERY_PARAMS) return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), "")) def _merge_params(extra: dict | None) -> dict | None: if not SD_QUERY_PARAMS and not extra: return None merged = dict(SD_QUERY_PARAMS) if extra: merged.update(extra) return merged def _is_vast_gateway() -> bool: return "trycloudflare.com" not in SD_BASE_URL.lower() def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient: return httpx.AsyncClient( timeout=timeout, follow_redirects=False, auth=SD_BASIC_AUTH, ) async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None: """ Vast Caddy: browser opens /?token=… and gets a session cookie; API then works. Prime with redirects so Set-Cookie is collected, then merge into the API client. """ token = SD_QUERY_PARAMS.get("token") if not token or not _is_vast_gateway(): return try: async with httpx.AsyncClient( timeout=30, follow_redirects=True, auth=SD_BASIC_AUTH, ) as prime: r = await prime.get(_comfy_url("/"), params={"token": token}) client.cookies.update(prime.cookies) logger.info( "Comfy gateway prime GET /?token=*** → %s, cookies=%s", r.status_code, list(prime.cookies.keys()) or "(none)", ) except Exception as e: logger.warning("Comfy gateway prime failed: %s", e) async def _comfy_request( client: httpx.AsyncClient, method: str, path: str, *, params: dict | None = None, **kwargs, ) -> httpx.Response: """ Comfy API: trycloudflare tunnel = no token. Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached. """ url = _comfy_url(path) extra = params or {} token = SD_QUERY_PARAMS.get("token") use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None) if token and _is_vast_gateway(): await _prime_comfy_gateway(client) req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None) resp: httpx.Response | None = None for hop in range(6): resp = await client.request(method, url, params=req_params, **kwargs) if resp.status_code not in (301, 302, 303, 307, 308): return resp loc = _absolute_url(resp.headers.get("location", ""), path) url = _url_with_token(loc) if use_vast_auth else loc req_params = extra or None logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0]) assert resp is not None return resp SD_STEPS = int(os.getenv("SD_STEPS", "28")) SD_CFG = float(os.getenv("SD_CFG", "7")) SD_SAMPLER = os.getenv("SD_SAMPLER", "euler") SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal") SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "") SD_DEFAULT_NEGATIVE = os.getenv( "SD_DEFAULT_NEGATIVE", "low quality, worst quality, blurry, bad anatomy, watermark, text", ) # Anima split-model settings SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors") SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors") SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors") SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "") SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7")) IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images")) ANIMA_CHECKPOINTS = {"anima-preview3-base.safetensors"} PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"} def _use_anima() -> bool: return bool(SD_UNET) and not SD_CHECKPOINT def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]: # Try new separator first sep = "__NEGATIVE_PROMPT__" if f"\n{sep}\n" in full_prompt: pos, _, neg = full_prompt.partition(f"\n{sep}\n") return pos.strip(), neg.strip() # Fallback to old format if "\n\nNegative prompt:" in full_prompt: pos, _, neg = full_prompt.partition("\n\nNegative prompt:") return pos.strip(), neg.strip() return full_prompt.strip(), SD_DEFAULT_NEGATIVE def _workflow_uses_anima(overrides: dict | None) -> bool: if overrides and overrides.get("checkpoint"): return False if overrides and overrides.get("unet"): return True return _use_anima() def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict: seed = int(uuid.uuid4().int % 2**32) o = overrides or {} if _workflow_uses_anima(o): unet = o.get("unet") or SD_UNET clip = o.get("clip") or SD_CLIP vae = o.get("vae") or SD_VAE workflow = { "44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}}, "45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}}, "15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}}, "28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}}, "11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}}, "12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}}, "19": { "class_type": "KSampler", "inputs": { "model": ["44", 0], "positive": ["11", 0], "negative": ["12", 0], "latent_image": ["28", 0], "seed": seed, "steps": SD_STEPS, "cfg": SD_CFG, "sampler_name": os.getenv("SD_SAMPLER", "er_sde"), "scheduler": os.getenv("SD_SCHEDULER", "simple"), "denoise": 1.0, }, }, "8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}}, "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, } if SD_STYLE_LORA: workflow["46"] = { "class_type": "LoraLoader", "inputs": { "lora_name": SD_STYLE_LORA, "model": ["44", 0], "clip": ["45", 0], "strength_model": SD_STYLE_LORA_WEIGHT, "strength_clip": SD_STYLE_LORA_WEIGHT, }, } workflow["19"]["inputs"]["model"] = ["46", 0] workflow["11"]["inputs"]["clip"] = ["46", 1] workflow["12"]["inputs"]["clip"] = ["46", 1] return workflow ckpt = o.get("checkpoint") or SD_CHECKPOINT return { "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}}, "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}}, "6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}}, "7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}}, "8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}}, "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, "10": { "class_type": "KSampler", "inputs": { "model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], "seed": seed, "steps": SD_STEPS, "cfg": SD_CFG, "sampler_name": SD_SAMPLER, "scheduler": SD_SCHEDULER, "denoise": 1.0, }, }, } async def comfy_api_request( method: str, path: str, *, params: dict | None = None, json_body: dict | None = None, timeout: float = 60, ) -> tuple[int, dict | str, dict]: """ Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset). """ async with _make_comfy_client(timeout=timeout) as client: await _prime_comfy_gateway(client) token = SD_QUERY_PARAMS.get("token") use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None) req_params = _merge_params(params) if use_vast else (params or None) req_kwargs: dict = {} if json_body is not None and method.upper() not in ("GET", "HEAD"): req_kwargs["json"] = json_body resp = await _comfy_request( client, method.upper(), path, params=req_params, **req_kwargs, ) headers = { k: resp.headers.get(k) for k in ("content-type", "location", "www-authenticate") if resp.headers.get(k) } try: body = resp.json() except Exception: body = resp.text[:8000] return resp.status_code, body, headers async def fetch_object_info() -> dict: status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120) if status != 200 or not isinstance(body, dict): raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}") return body async def check_sd() -> bool: try: async with _make_comfy_client(timeout=15) as client: await _prime_comfy_gateway(client) r = await _comfy_request(client, "GET", "/system_stats") return r.status_code == 200 except Exception: return False async def txt2img( prompt: str, negative_prompt: str | None = None, *, overrides: dict | None = None, ) -> tuple[bytes, str]: neg = negative_prompt or SD_DEFAULT_NEGATIVE workflow = _build_workflow(prompt, neg, overrides) client_id = uuid.uuid4().hex logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt) async with _make_comfy_client() as client: await _prime_comfy_gateway(client) resp = await _comfy_request( client, "POST", "/prompt", json={"prompt": workflow, "client_id": client_id}, ) resp.raise_for_status() prompt_id = resp.json()["prompt_id"] logger.info("ComfyUI queued prompt_id=%s", prompt_id) for _ in range(300): await asyncio.sleep(1) hist = await _comfy_request(client, "GET", f"/history/{prompt_id}") data = hist.json() if prompt_id in data: entry = data[prompt_id] # Log any errors from ComfyUI if entry.get("status", {}).get("status_str") == "error": msgs = entry.get("status", {}).get("messages", []) logger.error("ComfyUI workflow error: %s", msgs) outputs = entry.get("outputs", {}) for node_output in outputs.values(): if "images" in node_output: img_info = node_output["images"][0] img_resp = await _comfy_request( client, "GET", "/view", params={ "filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output"), }, ) img_resp.raise_for_status() image_bytes = img_resp.content IMAGES_DIR.mkdir(parents=True, exist_ok=True) filename = f"{uuid.uuid4().hex}.png" (IMAGES_DIR / filename).write_bytes(image_bytes) logger.info("ComfyUI done → saved %s", filename) return image_bytes, f"images/{filename}" logger.error("ComfyUI no image output. status=%s outputs_keys=%s", entry.get("status"), list(outputs.keys())) break raise RuntimeError("ComfyUI generation timed out or produced no output") async def generate_from_full_prompt( full_prompt: str, *, overrides: dict | None = None, ) -> tuple[str | None, str | None]: positive, negative = split_prompt_and_negative(full_prompt) try: _, rel_path = await txt2img(positive, negative, overrides=overrides) return rel_path, None except httpx.HTTPStatusError as e: code = e.response.status_code if code == 401: logger.error( "ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) " "and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. " "Test: curl -u user:pass http://IP:PORT/system_stats " "or open /?token=… in browser then curl with cookies. " "Alternative: trycloudflare URL for localhost:8188 in Port Mapping." ) elif code in (301, 302, 303, 307, 308): logger.error( "ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. " "SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com " "(no ?token=). Location: %s", code, e.response.headers.get("location"), ) else: logger.error("ComfyUI HTTP %s: %s", code, e) return None, str(e) except httpx.ConnectError as e: logger.error( "ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. " "Use trycloudflare URL from Port Mapping for localhost:8188.", e, ) return None, str(e) except Exception as e: logger.error("ComfyUI error: %s", e) return None, str(e)