Fixed SD Promt
This commit is contained in:
@@ -0,0 +1,248 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services import sdbackend as sd_service
|
||||
from services.comfy_models import list_node_types, parse_model_lists
|
||||
from services.llm import (
|
||||
CHAT_MODEL,
|
||||
LLM_FALLBACK_MODEL,
|
||||
LLMError,
|
||||
SYSTEM_MODEL,
|
||||
send_message,
|
||||
send_message_with_model,
|
||||
)
|
||||
from services.personas import get_all_personas
|
||||
from services.sd_prompt import (
|
||||
SD_PROMPT_MODEL,
|
||||
anima_dual_enabled,
|
||||
run_prompt_builder,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/debug", tags=["debug"])
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class SdPromptDebugRequest(BaseModel):
|
||||
persona_id: str = "default"
|
||||
chat_excerpt: str = ""
|
||||
messages: list[ChatMessage] | None = None
|
||||
outfit_json: str = "[]"
|
||||
appearance_override: str | None = None
|
||||
use_prose: bool = False
|
||||
|
||||
|
||||
class LlmDebugRequest(BaseModel):
|
||||
model: str = ""
|
||||
system: str = ""
|
||||
user: str = ""
|
||||
messages: list[ChatMessage] | None = None
|
||||
|
||||
|
||||
class ComfyRawRequest(BaseModel):
|
||||
method: str = "GET"
|
||||
path: str = "/system_stats"
|
||||
params_json: str = "{}"
|
||||
body_json: str = ""
|
||||
|
||||
|
||||
class ComfyGenerateRequest(BaseModel):
|
||||
positive: str
|
||||
negative: str = ""
|
||||
unet: str | None = None
|
||||
clip: str | None = None
|
||||
vae: str | None = None
|
||||
checkpoint: str | None = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def debug_config():
|
||||
base = sd_service.SD_BASE_URL
|
||||
return {
|
||||
"chat_model": CHAT_MODEL,
|
||||
"system_model": SYSTEM_MODEL,
|
||||
"llm_fallback_model": LLM_FALLBACK_MODEL,
|
||||
"sd_prompt_model": SD_PROMPT_MODEL or SYSTEM_MODEL,
|
||||
"sd_base_url": base,
|
||||
"sd_has_token": bool(sd_service.SD_QUERY_PARAMS.get("token")),
|
||||
"sd_anima_dual": anima_dual_enabled(),
|
||||
"sd_unet": sd_service.SD_UNET,
|
||||
"sd_clip": sd_service.SD_CLIP,
|
||||
"sd_vae": sd_service.SD_VAE,
|
||||
"sd_checkpoint": sd_service.SD_CHECKPOINT,
|
||||
"sd_steps": sd_service.SD_STEPS,
|
||||
"sd_cfg": sd_service.SD_CFG,
|
||||
"router_key_set": bool(os.getenv("ROUTER_KEY")),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/personas")
|
||||
async def debug_personas():
|
||||
personas = await get_all_personas()
|
||||
return [
|
||||
{
|
||||
"persona_id": pid,
|
||||
"name": p.get("name", pid),
|
||||
"appearance_tags": p.get("appearance_tags", ""),
|
||||
}
|
||||
for pid, p in personas.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/sd-prompt")
|
||||
async def debug_sd_prompt(req: SdPromptDebugRequest):
|
||||
msgs = None
|
||||
if req.messages:
|
||||
msgs = [m.model_dump() for m in req.messages]
|
||||
return await run_prompt_builder(
|
||||
req.persona_id,
|
||||
messages=msgs,
|
||||
chat_excerpt=req.chat_excerpt,
|
||||
outfit_json=req.outfit_json,
|
||||
appearance_override=req.appearance_override,
|
||||
use_prose=req.use_prose,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/llm")
|
||||
async def debug_llm(req: LlmDebugRequest):
|
||||
if req.messages:
|
||||
messages = [m.model_dump() for m in req.messages]
|
||||
else:
|
||||
messages = []
|
||||
if req.system.strip():
|
||||
messages.append({"role": "system", "content": req.system.strip()})
|
||||
if req.user.strip():
|
||||
messages.append({"role": "user", "content": req.user.strip()})
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="Нужны messages или system/user")
|
||||
|
||||
model = (req.model or "").strip() or SD_PROMPT_MODEL or SYSTEM_MODEL
|
||||
try:
|
||||
if model in (SYSTEM_MODEL, "") and not req.model:
|
||||
text = await send_message(messages)
|
||||
else:
|
||||
text = await send_message_with_model(messages, model)
|
||||
return {"model": model, "response": text}
|
||||
except LLMError as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/comfy/ping")
|
||||
async def debug_comfy_ping():
|
||||
try:
|
||||
status, body, headers = await sd_service.comfy_api_request("GET", "/system_stats")
|
||||
return {"ok": status == 200, "status": status, "body": body, "headers": headers}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/comfy/models")
|
||||
async def debug_comfy_models():
|
||||
try:
|
||||
info = await sd_service.fetch_object_info()
|
||||
return {
|
||||
"models": parse_model_lists(info),
|
||||
"configured": {
|
||||
"unet": sd_service.SD_UNET,
|
||||
"clip": sd_service.SD_CLIP,
|
||||
"vae": sd_service.SD_VAE,
|
||||
"checkpoint": sd_service.SD_CHECKPOINT,
|
||||
},
|
||||
"node_type_count": len(list_node_types(info)),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/comfy/object_info")
|
||||
async def debug_comfy_object_info(node: str | None = None):
|
||||
try:
|
||||
info = await sd_service.fetch_object_info()
|
||||
if node:
|
||||
if node not in info:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown node: {node}")
|
||||
return {node: info[node]}
|
||||
return {
|
||||
"node_types": list_node_types(info),
|
||||
"models": parse_model_lists(info),
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/comfy/raw")
|
||||
async def debug_comfy_raw(req: ComfyRawRequest):
|
||||
path = req.path.strip()
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
try:
|
||||
params = json.loads(req.params_json or "{}")
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("params_json must be object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"params_json: {e}")
|
||||
|
||||
body = None
|
||||
if req.body_json.strip():
|
||||
try:
|
||||
body = json.loads(req.body_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"body_json: {e}")
|
||||
|
||||
method = req.method.upper()
|
||||
if method not in ("GET", "POST", "PUT", "DELETE"):
|
||||
raise HTTPException(status_code=400, detail="method must be GET|POST|PUT|DELETE")
|
||||
|
||||
try:
|
||||
status, resp_body, headers = await sd_service.comfy_api_request(
|
||||
method,
|
||||
path,
|
||||
params=params or None,
|
||||
json_body=body,
|
||||
timeout=120,
|
||||
)
|
||||
return {"status": status, "headers": headers, "body": resp_body}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/comfy/generate")
|
||||
async def debug_comfy_generate(req: ComfyGenerateRequest):
|
||||
if not req.positive.strip():
|
||||
raise HTTPException(status_code=400, detail="positive required")
|
||||
|
||||
overrides: dict[str, str] = {}
|
||||
if req.unet:
|
||||
overrides["unet"] = req.unet
|
||||
if req.clip:
|
||||
overrides["clip"] = req.clip
|
||||
if req.vae:
|
||||
overrides["vae"] = req.vae
|
||||
if req.checkpoint:
|
||||
overrides["checkpoint"] = req.checkpoint
|
||||
|
||||
full = req.positive.strip()
|
||||
if req.negative.strip():
|
||||
full += f"\n\nNegative prompt: {req.negative.strip()}"
|
||||
|
||||
try:
|
||||
rel, err = await sd_service.generate_from_full_prompt(
|
||||
full,
|
||||
overrides=overrides or None,
|
||||
)
|
||||
if not rel:
|
||||
raise HTTPException(status_code=502, detail=err or "generation failed")
|
||||
return {"image_path": f"/static/{rel}", "status": "ok"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
Reference in New Issue
Block a user