Fixed SD Promt
This commit is contained in:
+322
-23
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
@@ -11,7 +12,178 @@ load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/")
|
||||
|
||||
def _parse_basic_auth() -> httpx.BasicAuth | None:
|
||||
"""
|
||||
Vast Caddy on mapped ports often uses Basic realm=restricted.
|
||||
Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD.
|
||||
"""
|
||||
raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip()
|
||||
if raw:
|
||||
if ":" in raw:
|
||||
user, _, password = raw.partition(":")
|
||||
else:
|
||||
user, password = "", raw
|
||||
return httpx.BasicAuth(user, password)
|
||||
user = (os.getenv("SD_COMFY_USER") or "").strip()
|
||||
password = (os.getenv("SD_COMFY_PASSWORD") or "").strip()
|
||||
if user or password:
|
||||
return httpx.BasicAuth(user, password)
|
||||
return None
|
||||
|
||||
|
||||
SD_BASIC_AUTH = _parse_basic_auth()
|
||||
|
||||
|
||||
def _parse_comfy_config() -> tuple[str, dict[str, str]]:
|
||||
"""
|
||||
SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=...
|
||||
API paths must be base + /prompt, not ...?token=xxx/prompt
|
||||
"""
|
||||
raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip()
|
||||
extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip()
|
||||
parsed = urlparse(raw)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path and path != "/":
|
||||
base = f"{base}{path}"
|
||||
query: dict[str, str] = {}
|
||||
for key, values in parse_qs(parsed.query).items():
|
||||
if values:
|
||||
query[key] = values[-1]
|
||||
if extra_token:
|
||||
query["token"] = extra_token
|
||||
base = base.rstrip("/")
|
||||
# Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply
|
||||
if "trycloudflare.com" in base.lower():
|
||||
if query.pop("token", None):
|
||||
logger.info(
|
||||
"SD_BASE_URL is trycloudflare tunnel: Vast token stripped. "
|
||||
"Use tunnel for port 8188 only (see instance Port Mapping)."
|
||||
)
|
||||
return base, query
|
||||
|
||||
|
||||
SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config()
|
||||
|
||||
|
||||
def _comfy_url(path: str) -> str:
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{SD_BASE_URL}{path}"
|
||||
|
||||
|
||||
def _log_comfy_target() -> str:
|
||||
if SD_QUERY_PARAMS.get("token"):
|
||||
return f"{SD_BASE_URL}?token=***"
|
||||
return SD_BASE_URL
|
||||
|
||||
|
||||
def _absolute_url(location: str, fallback_path: str = "/") -> str:
|
||||
if not location:
|
||||
return _comfy_url(fallback_path)
|
||||
if location.startswith(("http://", "https://")):
|
||||
return location
|
||||
if location.startswith("/"):
|
||||
return f"{SD_BASE_URL}{location}"
|
||||
return f"{SD_BASE_URL}/{location}"
|
||||
|
||||
|
||||
def _url_with_token(url: str) -> str:
|
||||
"""Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect)."""
|
||||
if not SD_QUERY_PARAMS.get("token"):
|
||||
return url
|
||||
p = urlparse(url)
|
||||
q: dict[str, str] = {}
|
||||
for key, values in parse_qs(p.query).items():
|
||||
if values:
|
||||
q[key] = values[-1]
|
||||
q.update(SD_QUERY_PARAMS)
|
||||
return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), ""))
|
||||
|
||||
|
||||
def _merge_params(extra: dict | None) -> dict | None:
|
||||
if not SD_QUERY_PARAMS and not extra:
|
||||
return None
|
||||
merged = dict(SD_QUERY_PARAMS)
|
||||
if extra:
|
||||
merged.update(extra)
|
||||
return merged
|
||||
|
||||
|
||||
def _is_vast_gateway() -> bool:
|
||||
return "trycloudflare.com" not in SD_BASE_URL.lower()
|
||||
|
||||
|
||||
def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=False,
|
||||
auth=SD_BASIC_AUTH,
|
||||
)
|
||||
|
||||
|
||||
async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None:
|
||||
"""
|
||||
Vast Caddy: browser opens /?token=… and gets a session cookie; API then works.
|
||||
Prime with redirects so Set-Cookie is collected, then merge into the API client.
|
||||
"""
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
if not token or not _is_vast_gateway():
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30,
|
||||
follow_redirects=True,
|
||||
auth=SD_BASIC_AUTH,
|
||||
) as prime:
|
||||
r = await prime.get(_comfy_url("/"), params={"token": token})
|
||||
client.cookies.update(prime.cookies)
|
||||
logger.info(
|
||||
"Comfy gateway prime GET /?token=*** → %s, cookies=%s",
|
||||
r.status_code,
|
||||
list(prime.cookies.keys()) or "(none)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Comfy gateway prime failed: %s", e)
|
||||
|
||||
|
||||
async def _comfy_request(
|
||||
client: httpx.AsyncClient,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Comfy API: trycloudflare tunnel = no token.
|
||||
Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached.
|
||||
"""
|
||||
url = _comfy_url(path)
|
||||
extra = params or {}
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
|
||||
if token and _is_vast_gateway():
|
||||
await _prime_comfy_gateway(client)
|
||||
|
||||
req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None)
|
||||
resp: httpx.Response | None = None
|
||||
|
||||
for hop in range(6):
|
||||
resp = await client.request(method, url, params=req_params, **kwargs)
|
||||
if resp.status_code not in (301, 302, 303, 307, 308):
|
||||
return resp
|
||||
loc = _absolute_url(resp.headers.get("location", ""), path)
|
||||
url = _url_with_token(loc) if use_vast_auth else loc
|
||||
req_params = extra or None
|
||||
logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0])
|
||||
|
||||
assert resp is not None
|
||||
return resp
|
||||
|
||||
|
||||
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
||||
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
||||
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
||||
@@ -26,6 +198,8 @@ SD_DEFAULT_NEGATIVE = os.getenv(
|
||||
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
|
||||
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
|
||||
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
|
||||
SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "")
|
||||
SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7"))
|
||||
|
||||
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
||||
|
||||
@@ -38,19 +212,37 @@ def _use_anima() -> bool:
|
||||
|
||||
|
||||
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
# Try new separator first
|
||||
sep = "__NEGATIVE_PROMPT__"
|
||||
if f"\n{sep}\n" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition(f"\n{sep}\n")
|
||||
return pos.strip(), neg.strip()
|
||||
# Fallback to old format
|
||||
if "\n\nNegative prompt:" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
|
||||
return pos.strip(), neg.strip()
|
||||
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str) -> dict:
|
||||
def _workflow_uses_anima(overrides: dict | None) -> bool:
|
||||
if overrides and overrides.get("checkpoint"):
|
||||
return False
|
||||
if overrides and overrides.get("unet"):
|
||||
return True
|
||||
return _use_anima()
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict:
|
||||
seed = int(uuid.uuid4().int % 2**32)
|
||||
if _use_anima():
|
||||
return {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}},
|
||||
o = overrides or {}
|
||||
if _workflow_uses_anima(o):
|
||||
unet = o.get("unet") or SD_UNET
|
||||
clip = o.get("clip") or SD_CLIP
|
||||
vae = o.get("vae") or SD_VAE
|
||||
workflow = {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}},
|
||||
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
||||
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
|
||||
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
|
||||
@@ -68,9 +260,24 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
}
|
||||
# Standard checkpoint workflow (Pony / SDXL)
|
||||
if SD_STYLE_LORA:
|
||||
workflow["46"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": SD_STYLE_LORA,
|
||||
"model": ["44", 0],
|
||||
"clip": ["45", 0],
|
||||
"strength_model": SD_STYLE_LORA_WEIGHT,
|
||||
"strength_clip": SD_STYLE_LORA_WEIGHT,
|
||||
},
|
||||
}
|
||||
workflow["19"]["inputs"]["model"] = ["46", 0]
|
||||
workflow["11"]["inputs"]["clip"] = ["46", 1]
|
||||
workflow["12"]["inputs"]["clip"] = ["46", 1]
|
||||
return workflow
|
||||
ckpt = o.get("checkpoint") or SD_CHECKPOINT
|
||||
return {
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
@@ -89,24 +296,78 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def comfy_api_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
timeout: float = 60,
|
||||
) -> tuple[int, dict | str, dict]:
|
||||
"""
|
||||
Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset).
|
||||
"""
|
||||
async with _make_comfy_client(timeout=timeout) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
req_params = _merge_params(params) if use_vast else (params or None)
|
||||
req_kwargs: dict = {}
|
||||
if json_body is not None and method.upper() not in ("GET", "HEAD"):
|
||||
req_kwargs["json"] = json_body
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
method.upper(),
|
||||
path,
|
||||
params=req_params,
|
||||
**req_kwargs,
|
||||
)
|
||||
headers = {
|
||||
k: resp.headers.get(k)
|
||||
for k in ("content-type", "location", "www-authenticate")
|
||||
if resp.headers.get(k)
|
||||
}
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
body = resp.text[:8000]
|
||||
return resp.status_code, body, headers
|
||||
|
||||
|
||||
async def fetch_object_info() -> dict:
|
||||
status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120)
|
||||
if status != 200 or not isinstance(body, dict):
|
||||
raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}")
|
||||
return body
|
||||
|
||||
|
||||
async def check_sd() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{SD_BASE_URL}/system_stats")
|
||||
async with _make_comfy_client(timeout=15) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
r = await _comfy_request(client, "GET", "/system_stats")
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]:
|
||||
async def txt2img(
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[bytes, str]:
|
||||
neg = negative_prompt or SD_DEFAULT_NEGATIVE
|
||||
workflow = _build_workflow(prompt, neg)
|
||||
workflow = _build_workflow(prompt, neg, overrides)
|
||||
client_id = uuid.uuid4().hex
|
||||
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
resp = await client.post(
|
||||
f"{SD_BASE_URL}/prompt",
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt)
|
||||
async with _make_comfy_client() as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
"POST",
|
||||
"/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -115,7 +376,7 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
|
||||
for _ in range(300):
|
||||
await asyncio.sleep(1)
|
||||
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
|
||||
hist = await _comfy_request(client, "GET", f"/history/{prompt_id}")
|
||||
data = hist.json()
|
||||
if prompt_id in data:
|
||||
entry = data[prompt_id]
|
||||
@@ -127,9 +388,15 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
for node_output in outputs.values():
|
||||
if "images" in node_output:
|
||||
img_info = node_output["images"][0]
|
||||
img_resp = await client.get(
|
||||
f"{SD_BASE_URL}/view",
|
||||
params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")},
|
||||
img_resp = await _comfy_request(
|
||||
client,
|
||||
"GET",
|
||||
"/view",
|
||||
params={
|
||||
"filename": img_info["filename"],
|
||||
"subfolder": img_info.get("subfolder", ""),
|
||||
"type": img_info.get("type", "output"),
|
||||
},
|
||||
)
|
||||
img_resp.raise_for_status()
|
||||
image_bytes = img_resp.content
|
||||
@@ -145,11 +412,43 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
||||
|
||||
|
||||
async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]:
|
||||
async def generate_from_full_prompt(
|
||||
full_prompt: str,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
positive, negative = split_prompt_and_negative(full_prompt)
|
||||
try:
|
||||
_, rel_path = await txt2img(positive, negative)
|
||||
_, rel_path = await txt2img(positive, negative, overrides=overrides)
|
||||
return rel_path, None
|
||||
except httpx.HTTPStatusError as e:
|
||||
code = e.response.status_code
|
||||
if code == 401:
|
||||
logger.error(
|
||||
"ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) "
|
||||
"and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. "
|
||||
"Test: curl -u user:pass http://IP:PORT/system_stats "
|
||||
"or open /?token=… in browser then curl with cookies. "
|
||||
"Alternative: trycloudflare URL for localhost:8188 in Port Mapping."
|
||||
)
|
||||
elif code in (301, 302, 303, 307, 308):
|
||||
logger.error(
|
||||
"ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. "
|
||||
"SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com "
|
||||
"(no ?token=). Location: %s",
|
||||
code,
|
||||
e.response.headers.get("location"),
|
||||
)
|
||||
else:
|
||||
logger.error("ComfyUI HTTP %s: %s", code, e)
|
||||
return None, str(e)
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(
|
||||
"ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. "
|
||||
"Use trycloudflare URL from Port Mapping for localhost:8188.",
|
||||
e,
|
||||
)
|
||||
return None, str(e)
|
||||
except Exception as e:
|
||||
logger.error("ComfyUI error: %s", e)
|
||||
return None, str(e)
|
||||
|
||||
Reference in New Issue
Block a user