455 lines
17 KiB
Python
455 lines
17 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
|
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _parse_basic_auth() -> httpx.BasicAuth | None:
|
|
"""
|
|
Vast Caddy on mapped ports often uses Basic realm=restricted.
|
|
Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD.
|
|
"""
|
|
raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip()
|
|
if raw:
|
|
if ":" in raw:
|
|
user, _, password = raw.partition(":")
|
|
else:
|
|
user, password = "", raw
|
|
return httpx.BasicAuth(user, password)
|
|
user = (os.getenv("SD_COMFY_USER") or "").strip()
|
|
password = (os.getenv("SD_COMFY_PASSWORD") or "").strip()
|
|
if user or password:
|
|
return httpx.BasicAuth(user, password)
|
|
return None
|
|
|
|
|
|
SD_BASIC_AUTH = _parse_basic_auth()
|
|
|
|
|
|
def _parse_comfy_config() -> tuple[str, dict[str, str]]:
|
|
"""
|
|
SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=...
|
|
API paths must be base + /prompt, not ...?token=xxx/prompt
|
|
"""
|
|
raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip()
|
|
extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip()
|
|
parsed = urlparse(raw)
|
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
path = (parsed.path or "").rstrip("/")
|
|
if path and path != "/":
|
|
base = f"{base}{path}"
|
|
query: dict[str, str] = {}
|
|
for key, values in parse_qs(parsed.query).items():
|
|
if values:
|
|
query[key] = values[-1]
|
|
if extra_token:
|
|
query["token"] = extra_token
|
|
base = base.rstrip("/")
|
|
# Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply
|
|
if "trycloudflare.com" in base.lower():
|
|
if query.pop("token", None):
|
|
logger.info(
|
|
"SD_BASE_URL is trycloudflare tunnel: Vast token stripped. "
|
|
"Use tunnel for port 8188 only (see instance Port Mapping)."
|
|
)
|
|
return base, query
|
|
|
|
|
|
SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config()
|
|
|
|
|
|
def _comfy_url(path: str) -> str:
|
|
if not path.startswith("/"):
|
|
path = f"/{path}"
|
|
return f"{SD_BASE_URL}{path}"
|
|
|
|
|
|
def _log_comfy_target() -> str:
|
|
if SD_QUERY_PARAMS.get("token"):
|
|
return f"{SD_BASE_URL}?token=***"
|
|
return SD_BASE_URL
|
|
|
|
|
|
def _absolute_url(location: str, fallback_path: str = "/") -> str:
|
|
if not location:
|
|
return _comfy_url(fallback_path)
|
|
if location.startswith(("http://", "https://")):
|
|
return location
|
|
if location.startswith("/"):
|
|
return f"{SD_BASE_URL}{location}"
|
|
return f"{SD_BASE_URL}/{location}"
|
|
|
|
|
|
def _url_with_token(url: str) -> str:
|
|
"""Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect)."""
|
|
if not SD_QUERY_PARAMS.get("token"):
|
|
return url
|
|
p = urlparse(url)
|
|
q: dict[str, str] = {}
|
|
for key, values in parse_qs(p.query).items():
|
|
if values:
|
|
q[key] = values[-1]
|
|
q.update(SD_QUERY_PARAMS)
|
|
return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), ""))
|
|
|
|
|
|
def _merge_params(extra: dict | None) -> dict | None:
|
|
if not SD_QUERY_PARAMS and not extra:
|
|
return None
|
|
merged = dict(SD_QUERY_PARAMS)
|
|
if extra:
|
|
merged.update(extra)
|
|
return merged
|
|
|
|
|
|
def _is_vast_gateway() -> bool:
|
|
return "trycloudflare.com" not in SD_BASE_URL.lower()
|
|
|
|
|
|
def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient:
|
|
return httpx.AsyncClient(
|
|
timeout=timeout,
|
|
follow_redirects=False,
|
|
auth=SD_BASIC_AUTH,
|
|
)
|
|
|
|
|
|
async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None:
|
|
"""
|
|
Vast Caddy: browser opens /?token=… and gets a session cookie; API then works.
|
|
Prime with redirects so Set-Cookie is collected, then merge into the API client.
|
|
"""
|
|
token = SD_QUERY_PARAMS.get("token")
|
|
if not token or not _is_vast_gateway():
|
|
return
|
|
try:
|
|
async with httpx.AsyncClient(
|
|
timeout=30,
|
|
follow_redirects=True,
|
|
auth=SD_BASIC_AUTH,
|
|
) as prime:
|
|
r = await prime.get(_comfy_url("/"), params={"token": token})
|
|
client.cookies.update(prime.cookies)
|
|
logger.info(
|
|
"Comfy gateway prime GET /?token=*** → %s, cookies=%s",
|
|
r.status_code,
|
|
list(prime.cookies.keys()) or "(none)",
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Comfy gateway prime failed: %s", e)
|
|
|
|
|
|
async def _comfy_request(
|
|
client: httpx.AsyncClient,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
params: dict | None = None,
|
|
**kwargs,
|
|
) -> httpx.Response:
|
|
"""
|
|
Comfy API: trycloudflare tunnel = no token.
|
|
Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached.
|
|
"""
|
|
url = _comfy_url(path)
|
|
extra = params or {}
|
|
token = SD_QUERY_PARAMS.get("token")
|
|
use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
|
|
|
if token and _is_vast_gateway():
|
|
await _prime_comfy_gateway(client)
|
|
|
|
req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None)
|
|
resp: httpx.Response | None = None
|
|
|
|
for hop in range(6):
|
|
resp = await client.request(method, url, params=req_params, **kwargs)
|
|
if resp.status_code not in (301, 302, 303, 307, 308):
|
|
return resp
|
|
loc = _absolute_url(resp.headers.get("location", ""), path)
|
|
url = _url_with_token(loc) if use_vast_auth else loc
|
|
req_params = extra or None
|
|
logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0])
|
|
|
|
assert resp is not None
|
|
return resp
|
|
|
|
|
|
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
|
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
|
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
|
SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal")
|
|
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "")
|
|
SD_DEFAULT_NEGATIVE = os.getenv(
|
|
"SD_DEFAULT_NEGATIVE",
|
|
"low quality, worst quality, blurry, bad anatomy, watermark, text",
|
|
)
|
|
|
|
# Anima split-model settings
|
|
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
|
|
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
|
|
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
|
|
SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "")
|
|
SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7"))
|
|
|
|
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
|
|
|
ANIMA_CHECKPOINTS = {"anima-preview3-base.safetensors"}
|
|
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
|
|
|
|
|
def _use_anima() -> bool:
|
|
return bool(SD_UNET) and not SD_CHECKPOINT
|
|
|
|
|
|
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
|
# Try new separator first
|
|
sep = "__NEGATIVE_PROMPT__"
|
|
if f"\n{sep}\n" in full_prompt:
|
|
pos, _, neg = full_prompt.partition(f"\n{sep}\n")
|
|
return pos.strip(), neg.strip()
|
|
# Fallback to old format
|
|
if "\n\nNegative prompt:" in full_prompt:
|
|
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
|
|
return pos.strip(), neg.strip()
|
|
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
|
|
|
|
|
|
def _workflow_uses_anima(overrides: dict | None) -> bool:
|
|
if overrides and overrides.get("checkpoint"):
|
|
return False
|
|
if overrides and overrides.get("unet"):
|
|
return True
|
|
return _use_anima()
|
|
|
|
|
|
def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict:
|
|
seed = int(uuid.uuid4().int % 2**32)
|
|
o = overrides or {}
|
|
if _workflow_uses_anima(o):
|
|
unet = o.get("unet") or SD_UNET
|
|
clip = o.get("clip") or SD_CLIP
|
|
vae = o.get("vae") or SD_VAE
|
|
workflow = {
|
|
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}},
|
|
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}},
|
|
"15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}},
|
|
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 720, "batch_size": 1}},
|
|
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
|
|
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
|
|
"19": {
|
|
"class_type": "KSampler",
|
|
"inputs": {
|
|
"model": ["44", 0], "positive": ["11", 0], "negative": ["12", 0],
|
|
"latent_image": ["28", 0], "seed": seed,
|
|
"steps": SD_STEPS, "cfg": SD_CFG,
|
|
"sampler_name": os.getenv("SD_SAMPLER", "er_sde"),
|
|
"scheduler": os.getenv("SD_SCHEDULER", "simple"),
|
|
"denoise": 1.0,
|
|
},
|
|
},
|
|
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
|
|
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
|
}
|
|
if SD_STYLE_LORA:
|
|
workflow["46"] = {
|
|
"class_type": "LoraLoader",
|
|
"inputs": {
|
|
"lora_name": SD_STYLE_LORA,
|
|
"model": ["44", 0],
|
|
"clip": ["45", 0],
|
|
"strength_model": SD_STYLE_LORA_WEIGHT,
|
|
"strength_clip": SD_STYLE_LORA_WEIGHT,
|
|
},
|
|
}
|
|
workflow["19"]["inputs"]["model"] = ["46", 0]
|
|
workflow["11"]["inputs"]["clip"] = ["46", 1]
|
|
workflow["12"]["inputs"]["clip"] = ["46", 1]
|
|
return workflow
|
|
ckpt = o.get("checkpoint") or SD_CHECKPOINT
|
|
return {
|
|
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}},
|
|
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
|
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
|
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
|
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}},
|
|
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
|
"10": {
|
|
"class_type": "KSampler",
|
|
"inputs": {
|
|
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
|
|
"latent_image": ["5", 0], "seed": seed,
|
|
"steps": SD_STEPS, "cfg": SD_CFG,
|
|
"sampler_name": SD_SAMPLER, "scheduler": SD_SCHEDULER,
|
|
"denoise": 1.0,
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
async def comfy_api_request(
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
params: dict | None = None,
|
|
json_body: dict | None = None,
|
|
timeout: float = 60,
|
|
) -> tuple[int, dict | str, dict]:
|
|
"""
|
|
Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset).
|
|
"""
|
|
async with _make_comfy_client(timeout=timeout) as client:
|
|
await _prime_comfy_gateway(client)
|
|
token = SD_QUERY_PARAMS.get("token")
|
|
use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
|
req_params = _merge_params(params) if use_vast else (params or None)
|
|
req_kwargs: dict = {}
|
|
if json_body is not None and method.upper() not in ("GET", "HEAD"):
|
|
req_kwargs["json"] = json_body
|
|
resp = await _comfy_request(
|
|
client,
|
|
method.upper(),
|
|
path,
|
|
params=req_params,
|
|
**req_kwargs,
|
|
)
|
|
headers = {
|
|
k: resp.headers.get(k)
|
|
for k in ("content-type", "location", "www-authenticate")
|
|
if resp.headers.get(k)
|
|
}
|
|
try:
|
|
body = resp.json()
|
|
except Exception:
|
|
body = resp.text[:8000]
|
|
return resp.status_code, body, headers
|
|
|
|
|
|
async def fetch_object_info() -> dict:
|
|
status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120)
|
|
if status != 200 or not isinstance(body, dict):
|
|
raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}")
|
|
return body
|
|
|
|
|
|
async def check_sd() -> bool:
|
|
try:
|
|
async with _make_comfy_client(timeout=15) as client:
|
|
await _prime_comfy_gateway(client)
|
|
r = await _comfy_request(client, "GET", "/system_stats")
|
|
return r.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
async def txt2img(
|
|
prompt: str,
|
|
negative_prompt: str | None = None,
|
|
*,
|
|
overrides: dict | None = None,
|
|
) -> tuple[bytes, str]:
|
|
neg = negative_prompt or SD_DEFAULT_NEGATIVE
|
|
workflow = _build_workflow(prompt, neg, overrides)
|
|
client_id = uuid.uuid4().hex
|
|
|
|
logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt)
|
|
async with _make_comfy_client() as client:
|
|
await _prime_comfy_gateway(client)
|
|
resp = await _comfy_request(
|
|
client,
|
|
"POST",
|
|
"/prompt",
|
|
json={"prompt": workflow, "client_id": client_id},
|
|
)
|
|
resp.raise_for_status()
|
|
prompt_id = resp.json()["prompt_id"]
|
|
logger.info("ComfyUI queued prompt_id=%s", prompt_id)
|
|
|
|
for _ in range(300):
|
|
await asyncio.sleep(1)
|
|
hist = await _comfy_request(client, "GET", f"/history/{prompt_id}")
|
|
data = hist.json()
|
|
if prompt_id in data:
|
|
entry = data[prompt_id]
|
|
# Log any errors from ComfyUI
|
|
if entry.get("status", {}).get("status_str") == "error":
|
|
msgs = entry.get("status", {}).get("messages", [])
|
|
logger.error("ComfyUI workflow error: %s", msgs)
|
|
outputs = entry.get("outputs", {})
|
|
for node_output in outputs.values():
|
|
if "images" in node_output:
|
|
img_info = node_output["images"][0]
|
|
img_resp = await _comfy_request(
|
|
client,
|
|
"GET",
|
|
"/view",
|
|
params={
|
|
"filename": img_info["filename"],
|
|
"subfolder": img_info.get("subfolder", ""),
|
|
"type": img_info.get("type", "output"),
|
|
},
|
|
)
|
|
img_resp.raise_for_status()
|
|
image_bytes = img_resp.content
|
|
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
|
filename = f"{uuid.uuid4().hex}.png"
|
|
(IMAGES_DIR / filename).write_bytes(image_bytes)
|
|
logger.info("ComfyUI done → saved %s", filename)
|
|
return image_bytes, f"images/{filename}"
|
|
logger.error("ComfyUI no image output. status=%s outputs_keys=%s",
|
|
entry.get("status"), list(outputs.keys()))
|
|
break
|
|
|
|
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
|
|
|
|
|
async def generate_from_full_prompt(
|
|
full_prompt: str,
|
|
*,
|
|
overrides: dict | None = None,
|
|
) -> tuple[str | None, str | None]:
|
|
positive, negative = split_prompt_and_negative(full_prompt)
|
|
try:
|
|
_, rel_path = await txt2img(positive, negative, overrides=overrides)
|
|
return rel_path, None
|
|
except httpx.HTTPStatusError as e:
|
|
code = e.response.status_code
|
|
if code == 401:
|
|
logger.error(
|
|
"ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) "
|
|
"and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. "
|
|
"Test: curl -u user:pass http://IP:PORT/system_stats "
|
|
"or open /?token=… in browser then curl with cookies. "
|
|
"Alternative: trycloudflare URL for localhost:8188 in Port Mapping."
|
|
)
|
|
elif code in (301, 302, 303, 307, 308):
|
|
logger.error(
|
|
"ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. "
|
|
"SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com "
|
|
"(no ?token=). Location: %s",
|
|
code,
|
|
e.response.headers.get("location"),
|
|
)
|
|
else:
|
|
logger.error("ComfyUI HTTP %s: %s", code, e)
|
|
return None, str(e)
|
|
except httpx.ConnectError as e:
|
|
logger.error(
|
|
"ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. "
|
|
"Use trycloudflare URL from Port Mapping for localhost:8188.",
|
|
e,
|
|
)
|
|
return None, str(e)
|
|
except Exception as e:
|
|
logger.error("ComfyUI error: %s", e)
|
|
return None, str(e)
|