Fixed SD Promt
This commit is contained in:
+14
-11
@@ -45,6 +45,7 @@ def parse_card_v2(data: dict, card_id: str | None = None) -> dict:
|
||||
"first_mes": inner.get("first_mes", ""),
|
||||
"mes_example": inner.get("mes_example", ""),
|
||||
"appearance_tags": _extract_appearance(inner),
|
||||
"appearance_prose": "",
|
||||
"lorebook_json": json.dumps(entries, ensure_ascii=False),
|
||||
"alternate_greetings": alternates,
|
||||
"alternate_greetings_json": json.dumps(alternates, ensure_ascii=False),
|
||||
@@ -141,13 +142,13 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT OR REPLACE INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes,
|
||||
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json,
|
||||
avatar_path, alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
"""INSERT INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes, mes_example,
|
||||
raw_json, lora_name, lora_weight, appearance_tags, appearance_prose, lorebook_json, avatar_path,
|
||||
alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
card_id,
|
||||
card["card_id"],
|
||||
card["name"],
|
||||
card["description"],
|
||||
card["personality"],
|
||||
@@ -157,10 +158,11 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
|
||||
card["raw_json"],
|
||||
lora_name,
|
||||
lora_weight,
|
||||
card.get("appearance_tags", ""),
|
||||
card["appearance_tags"],
|
||||
card.get("appearance_prose", ""),
|
||||
card["lorebook_json"],
|
||||
card.get("avatar_path", ""),
|
||||
alt_json,
|
||||
card.get("alternate_greetings_json", "[]"),
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
@@ -199,8 +201,8 @@ async def delete_character(card_id: str) -> bool:
|
||||
async def update_appearance_tags(card_id: str, appearance_tags: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE characters SET appearance_tags = ? WHERE card_id = ?",
|
||||
(appearance_tags, card_id),
|
||||
"UPDATE characters SET appearance_tags = ?, appearance_prose = ? WHERE card_id = ?",
|
||||
(appearance_tags, "", card_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -228,7 +230,7 @@ async def preview_card_file(content: bytes, filename: str) -> dict:
|
||||
|
||||
async def update_character(card_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "personality", "scenario", "first_mes",
|
||||
"mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path",
|
||||
"mes_example", "appearance_tags", "appearance_prose", "lora_name", "lora_weight", "avatar_path",
|
||||
"alternate_greetings_json"}
|
||||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not updates:
|
||||
@@ -295,6 +297,7 @@ async def import_card_file(
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": saved.get("appearance_tags", ""),
|
||||
"appearance_prose": saved.get("appearance_prose", ""),
|
||||
"avatar_path": saved.get("avatar_path", ""),
|
||||
"personality": saved.get("personality", ""),
|
||||
"scenario": saved.get("scenario", ""),
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from services.personas import get_persona
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.character_card import get_character
|
||||
|
||||
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
||||
|
||||
|
||||
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
||||
"""Static character prompt only (no RPG runtime blocks)."""
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return DEFAULT_PROMPT
|
||||
prompt = persona["prompt"]
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
if persona.get("lorebook_json"):
|
||||
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
if persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card:
|
||||
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
return prompt
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Parse ComfyUI /object_info into usable model lists."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Node types whose combo inputs we expose in the debug UI
|
||||
_MODEL_NODES: dict[str, tuple[str, str]] = {
|
||||
"checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
|
||||
"unets": ("UNETLoader", "unet_name"),
|
||||
"clips": ("CLIPLoader", "clip_name"),
|
||||
"vaes": ("VAELoader", "vae_name"),
|
||||
"loras": ("LoraLoader", "lora_name"),
|
||||
}
|
||||
|
||||
|
||||
def _combo_options(node_def: dict, input_name: str) -> list[str]:
|
||||
if not isinstance(node_def, dict):
|
||||
return []
|
||||
required = (node_def.get("input") or {}).get("required") or {}
|
||||
optional = (node_def.get("input") or {}).get("optional") or {}
|
||||
spec = required.get(input_name) or optional.get(input_name)
|
||||
if not spec or not isinstance(spec, (list, tuple)):
|
||||
return []
|
||||
first = spec[0]
|
||||
if isinstance(first, list):
|
||||
return [str(x) for x in first]
|
||||
return []
|
||||
|
||||
|
||||
def parse_model_lists(object_info: dict) -> dict[str, list[str]]:
|
||||
out: dict[str, list[str]] = {}
|
||||
for key, (node_type, input_name) in _MODEL_NODES.items():
|
||||
node_def = object_info.get(node_type) or {}
|
||||
options = _combo_options(node_def, input_name)
|
||||
if options:
|
||||
out[key] = options
|
||||
return out
|
||||
|
||||
|
||||
def list_node_types(object_info: dict) -> list[str]:
|
||||
return sorted(k for k in object_info.keys() if isinstance(object_info.get(k), dict))
|
||||
+119
-6
@@ -13,6 +13,8 @@ OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
|
||||
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
|
||||
# Softer model when primary returns content_filter / empty / API errors (default: CHAT_MODEL).
|
||||
LLM_FALLBACK_MODEL = (os.getenv("LLM_FALLBACK_MODEL") or "").strip() or CHAT_MODEL
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
||||
@@ -21,26 +23,128 @@ HEADERS = {
|
||||
}
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""OpenRouter returned an error or an unexpected response shape."""
|
||||
|
||||
|
||||
def _parse_completion_body(data: dict) -> str:
|
||||
if not isinstance(data, dict):
|
||||
raise LLMError(f"Invalid API response: expected object, got {type(data).__name__}")
|
||||
|
||||
if data.get("error"):
|
||||
err = data["error"]
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message") or str(err)
|
||||
code = err.get("code")
|
||||
else:
|
||||
msg = str(err)
|
||||
code = None
|
||||
suffix = f" (code={code})" if code is not None else ""
|
||||
raise LLMError(f"OpenRouter error{suffix}: {msg}")
|
||||
|
||||
choices = data.get("choices")
|
||||
if not choices:
|
||||
preview = str(data)[:400]
|
||||
raise LLMError(f"OpenRouter response has no 'choices'. Body preview: {preview}")
|
||||
|
||||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||||
message = first.get("message") or {}
|
||||
if not isinstance(message, dict):
|
||||
raise LLMError("OpenRouter choice has no message object")
|
||||
|
||||
finish = first.get("finish_reason") or ""
|
||||
native_finish = first.get("native_finish_reason") or ""
|
||||
blocked_reasons = {"content_filter", "safety", "moderation"}
|
||||
if finish in blocked_reasons or str(native_finish).upper() in (
|
||||
"PROHIBITED_CONTENT",
|
||||
"SAFETY",
|
||||
"BLOCKED",
|
||||
):
|
||||
raise LLMError(
|
||||
f"Content blocked by provider (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
content = message.get("content")
|
||||
if content is not None and str(content).strip():
|
||||
return str(content)
|
||||
|
||||
refusal = message.get("refusal")
|
||||
if refusal:
|
||||
raise LLMError(f"Model refused the request: {refusal}")
|
||||
|
||||
if finish and finish not in ("stop", "length", "tool_calls", "function_call"):
|
||||
raise LLMError(
|
||||
f"OpenRouter finished without content (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
raise LLMError("OpenRouter returned empty message content")
|
||||
|
||||
|
||||
def _clean(messages: list) -> list:
|
||||
"""Filter out messages with empty content."""
|
||||
return [m for m in messages if (m.get("content") or "").strip()]
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
async def _post_once(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
if not OPENROUTER_KEY:
|
||||
raise LLMError("ROUTER_KEY is not set in environment")
|
||||
|
||||
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
|
||||
async with httpx.AsyncClient(timeout=90) as client:
|
||||
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
try:
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
raise LLMError(f"Non-JSON response (HTTP {r.status_code}): {r.text[:300]}") from e
|
||||
|
||||
if r.status_code >= 400:
|
||||
try:
|
||||
_parse_completion_body(data)
|
||||
except LLMError:
|
||||
raise
|
||||
raise LLMError(f"HTTP {r.status_code}: {data}")
|
||||
|
||||
try:
|
||||
return _parse_completion_body(data)
|
||||
except LLMError:
|
||||
logger.warning(
|
||||
"OpenRouter completion failed model=%s status=%s body=%.500s",
|
||||
model,
|
||||
r.status_code,
|
||||
data,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
"""POST completion; on failure retries once with LLM_FALLBACK_MODEL (usually CHAT_MODEL)."""
|
||||
try:
|
||||
return await _post_once(model, messages, extra)
|
||||
except LLMError as primary_err:
|
||||
fallback = LLM_FALLBACK_MODEL
|
||||
if not fallback or fallback == model:
|
||||
raise
|
||||
logger.info(
|
||||
"LLM fallback: %s failed (%s) → retrying with %s",
|
||||
model,
|
||||
primary_err,
|
||||
fallback,
|
||||
)
|
||||
try:
|
||||
return await _post_once(fallback, messages, extra)
|
||||
except LLMError as fallback_err:
|
||||
raise LLMError(
|
||||
f"{primary_err} (fallback {fallback} also failed: {fallback_err})"
|
||||
) from fallback_err
|
||||
|
||||
|
||||
async def send_message(messages: list) -> str:
|
||||
"""System model — narrator, facts, SD prompt."""
|
||||
"""SYSTEM_MODEL with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(SYSTEM_MODEL, messages)
|
||||
|
||||
|
||||
async def send_message_with_model(messages: list, model: str) -> str:
|
||||
"""Explicit model — plot arc, narrator override."""
|
||||
"""Named model (RPG_*, SD_*) with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(model, messages)
|
||||
|
||||
|
||||
@@ -73,10 +177,19 @@ async def stream_message(messages: list):
|
||||
return
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if chunk.get("error"):
|
||||
err = chunk["error"]
|
||||
msg = err.get("message", err) if isinstance(err, dict) else err
|
||||
raise LLMError(f"OpenRouter stream error: {msg}")
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
content = (choices[0].get("delta") or {}).get("content", "")
|
||||
if content:
|
||||
chunk_count += 1
|
||||
yield content
|
||||
except LLMError:
|
||||
raise
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
|
||||
+124
-17
@@ -2,7 +2,8 @@ import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
|
||||
async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict:
|
||||
"""Existing sessions keep their persona_id; persona_id applies only on INSERT."""
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
@@ -13,9 +14,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") ->
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
pid = (persona_id or "default").strip() or "default"
|
||||
await db.execute(
|
||||
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
|
||||
(session_id, persona_id),
|
||||
(session_id, pid),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -71,24 +73,99 @@ async def update_session_persona(session_id: str, persona_id: str):
|
||||
(persona_id, session_id),
|
||||
)
|
||||
|
||||
# If persona changed, reset RPG state bound to the persona/arc.
|
||||
if prev is not None and prev != persona_id:
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}'
|
||||
WHERE session_id = ?""",
|
||||
(session_id,),
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM action_resolutions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await _reset_persona_bound_state(db, session_id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None:
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}',
|
||||
outfit_json = '[]',
|
||||
affinity = 0
|
||||
WHERE session_id = ?""",
|
||||
(session_id,),
|
||||
)
|
||||
await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,))
|
||||
await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,))
|
||||
|
||||
|
||||
async def upsert_static_system_message(
|
||||
session_id: str, static_prompt: str, history: list | None = None
|
||||
) -> bool:
|
||||
"""Store only static persona prompt in messages. Returns True if written."""
|
||||
hist = history if history is not None else await get_history(session_id)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
if not hist:
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
if hist[0]["role"] == "system":
|
||||
if hist[0]["content"] == static_prompt:
|
||||
return False
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(static_prompt, session_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def delete_dialog_messages(session_id: str) -> None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def rebind_session_persona(
|
||||
session_id: str,
|
||||
persona_id: str,
|
||||
*,
|
||||
clear_history: bool = False,
|
||||
static_prompt: str,
|
||||
first_mes: str | None = None,
|
||||
) -> None:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
await update_session_persona(session_id, persona_id)
|
||||
if clear_history:
|
||||
await delete_dialog_messages(session_id)
|
||||
|
||||
history = await get_history(session_id)
|
||||
await upsert_static_system_message(session_id, static_prompt, history)
|
||||
|
||||
if clear_history and first_mes and first_mes.strip():
|
||||
await add_message(session_id, "assistant", first_mes.strip())
|
||||
|
||||
|
||||
async def update_session_rpg(session_id: str, rpg_enabled: bool):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
@@ -178,7 +255,8 @@ async def get_history(session_id: str) -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id, role, content, image_prompt, image_path
|
||||
"""SELECT id, role, content, image_prompt, image_path,
|
||||
image_prompt_alt, image_path_alt
|
||||
FROM messages WHERE session_id = ? ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
@@ -190,6 +268,8 @@ async def get_history(session_id: str) -> list:
|
||||
"content": r["content"],
|
||||
"image_prompt": r["image_prompt"],
|
||||
"image_path": r["image_path"],
|
||||
"image_prompt_alt": r["image_prompt_alt"],
|
||||
"image_path_alt": r["image_path_alt"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
@@ -332,6 +412,33 @@ async def update_message_image(message_id: int, image_path: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt(message_id: int, image_prompt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt = ? WHERE id = ?",
|
||||
(image_prompt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt_alt(message_id: int, image_prompt_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt_alt = ? WHERE id = ?",
|
||||
(image_prompt_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_image_alt(message_id: int, image_path_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_path_alt = ? WHERE id = ?",
|
||||
(image_path_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_assistant_message_id(session_id: str) -> int | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.memory import (
|
||||
get_history,
|
||||
get_session,
|
||||
get_last_assistant_message_id,
|
||||
update_session_plot_arc,
|
||||
update_session_status_quo,
|
||||
update_session_affinity,
|
||||
update_session_outfit,
|
||||
upsert_quest,
|
||||
get_quests,
|
||||
)
|
||||
from services.personas import get_persona
|
||||
from services.rpg_facts import facts_to_prompt
|
||||
from services.rpg_plot import generate_plot_arc
|
||||
from services.rpg_narrator import narrator_post
|
||||
from services.sd_prompt import generate_sd_prompt
|
||||
from services.sd_images import run_sd_for_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RPG_SETTINGS = {
|
||||
"dice": True,
|
||||
"narrator": True,
|
||||
"quests": True,
|
||||
"affinity": True,
|
||||
"choices": True,
|
||||
}
|
||||
|
||||
|
||||
def get_rpg_settings(session: dict) -> dict:
|
||||
try:
|
||||
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
||||
except Exception:
|
||||
return DEFAULT_RPG_SETTINGS
|
||||
|
||||
|
||||
async def resolve_greeting(session_id: str, persona: dict) -> str:
|
||||
history = await get_history(session_id)
|
||||
for m in reversed(history):
|
||||
if m.get("role") == "assistant" and (m.get("content") or "").strip():
|
||||
return m["content"].strip()
|
||||
return (persona.get("first_mes") or "").strip()
|
||||
|
||||
|
||||
async def ensure_plot_arc_and_quests(
|
||||
session_id: str,
|
||||
persona: dict,
|
||||
greeting: str,
|
||||
genre: str,
|
||||
*,
|
||||
seed_quests: bool = True,
|
||||
) -> dict:
|
||||
session = await get_session(session_id) or {}
|
||||
arc_json = session.get("plot_arc_json") or "{}"
|
||||
try:
|
||||
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
||||
except Exception:
|
||||
arc = {}
|
||||
|
||||
if arc:
|
||||
return arc
|
||||
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", "Character"),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
greeting,
|
||||
facts_block=facts_block,
|
||||
genre=genre,
|
||||
)
|
||||
if not arc:
|
||||
return {}
|
||||
|
||||
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
||||
if seed_quests:
|
||||
for beat in arc.get("beats", []):
|
||||
title = (beat.get("title") or beat.get("injection", "")).strip()
|
||||
if title:
|
||||
await upsert_quest(session_id, title[:120])
|
||||
return arc
|
||||
|
||||
|
||||
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
history = await get_history(session_id)
|
||||
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
|
||||
if not assistant_msgs:
|
||||
raise ValueError("No assistant message (first_mes) found")
|
||||
|
||||
first_mes_text = assistant_msgs[-1].get("content", "").strip()
|
||||
if not first_mes_text:
|
||||
raise ValueError("Empty first_mes")
|
||||
|
||||
msg_id = await get_last_assistant_message_id(session_id)
|
||||
persona = await get_persona(persona_id) or {}
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
|
||||
arc: dict = {}
|
||||
choices: list = []
|
||||
status_quo = session.get("status_quo") or ""
|
||||
outfit_json = session.get("outfit_json") or "[]"
|
||||
|
||||
if rpg:
|
||||
genre = session.get("genre") or "adventure"
|
||||
arc = await ensure_plot_arc_and_quests(
|
||||
session_id,
|
||||
persona,
|
||||
first_mes_text,
|
||||
genre,
|
||||
seed_quests=rpg_settings.get("quests", True),
|
||||
)
|
||||
|
||||
session = await get_session(session_id) or session
|
||||
ctx_txt = f"assistant: {first_mes_text}"
|
||||
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
arc_json,
|
||||
facts_block,
|
||||
is_opening=True,
|
||||
)
|
||||
|
||||
sq = (post.get("status_quo_update") or "").strip()
|
||||
if sq:
|
||||
await update_session_status_quo(session_id, sq)
|
||||
status_quo = sq
|
||||
|
||||
if rpg_settings.get("choices", True):
|
||||
choices = post.get("choices") or []
|
||||
|
||||
if rpg_settings.get("affinity", True):
|
||||
delta = int(post.get("affinity_delta") or 0)
|
||||
if delta:
|
||||
await update_session_affinity(session_id, delta)
|
||||
|
||||
outfit_update = post.get("outfit_update")
|
||||
if isinstance(outfit_update, list) and outfit_update:
|
||||
outfit_json = json.dumps(outfit_update, ensure_ascii=False)
|
||||
await update_session_outfit(session_id, outfit_json)
|
||||
|
||||
if rpg_settings.get("quests", True):
|
||||
for qu in (post.get("quest_updates") or []):
|
||||
title = (qu.get("title") or "").strip()
|
||||
if title:
|
||||
await upsert_quest(session_id, title[:120], qu.get("status", "active"))
|
||||
|
||||
quests = await get_quests(session_id)
|
||||
messages = await get_history(session_id)
|
||||
bundle = await generate_sd_prompt(messages, persona_id, outfit_json=outfit_json)
|
||||
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
|
||||
|
||||
updated = await get_session(session_id)
|
||||
affinity = updated.get("affinity", 0) if updated else 0
|
||||
|
||||
return {
|
||||
"plot_arc": arc,
|
||||
"quests": quests,
|
||||
"outfit_json": outfit_json,
|
||||
"status_quo": status_quo,
|
||||
"choices": choices,
|
||||
"image_prompt": sd_out.get("image_prompt"),
|
||||
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
||||
"image_path": sd_out.get("image_path"),
|
||||
"image_path_alt": sd_out.get("image_path_alt"),
|
||||
"image_error": sd_out.get("image_error"),
|
||||
"image_error_alt": sd_out.get("image_error_alt"),
|
||||
"affinity": affinity,
|
||||
}
|
||||
+21
-4
@@ -63,6 +63,7 @@ def _row_to_persona(row: dict) -> dict:
|
||||
"lora_name": row["lora_name"] or "",
|
||||
"lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8,
|
||||
"appearance_tags": row["appearance_tags"] or "",
|
||||
"appearance_prose": row.get("appearance_prose", "") or "",
|
||||
"personality": row.get("personality", "") or "",
|
||||
"scenario": row.get("scenario", "") or "",
|
||||
"first_mes": row.get("first_mes", "") or "",
|
||||
@@ -117,6 +118,7 @@ async def create_persona(
|
||||
lora_name: str = "",
|
||||
lora_weight: float = 0.8,
|
||||
appearance_tags: str = "",
|
||||
appearance_prose: str = "",
|
||||
personality: str = "",
|
||||
scenario: str = "",
|
||||
first_mes: str = "",
|
||||
@@ -138,19 +140,19 @@ async def create_persona(
|
||||
await db.execute(
|
||||
"""INSERT INTO personas
|
||||
(persona_id, name, emoji, description, prompt, custom,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
persona_id, name, emoji, description, final_prompt,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
return {
|
||||
"name": name,
|
||||
"emoji": emoji,
|
||||
"description": description,
|
||||
@@ -160,6 +162,7 @@ async def create_persona(
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": appearance_tags,
|
||||
"appearance_prose": appearance_prose,
|
||||
"personality": personality,
|
||||
"scenario": scenario,
|
||||
"first_mes": first_mes,
|
||||
@@ -226,6 +229,7 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||||
"lora_name",
|
||||
"lora_weight",
|
||||
"appearance_tags",
|
||||
"appearance_prose",
|
||||
"personality",
|
||||
"scenario",
|
||||
"first_mes",
|
||||
@@ -255,6 +259,19 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||||
merged = dict(existing)
|
||||
merged.update(updates)
|
||||
updates["prompt"] = build_persona_prompt(merged)
|
||||
|
||||
if "appearance_tags" in updates and "appearance_prose" not in updates:
|
||||
tags = updates["appearance_tags"].strip()
|
||||
if tags:
|
||||
from services.llm import send_message
|
||||
try:
|
||||
prose = await send_message([
|
||||
{"role": "system", "content": "Convert danbooru tags to natural English description. Output only the description, no markdown."},
|
||||
{"role": "user", "content": f"Tags: {tags}"}
|
||||
])
|
||||
updates["appearance_prose"] = prose.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cols = ", ".join(f"{k} = ?" for k in updates)
|
||||
cur2 = await db.execute(
|
||||
|
||||
+17
-2
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from services.llm import send_message_with_model, send_message
|
||||
from services.llm import LLMError, send_message_with_model, send_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FACTS_MODEL = os.getenv("RPG_FACTS_MODEL", "").strip() or "deepseek/deepseek-chat-v3"
|
||||
|
||||
@@ -51,7 +54,19 @@ async def extract_facts(context_messages: list[dict]) -> list[str]:
|
||||
{"role": "user", "content": transcript},
|
||||
]
|
||||
|
||||
raw = await (send_message_with_model(messages, FACTS_MODEL) if FACTS_MODEL else send_message(messages))
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, FACTS_MODEL)
|
||||
if FACTS_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("extract_facts LLM failed (model=%s): %s", FACTS_MODEL or "SYSTEM", e)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning("extract_facts unexpected error: %s", e)
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(raw.strip())
|
||||
if isinstance(data, list):
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from services.llm import send_message_with_model
|
||||
from services.llm import LLMError, send_message_with_model
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,10 +63,18 @@ async def narrator_pre(
|
||||
f"Facts:\n{facts_block}\n\n"
|
||||
f"Recent context:\n{context}\n"
|
||||
)
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
try:
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("Narrator-pre LLM failed (model=%s): %s", NARRATOR_MODEL, e)
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""}
|
||||
except Exception as e:
|
||||
logger.warning("Narrator-pre unexpected error: %s", e)
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""}
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
@@ -87,17 +95,35 @@ async def narrator_post(
|
||||
context: str,
|
||||
global_plot: str,
|
||||
facts_block: str,
|
||||
is_opening: bool = False,
|
||||
) -> dict:
|
||||
opening_block = ""
|
||||
if is_opening:
|
||||
opening_block = (
|
||||
"\n\nOPENING SCENE: This is the first greeting, not a mid-conversation reply. "
|
||||
"Extract the character's INITIAL visible clothing from the greeting into outfit_update "
|
||||
"(danbooru underscore tags), even if clothing did not change during the scene. "
|
||||
"Set status_quo to describe the opening situation.\n"
|
||||
)
|
||||
user = (
|
||||
f"Persona: {persona_name}\n\n"
|
||||
f"Global plot:\n{global_plot}\n\n"
|
||||
f"Facts:\n{facts_block}\n\n"
|
||||
f"Recent context:\n{context}\n"
|
||||
f"{opening_block}"
|
||||
)
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
try:
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("Narrator-post LLM failed (model=%s): %s", NARRATOR_MODEL, e)
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []}
|
||||
except Exception as e:
|
||||
logger.warning("Narrator-post unexpected error: %s", e)
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []}
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
|
||||
+14
-2
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from services.llm import send_message_with_model, send_message
|
||||
from services.llm import LLMError, send_message_with_model, send_message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +63,19 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
|
||||
{"role": "system", "content": ARC_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
raw = await (send_message_with_model(messages, PLOT_MODEL) if PLOT_MODEL else send_message(messages))
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("generate_plot_arc LLM failed (model=%s): %s", PLOT_MODEL or "SYSTEM", e)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("generate_plot_arc unexpected error: %s", e)
|
||||
return {}
|
||||
|
||||
cleaned = raw.strip()
|
||||
# common OpenRouter formatting: fenced json
|
||||
if cleaned.startswith("```"):
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Run ComfyUI generation from SdPromptBundle (single hybrid prompt for Anima)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from services import sdbackend as sd_service
|
||||
from services.memory import update_message_image, update_message_prompt, update_message_prompt_alt
|
||||
from services.sd_prompt import SdPromptBundle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
async def run_sd_for_message(bundle: SdPromptBundle | None, msg_id: int | None) -> dict:
|
||||
"""Generate image, persist prompts/paths on message. Returns fields for API/SSE."""
|
||||
out = {
|
||||
"image_prompt": None,
|
||||
"image_prompt_alt": None,
|
||||
"image_path": None,
|
||||
"image_path_alt": None,
|
||||
"image_error": None,
|
||||
"image_error_alt": None,
|
||||
}
|
||||
if not bundle or not bundle.tag_full:
|
||||
return out
|
||||
|
||||
out["image_prompt"] = bundle.tag_full
|
||||
if bundle.desc_full and bundle.desc_full != bundle.tag_full:
|
||||
out["image_prompt_alt"] = bundle.desc_full
|
||||
|
||||
if msg_id:
|
||||
await update_message_prompt(msg_id, bundle.tag_full)
|
||||
if out["image_prompt_alt"]:
|
||||
await update_message_prompt_alt(msg_id, out["image_prompt_alt"])
|
||||
|
||||
if not SD_AUTO_GENERATE:
|
||||
return out
|
||||
|
||||
rel, err = await sd_service.generate_from_full_prompt(bundle.tag_full)
|
||||
if rel:
|
||||
out["image_path"] = f"/static/{rel}"
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
else:
|
||||
out["image_error"] = err
|
||||
|
||||
return out
|
||||
+641
-67
@@ -2,26 +2,115 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm import send_message, send_message_with_model
|
||||
from services.personas import get_persona
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NEGATIVE_PROMPT_SEPARATOR = "\n\n__NEGATIVE_PROMPT__\n\n"
|
||||
|
||||
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
|
||||
Given a roleplay chat excerpt, output ONLY valid JSON (no markdown):
|
||||
{
|
||||
"should_generate": true,
|
||||
"shot_type": "first_person_pov" | "landscape" | "third_person",
|
||||
"action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'",
|
||||
"environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'"
|
||||
"action_tags": "booru-style tags for pose/action/expression",
|
||||
"environment_tags": "booru-style tags for location/lighting/time"
|
||||
}
|
||||
Rules:
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'.
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined.
|
||||
- Do NOT include appearance/character tags — those are provided separately.
|
||||
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
|
||||
- Do NOT invent tags. If unsure — omit.
|
||||
- Keep each field to 3-6 tags."""
|
||||
- Keep action_tags and environment_tags to 3-6 tags each.
|
||||
- shot_type: default "first_person_pov" for dialogue/intimacy at arm's length. "third_person" only for wide action (fight, chase). "landscape" only when environment is the focus.
|
||||
- should_generate: false for non-visual beats (pure internal monologue, time skips with no new pose, empty lines).
|
||||
- NEVER use negative words in tag fields (not, without, naked, nsfw, etc.)."""
|
||||
|
||||
ANIMA_BUILDER_EXTRA = """
|
||||
Anima hybrid mode — ALSO include:
|
||||
"pov_cue": "face_to_face" | "walking_together" | "doorway_invite" | "reach_to_viewer" | "dialogue_close",
|
||||
"viewer_body_visible": false,
|
||||
"scene_description": "ONE short English sentence (max 40 words). Camera POV: what the viewer sees. Mood/atmosphere only — do NOT repeat tags from action_tags/environment_tags. Do NOT list comma-separated booru tags."
|
||||
POV / interaction rules:
|
||||
- Default viewer_body_visible: false. The viewer's body, hands, or face must NOT appear in the image — only the character toward the camera.
|
||||
- For hugs, embraces: use arms_out, reaching_towards_viewer, inviting_hug — NOT holding_hands, lifting, carrying, nose_rub (these draw a second body in POV).
|
||||
- For long messages with time skips ("About an hour later..."), illustrate ONLY the final visible beat (usually the last paragraph).
|
||||
- scene_description: describe HER toward the camera only — NEVER "someone", "both", "with you", "hand in hand with", or another person's body.
|
||||
- NEVER use tags: looking_at_each_other, couple, 2girls, 2boys, multiple_girls. For POV walking together omit holding_hands (use walking, smiling, reaching_towards_viewer instead).
|
||||
- pov_cue: pick the framing that matches the CURRENT beat (walking_together for strolling side by side, doorway_invite for doorway with arms open, reach_to_viewer when she reaches toward camera, face_to_face for close dialogue).
|
||||
- Illustrate ONLY the beat under === ILLUSTRATE ===; use === Context === for outfit/location hints only.
|
||||
- Do NOT put English sentences in action_tags or environment_tags — tags only."""
|
||||
|
||||
POV_CUE_PHRASES: dict[str, str] = {
|
||||
"face_to_face": "POV: close face-to-face, she looks directly at you",
|
||||
"walking_together": "POV: walking beside you, profile and shared path visible",
|
||||
"doorway_invite": "POV: she blocks the doorway, arms open toward you",
|
||||
"reach_to_viewer": "POV: she reaches toward the camera",
|
||||
"dialogue_close": "POV: close conversation, she faces you at arm's length",
|
||||
}
|
||||
|
||||
POV_CUE_DEFAULT = "POV: she stands before you, facing the camera"
|
||||
|
||||
POV_INTERACTION_NEGATIVE = (
|
||||
"duplicate, clone, multiple_girls, 2girls, extra_person, pov hands, "
|
||||
"disembodied hands, extra arms, second person"
|
||||
)
|
||||
|
||||
_CONTACT_ACTION_KEYWORDS = (
|
||||
"hug", "holding_hands", "hand_holding", "arms_out", "embrace",
|
||||
"reaching", "inviting_hug", "arm_around", "cuddling",
|
||||
)
|
||||
|
||||
_JUNK_STANDALONE_TAGS = frozenset({
|
||||
"white", "black", "skin", "ear", "ears", "girl", "boy", "fox", "wolf", "cat",
|
||||
"short", "tall", "slim", "golden", "silver", "red", "blue", "green", "purple",
|
||||
"pink", "brown", "blonde", "eye", "eyes", "hair",
|
||||
})
|
||||
|
||||
_INVALID_TAGS = frozenset({
|
||||
"pumped_up", "pumped", "looking_at_each_other", "couple",
|
||||
"2girls", "2boys", "multiple_girls", "multiple_boys", "duo",
|
||||
})
|
||||
|
||||
_POV_DROP_ACTION_TAGS = frozenset({
|
||||
"holding_hands", "hand_holding", "looking_at_each_other", "couple",
|
||||
"lifting", "carry", "carrying", "princess_carry", "nose_rub", "nose_boop",
|
||||
})
|
||||
|
||||
_TIME_SKIP_RE = re.compile(
|
||||
r"(?i)\b(?:about an hour later|hours later|later that (?:day|evening|night)|"
|
||||
r"the next (?:day|morning|evening)|meanwhile|after (?:some )?time)\b[.…\s]*",
|
||||
)
|
||||
|
||||
_POV_MOOD_FALLBACK: dict[str, str] = {
|
||||
"walking_together": "Easy warmth and quiet laughter in the afternoon light.",
|
||||
"doorway_invite": "Cool air and playful tension as she waits in the doorway.",
|
||||
"reach_to_viewer": "A charged moment as she reaches toward the camera.",
|
||||
"face_to_face": "Her expression softens in close focus toward the camera.",
|
||||
"dialogue_close": "Intimate calm in the space between you.",
|
||||
}
|
||||
|
||||
_INDOOR_ENV_MARKERS = frozenset({"doorway", "indoors", "indoor", "apartment", "inside", "room"})
|
||||
_OUTDOOR_ENV_MARKERS = frozenset({"outdoor", "outdoors", "outside", "street"})
|
||||
|
||||
_POV_PROSE_BANNED = re.compile(
|
||||
r"\b(someone|both|together with|hand in hand with|another person|second person|"
|
||||
r"your hands|your fingers|your embrace|your heat|intertwined|with you|"
|
||||
r"demands your|before you)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
SD_ANIMA_DUAL_COMPARE = os.getenv("SD_ANIMA_DUAL_COMPARE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SdPromptBundle:
|
||||
tag_full: str
|
||||
negative: str
|
||||
desc_full: str | None = None
|
||||
|
||||
|
||||
def extract_image_prompt_tag(text: str) -> str | None:
|
||||
@@ -44,7 +133,7 @@ SD_UNET = os.getenv("SD_UNET", "")
|
||||
SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip()
|
||||
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
|
||||
|
||||
|
||||
@@ -56,37 +145,201 @@ def _is_anima() -> bool:
|
||||
return bool(SD_UNET) and not SD_CHECKPOINT
|
||||
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
def anima_dual_enabled() -> bool:
|
||||
return _is_anima() and SD_ANIMA_DUAL_COMPARE
|
||||
|
||||
|
||||
def _builder_system() -> str:
|
||||
if _is_anima():
|
||||
return PROMPT_BUILDER_SYSTEM + ANIMA_BUILDER_EXTRA
|
||||
return PROMPT_BUILDER_SYSTEM
|
||||
|
||||
|
||||
def _normalize_shot_type(scene: dict) -> dict:
|
||||
st = (scene.get("shot_type") or "").strip().lower()
|
||||
if st == "landscape":
|
||||
scene["shot_type"] = "landscape"
|
||||
return _sanitize_scene_fields(scene)
|
||||
if st == "third_person":
|
||||
action = (scene.get("action_tags") or "").lower()
|
||||
wide = ("battle", "fight", "chase", "running", "crowd", "wide_shot", "group_shot")
|
||||
if any(w in action for w in wide):
|
||||
scene["shot_type"] = "third_person"
|
||||
return _sanitize_scene_fields(scene)
|
||||
scene["shot_type"] = "first_person_pov"
|
||||
if scene.get("viewer_body_visible") is None:
|
||||
scene["viewer_body_visible"] = False
|
||||
return _sanitize_scene_fields(scene)
|
||||
|
||||
|
||||
def _split_tag_input(tag_str: str) -> list[str]:
|
||||
return [t.strip() for t in (tag_str or "").split(",") if t.strip()]
|
||||
|
||||
|
||||
def _is_sentence_like_tag(tag: str) -> bool:
|
||||
t = tag.strip()
|
||||
if len(t) > 45:
|
||||
return True
|
||||
if re.search(r"[.!?]", t):
|
||||
return True
|
||||
words = t.split()
|
||||
return len(words) >= 5 and "_" not in t
|
||||
|
||||
|
||||
def _filter_tag_field(tag_str: str, *, for_pov: bool, field: str) -> str:
|
||||
kept: list[str] = []
|
||||
for raw in _split_tag_input(tag_str):
|
||||
key = raw.lower().replace(" ", "_")
|
||||
if key in _INVALID_TAGS:
|
||||
continue
|
||||
if _is_sentence_like_tag(raw):
|
||||
continue
|
||||
if for_pov and field == "action" and key in _POV_DROP_ACTION_TAGS:
|
||||
continue
|
||||
kept.append(raw if "_" in raw else key)
|
||||
return ", ".join(kept)
|
||||
|
||||
|
||||
def _reconcile_environment_tags(env_str: str) -> str:
|
||||
tags = _split_tag_input(env_str)
|
||||
keys = {t.lower().replace(" ", "_") for t in tags}
|
||||
has_indoor = bool(keys & _INDOOR_ENV_MARKERS) or any(
|
||||
any(m in k for m in _INDOOR_ENV_MARKERS) for k in keys
|
||||
)
|
||||
has_outdoor = bool(keys & _OUTDOOR_ENV_MARKERS) or any(
|
||||
any(m in k for m in _OUTDOOR_ENV_MARKERS) for k in keys
|
||||
)
|
||||
if has_indoor and has_outdoor:
|
||||
tags = [t for t in tags if t.lower().replace(" ", "_") not in _OUTDOOR_ENV_MARKERS]
|
||||
return ", ".join(tags)
|
||||
|
||||
|
||||
def _sanitize_pov_prose(desc: str, scene: dict) -> str:
|
||||
if not desc or not desc.strip():
|
||||
return ""
|
||||
if scene.get("shot_type") != "first_person_pov":
|
||||
return desc.strip()
|
||||
|
||||
kept: list[str] = []
|
||||
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
|
||||
s = sentence.strip()
|
||||
if not s:
|
||||
continue
|
||||
if _POV_PROSE_BANNED.search(s):
|
||||
continue
|
||||
if re.search(r"\bwolfgirl\b", s, re.I) and re.search(
|
||||
r"\b(walks|walking|stands)\b", s, re.I
|
||||
):
|
||||
continue
|
||||
kept.append(s)
|
||||
out = " ".join(kept).strip()
|
||||
return re.sub(r"\bat the viewer\b", "at the camera", out, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _sanitize_scene_fields(scene: dict) -> dict:
|
||||
scene = dict(scene)
|
||||
for_pov = scene.get("shot_type") == "first_person_pov"
|
||||
scene["action_tags"] = _filter_tag_field(
|
||||
scene.get("action_tags") or "", for_pov=for_pov, field="action"
|
||||
)
|
||||
env = _filter_tag_field(scene.get("environment_tags") or "", for_pov=False, field="env")
|
||||
scene["environment_tags"] = _reconcile_environment_tags(env)
|
||||
scene["scene_description"] = _sanitize_pov_prose(
|
||||
(scene.get("scene_description") or "").strip(), scene
|
||||
)
|
||||
return scene
|
||||
|
||||
|
||||
def _scene_should_generate(scene: dict) -> bool:
|
||||
if scene.get("should_generate") is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _sanitize_tags_string(tag_str: str) -> str:
|
||||
if not tag_str:
|
||||
return ""
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in tag_str.split(","):
|
||||
t = raw.strip()
|
||||
if not t:
|
||||
continue
|
||||
key = t.lower().replace(" ", "_")
|
||||
if key in seen:
|
||||
continue
|
||||
if key in _INVALID_TAGS:
|
||||
continue
|
||||
if "_" not in key and key in _JUNK_STANDALONE_TAGS:
|
||||
continue
|
||||
if len(key) <= 2:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(t if "_" in t else key)
|
||||
return ", ".join(out)
|
||||
|
||||
|
||||
def _quality_prefix() -> str:
|
||||
if _is_pony():
|
||||
quality = "score_9, score_8_up, score_7_up, source_anime, highres"
|
||||
elif _is_anima():
|
||||
quality = "masterpiece, best quality, score_7, anime"
|
||||
else:
|
||||
quality = "masterpiece, best quality, highres"
|
||||
return "score_9, score_8_up, score_7_up, source_anime, highres"
|
||||
if _is_anima():
|
||||
return "masterpiece, best quality, score_7, anime"
|
||||
return "masterpiece, best quality, highres"
|
||||
|
||||
parts = [quality]
|
||||
|
||||
appearance = (persona or {}).get("appearance_tags", "")
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(outfit_tags)
|
||||
def _appearance_for_persona(persona: dict | None) -> str:
|
||||
"""Tag core uses appearance_tags only (prose is for LLM context, not Comfy tag line)."""
|
||||
return _sanitize_tags_string((persona or {}).get("appearance_tags", ""))
|
||||
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
else:
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(scene.get("action_tags", ""))
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
|
||||
def _dedupe_outfit_tags(outfit_tags: str) -> str:
|
||||
tags = _split_tag_input(outfit_tags)
|
||||
keys = {t.lower().replace(" ", "_") for t in tags}
|
||||
if len(keys & {"jeans", "ripped_jeans", "black_jeans"}) > 1 and "jeans" in keys:
|
||||
tags = [t for t in tags if t.lower().replace(" ", "_") != "jeans"]
|
||||
return ", ".join(tags)
|
||||
|
||||
|
||||
def _scene_has_physical_contact(scene: dict) -> bool:
|
||||
action = (scene.get("action_tags") or "").lower()
|
||||
return any(k in action for k in _CONTACT_ACTION_KEYWORDS)
|
||||
|
||||
|
||||
def _infer_pov_cue_from_action(action_tags: str) -> str:
|
||||
action = (action_tags or "").lower()
|
||||
if any(k in action for k in ("holding_hands", "hand_holding", "walking", "strolling")):
|
||||
return "walking_together"
|
||||
if any(k in action for k in ("doorway", "door", "entry", "threshold")):
|
||||
if any(k in action for k in ("arms_out", "hug", "embrace", "inviting")):
|
||||
return "doorway_invite"
|
||||
if any(k in action for k in ("arms_out", "reaching", "inviting_hug", "hug", "embrace")):
|
||||
return "reach_to_viewer"
|
||||
if any(k in action for k in ("sitting", "lying", "bed")):
|
||||
return "dialogue_close"
|
||||
return "face_to_face"
|
||||
|
||||
|
||||
def _build_pov_phrase(scene: dict) -> str:
|
||||
if scene.get("shot_type") != "first_person_pov":
|
||||
return ""
|
||||
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
if cue in POV_CUE_PHRASES:
|
||||
return POV_CUE_PHRASES[cue]
|
||||
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
|
||||
return POV_CUE_PHRASES.get(inferred, POV_CUE_DEFAULT)
|
||||
|
||||
|
||||
def _append_lora(parts: list[str], persona: dict | None) -> None:
|
||||
lora = (persona or {}).get("lora_name", "")
|
||||
weight = (persona or {}).get("lora_weight", 0.8)
|
||||
if lora:
|
||||
parts.append(f"<lora:{lora}:{weight}>")
|
||||
|
||||
|
||||
def _dedupe_comma_join(parts: list[str]) -> str:
|
||||
positive = ", ".join(p.strip() for p in parts if p and p.strip())
|
||||
seen, deduped = set(), []
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for tag in positive.split(", "):
|
||||
t = tag.strip()
|
||||
if t and t not in seen:
|
||||
@@ -95,53 +348,152 @@ def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str =
|
||||
return ", ".join(deduped)
|
||||
|
||||
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
outfit_json: str = "[]",
|
||||
) -> tuple[str | None, str | None]:
|
||||
persona = await get_persona(persona_id)
|
||||
# Generate only if persona has appearance tags
|
||||
if not persona or not (persona.get("appearance_tags") or "").strip():
|
||||
logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id)
|
||||
return None, None
|
||||
def _build_tag_core(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Anchor + structure: quality, appearance, outfit, action/env tags, LoRA. No POV prose, no scene_description."""
|
||||
parts = [_quality_prefix()]
|
||||
appearance = _appearance_for_persona(persona)
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
else:
|
||||
if not _is_anima() and scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
_append_lora(parts, persona)
|
||||
return _dedupe_comma_join(parts)
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
|
||||
if not recent:
|
||||
return None, None
|
||||
|
||||
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
|
||||
def build_positive_prompt_tags_only(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Tags + contextual POV phrase (Anima) or legacy Pony path."""
|
||||
if not _is_anima():
|
||||
return build_positive_prompt(scene, persona, outfit_tags)
|
||||
core = _build_tag_core(scene, persona, outfit_tags)
|
||||
pov = _build_pov_phrase(scene)
|
||||
if pov:
|
||||
return f"{core}, {pov}" if core else pov
|
||||
return core
|
||||
|
||||
builder_messages = [
|
||||
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
|
||||
{"role": "user", "content": f"Chat:\n{excerpt}"},
|
||||
]
|
||||
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||
raw = re.sub(r"\n?```$", "", raw)
|
||||
scene = json.loads(raw)
|
||||
if not isinstance(scene, dict):
|
||||
logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", ""))
|
||||
return None, None
|
||||
def _tag_tokens_for_dedupe(tag_line: str) -> set[str]:
|
||||
tokens: set[str] = set()
|
||||
for part in tag_line.replace("<lora:", " ").split(","):
|
||||
for word in re.split(r"[\s_./]+", part.lower()):
|
||||
w = word.strip()
|
||||
if len(w) >= 4:
|
||||
tokens.add(w)
|
||||
return tokens
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
def _trim_redundant_scene_description(desc: str, tag_line: str) -> str:
|
||||
tag_tokens = _tag_tokens_for_dedupe(tag_line)
|
||||
if not tag_tokens or not desc.strip():
|
||||
return desc.strip()
|
||||
|
||||
kept: list[str] = []
|
||||
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
|
||||
s = sentence.strip()
|
||||
if not s:
|
||||
continue
|
||||
words = [w.lower() for w in re.findall(r"[a-zA-Z]{4,}", s)]
|
||||
if not words:
|
||||
kept.append(s)
|
||||
continue
|
||||
overlap = sum(1 for w in words if w in tag_tokens) / len(words)
|
||||
if overlap < 0.62:
|
||||
kept.append(s)
|
||||
|
||||
return " ".join(kept).strip()
|
||||
|
||||
|
||||
def _extract_illustrate_content(content: str, max_chars: int = 1400) -> str:
|
||||
"""Long assistant posts (first_mes): use final beat after time-skip, last paragraphs."""
|
||||
text = strip_image_prompt_tag(content).strip()
|
||||
if not text:
|
||||
return ""
|
||||
chunks = _TIME_SKIP_RE.split(text)
|
||||
if len(chunks) > 1:
|
||||
text = chunks[-1].strip()
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
||||
if paragraphs:
|
||||
for n in (1, 2, 3):
|
||||
tail = "\n\n".join(paragraphs[-n:])
|
||||
if len(tail) <= max_chars:
|
||||
return tail
|
||||
return paragraphs[-1][-max_chars:]
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _fallback_mood_prose(scene: dict) -> str:
|
||||
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
if cue in _POV_MOOD_FALLBACK:
|
||||
return _POV_MOOD_FALLBACK[cue]
|
||||
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
|
||||
return _POV_MOOD_FALLBACK.get(inferred, "Soft atmosphere; her expression toward the camera.")
|
||||
|
||||
|
||||
def _cap_scene_description(desc: str, max_words: int = 40, max_chars: int = 220) -> str:
|
||||
words = desc.split()
|
||||
if len(words) > max_words:
|
||||
desc = " ".join(words[:max_words])
|
||||
if len(desc) > max_chars:
|
||||
desc = desc[: max_chars - 3] + "..."
|
||||
return desc
|
||||
|
||||
|
||||
def build_positive_prompt_hybrid(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Production Anima prompt: tag core + POV cue + short mood prose."""
|
||||
if not _is_anima():
|
||||
return build_positive_prompt(scene, persona, outfit_tags)
|
||||
|
||||
base = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
desc = _trim_redundant_scene_description(
|
||||
(scene.get("scene_description") or "").strip(),
|
||||
base,
|
||||
)
|
||||
desc = _cap_scene_description(desc)
|
||||
if not desc:
|
||||
desc = _cap_scene_description(_fallback_mood_prose(scene))
|
||||
if not desc:
|
||||
return base
|
||||
|
||||
lora = (persona or {}).get("lora_name", "")
|
||||
weight = (persona or {}).get("lora_weight", 0.8)
|
||||
lora_suffix = f" <lora:{lora}:{weight}>" if lora else ""
|
||||
if lora_suffix and base.endswith(lora_suffix):
|
||||
base = base[: -len(lora_suffix)]
|
||||
return f"{base}. {desc}{lora_suffix}"
|
||||
return f"{base}. {desc}"
|
||||
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Legacy entry: Pony/non-Anima full prompt; Anima delegates to tags-only."""
|
||||
if _is_anima():
|
||||
return build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
|
||||
parts = [_quality_prefix()]
|
||||
appearance = _appearance_for_persona(persona)
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
else:
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
_append_lora(parts, persona)
|
||||
return _dedupe_comma_join(parts)
|
||||
|
||||
|
||||
def _negative_for_scene(scene: dict) -> str:
|
||||
if _is_pony():
|
||||
negative = PONY_NEGATIVE
|
||||
elif _is_anima():
|
||||
@@ -151,6 +503,228 @@ async def generate_sd_prompt(
|
||||
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
negative += ", third person, over the shoulder"
|
||||
viewer_visible = scene.get("viewer_body_visible") is True
|
||||
if not viewer_visible or _scene_has_physical_contact(scene):
|
||||
negative += ", " + POV_INTERACTION_NEGATIVE
|
||||
|
||||
full = positive + f"\n\nNegative prompt: {negative}"
|
||||
return full, negative
|
||||
return negative
|
||||
|
||||
|
||||
def _format_builder_user_block(persona: dict, messages: list[dict], outfit_json: str) -> str:
|
||||
lines: list[str] = []
|
||||
tags = (persona.get("appearance_tags") or "").strip()
|
||||
lines.append(f"Character appearance (tags): {tags}")
|
||||
prose = (persona.get("appearance_prose") or "").strip()
|
||||
if _is_anima() and prose and prose != tags:
|
||||
snippet = prose[:300] + ("..." if len(prose) > 300 else "")
|
||||
lines.append(f"Character notes (do not copy into tags or scene_description): {snippet}")
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_ref = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_ref = ""
|
||||
|
||||
if outfit_ref:
|
||||
lines.append(f"Current outfit (tags): {outfit_ref}")
|
||||
|
||||
recent = [m for m in messages if m.get("role") in ("user", "assistant")][-6:]
|
||||
if not recent:
|
||||
lines.append("\nChat:\n(no messages — return should_generate=false)")
|
||||
return "\n".join(lines)
|
||||
|
||||
illustrate: list[dict] = []
|
||||
if recent[-1]["role"] == "assistant":
|
||||
illustrate = [recent[-1]]
|
||||
if len(recent) >= 2 and recent[-2]["role"] == "user":
|
||||
illustrate.insert(0, recent[-2])
|
||||
else:
|
||||
illustrate = [recent[-1]]
|
||||
if len(recent) >= 2 and recent[-2]["role"] == "assistant":
|
||||
illustrate.insert(0, recent[-2])
|
||||
|
||||
context = [m for m in recent if m not in illustrate]
|
||||
|
||||
lines.append("\n=== ILLUSTRATE (draw THIS beat only) ===")
|
||||
for m in illustrate:
|
||||
raw = m.get("content", "")
|
||||
content = _extract_illustrate_content(raw) if m.get("role") == "assistant" else strip_image_prompt_tag(raw)
|
||||
lines.append(f"{m['role']}: {content}")
|
||||
|
||||
if context:
|
||||
lines.append("\n=== Context (outfit/location hints only — do not illustrate old beats) ===")
|
||||
for m in context:
|
||||
content = strip_image_prompt_tag(m.get("content", ""))
|
||||
if len(content) > 800:
|
||||
content = content[:797] + "..."
|
||||
lines.append(f"{m['role']}: {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _parse_scene_json(raw: str) -> dict:
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = re.sub(r"^```\w*\n?", "", cleaned)
|
||||
cleaned = re.sub(r"\n?```$", "", cleaned)
|
||||
scene = json.loads(cleaned)
|
||||
if not isinstance(scene, dict):
|
||||
raise ValueError("LLM returned non-object JSON")
|
||||
return _normalize_shot_type(scene)
|
||||
|
||||
|
||||
def _bundle_from_scene(scene: dict, persona: dict, outfit_tags: str) -> SdPromptBundle:
|
||||
negative = _negative_for_scene(scene)
|
||||
if _is_anima():
|
||||
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
|
||||
tag_full = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
desc_full = None
|
||||
if anima_dual_enabled():
|
||||
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
desc_full = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=desc_full)
|
||||
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
tag_full = positive + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=None)
|
||||
|
||||
|
||||
def _parse_chat_excerpt(excerpt: str) -> list[dict]:
|
||||
messages: list[dict] = []
|
||||
for line in (excerpt or "").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
lower = line.lower()
|
||||
if lower.startswith("user:"):
|
||||
messages.append({"role": "user", "content": line[5:].strip()})
|
||||
elif lower.startswith("assistant:"):
|
||||
messages.append({"role": "assistant", "content": line[10:].strip()})
|
||||
elif lower.startswith("system:"):
|
||||
messages.append({"role": "system", "content": line[7:].strip()})
|
||||
else:
|
||||
messages.append({"role": "user", "content": line})
|
||||
return messages
|
||||
|
||||
|
||||
async def run_prompt_builder(
|
||||
persona_id: str,
|
||||
*,
|
||||
messages: list[dict] | None = None,
|
||||
chat_excerpt: str = "",
|
||||
outfit_json: str = "[]",
|
||||
appearance_override: str | None = None,
|
||||
use_prose: bool = False,
|
||||
) -> dict:
|
||||
"""Debug: full SD prompt builder pipeline with LLM raw output."""
|
||||
persona = await get_persona(persona_id) or {}
|
||||
if appearance_override is not None:
|
||||
persona = {**persona, "appearance_tags": appearance_override}
|
||||
|
||||
recent = messages if messages is not None else _parse_chat_excerpt(chat_excerpt)
|
||||
recent = [m for m in recent if m.get("role") in ("user", "assistant")]
|
||||
|
||||
user_block = _format_builder_user_block(persona, recent, outfit_json)
|
||||
builder_messages = [
|
||||
{"role": "system", "content": _builder_system()},
|
||||
{"role": "user", "content": user_block},
|
||||
]
|
||||
model_used = SD_PROMPT_MODEL or "SYSTEM_MODEL"
|
||||
result: dict = {
|
||||
"persona_id": persona_id,
|
||||
"sd_prompt_model": model_used,
|
||||
"builder_system": _builder_system(),
|
||||
"builder_user": user_block,
|
||||
"anima_dual": anima_dual_enabled(),
|
||||
}
|
||||
|
||||
raw = ""
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
result["llm_raw"] = raw
|
||||
scene = _parse_scene_json(raw)
|
||||
result["scene"] = scene
|
||||
|
||||
if not _scene_should_generate(scene):
|
||||
result["skipped"] = True
|
||||
result["error"] = "should_generate=false"
|
||||
return result
|
||||
|
||||
try:
|
||||
outfit_tags = ", ".join(json.loads(outfit_json or "[]"))
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
negative = _negative_for_scene(scene)
|
||||
if _is_anima():
|
||||
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
|
||||
result["tag_positive"] = tags_only
|
||||
result["hybrid_positive"] = hybrid
|
||||
result["negative"] = negative
|
||||
result["tags_only_full"] = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
result["hybrid_full"] = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
result["tag_full"] = result["hybrid_full"]
|
||||
else:
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
result["tag_positive"] = positive
|
||||
result["negative"] = negative
|
||||
result["tag_full"] = positive + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
result["llm_raw"] = raw or result.get("llm_raw", "")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
outfit_json: str = "[]",
|
||||
) -> SdPromptBundle | None:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return None
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")]
|
||||
if not recent:
|
||||
return None
|
||||
|
||||
user_block = _format_builder_user_block(persona, recent, outfit_json)
|
||||
builder_messages = [
|
||||
{"role": "system", "content": _builder_system()},
|
||||
{"role": "user", "content": user_block},
|
||||
]
|
||||
|
||||
raw = ""
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
scene = _parse_scene_json(raw)
|
||||
except Exception as e:
|
||||
logger.warning("sd_prompt failed: %s raw=%.200s", e, raw)
|
||||
return None
|
||||
|
||||
if not _scene_should_generate(scene):
|
||||
logger.info("sd_prompt: skipped (should_generate=false)")
|
||||
return None
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
bundle = _bundle_from_scene(scene, persona, outfit_tags)
|
||||
if anima_dual_enabled() and bundle.desc_full:
|
||||
logger.info(
|
||||
"Anima prompts: hybrid=%.80s | tags_only=%.80s",
|
||||
bundle.tag_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
|
||||
bundle.desc_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
|
||||
)
|
||||
return bundle
|
||||
|
||||
+322
-23
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
@@ -11,7 +12,178 @@ load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/")
|
||||
|
||||
def _parse_basic_auth() -> httpx.BasicAuth | None:
|
||||
"""
|
||||
Vast Caddy on mapped ports often uses Basic realm=restricted.
|
||||
Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD.
|
||||
"""
|
||||
raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip()
|
||||
if raw:
|
||||
if ":" in raw:
|
||||
user, _, password = raw.partition(":")
|
||||
else:
|
||||
user, password = "", raw
|
||||
return httpx.BasicAuth(user, password)
|
||||
user = (os.getenv("SD_COMFY_USER") or "").strip()
|
||||
password = (os.getenv("SD_COMFY_PASSWORD") or "").strip()
|
||||
if user or password:
|
||||
return httpx.BasicAuth(user, password)
|
||||
return None
|
||||
|
||||
|
||||
SD_BASIC_AUTH = _parse_basic_auth()
|
||||
|
||||
|
||||
def _parse_comfy_config() -> tuple[str, dict[str, str]]:
|
||||
"""
|
||||
SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=...
|
||||
API paths must be base + /prompt, not ...?token=xxx/prompt
|
||||
"""
|
||||
raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip()
|
||||
extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip()
|
||||
parsed = urlparse(raw)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path and path != "/":
|
||||
base = f"{base}{path}"
|
||||
query: dict[str, str] = {}
|
||||
for key, values in parse_qs(parsed.query).items():
|
||||
if values:
|
||||
query[key] = values[-1]
|
||||
if extra_token:
|
||||
query["token"] = extra_token
|
||||
base = base.rstrip("/")
|
||||
# Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply
|
||||
if "trycloudflare.com" in base.lower():
|
||||
if query.pop("token", None):
|
||||
logger.info(
|
||||
"SD_BASE_URL is trycloudflare tunnel: Vast token stripped. "
|
||||
"Use tunnel for port 8188 only (see instance Port Mapping)."
|
||||
)
|
||||
return base, query
|
||||
|
||||
|
||||
SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config()
|
||||
|
||||
|
||||
def _comfy_url(path: str) -> str:
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{SD_BASE_URL}{path}"
|
||||
|
||||
|
||||
def _log_comfy_target() -> str:
|
||||
if SD_QUERY_PARAMS.get("token"):
|
||||
return f"{SD_BASE_URL}?token=***"
|
||||
return SD_BASE_URL
|
||||
|
||||
|
||||
def _absolute_url(location: str, fallback_path: str = "/") -> str:
|
||||
if not location:
|
||||
return _comfy_url(fallback_path)
|
||||
if location.startswith(("http://", "https://")):
|
||||
return location
|
||||
if location.startswith("/"):
|
||||
return f"{SD_BASE_URL}{location}"
|
||||
return f"{SD_BASE_URL}/{location}"
|
||||
|
||||
|
||||
def _url_with_token(url: str) -> str:
|
||||
"""Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect)."""
|
||||
if not SD_QUERY_PARAMS.get("token"):
|
||||
return url
|
||||
p = urlparse(url)
|
||||
q: dict[str, str] = {}
|
||||
for key, values in parse_qs(p.query).items():
|
||||
if values:
|
||||
q[key] = values[-1]
|
||||
q.update(SD_QUERY_PARAMS)
|
||||
return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), ""))
|
||||
|
||||
|
||||
def _merge_params(extra: dict | None) -> dict | None:
|
||||
if not SD_QUERY_PARAMS and not extra:
|
||||
return None
|
||||
merged = dict(SD_QUERY_PARAMS)
|
||||
if extra:
|
||||
merged.update(extra)
|
||||
return merged
|
||||
|
||||
|
||||
def _is_vast_gateway() -> bool:
|
||||
return "trycloudflare.com" not in SD_BASE_URL.lower()
|
||||
|
||||
|
||||
def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=False,
|
||||
auth=SD_BASIC_AUTH,
|
||||
)
|
||||
|
||||
|
||||
async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None:
|
||||
"""
|
||||
Vast Caddy: browser opens /?token=… and gets a session cookie; API then works.
|
||||
Prime with redirects so Set-Cookie is collected, then merge into the API client.
|
||||
"""
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
if not token or not _is_vast_gateway():
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30,
|
||||
follow_redirects=True,
|
||||
auth=SD_BASIC_AUTH,
|
||||
) as prime:
|
||||
r = await prime.get(_comfy_url("/"), params={"token": token})
|
||||
client.cookies.update(prime.cookies)
|
||||
logger.info(
|
||||
"Comfy gateway prime GET /?token=*** → %s, cookies=%s",
|
||||
r.status_code,
|
||||
list(prime.cookies.keys()) or "(none)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Comfy gateway prime failed: %s", e)
|
||||
|
||||
|
||||
async def _comfy_request(
|
||||
client: httpx.AsyncClient,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Comfy API: trycloudflare tunnel = no token.
|
||||
Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached.
|
||||
"""
|
||||
url = _comfy_url(path)
|
||||
extra = params or {}
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
|
||||
if token and _is_vast_gateway():
|
||||
await _prime_comfy_gateway(client)
|
||||
|
||||
req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None)
|
||||
resp: httpx.Response | None = None
|
||||
|
||||
for hop in range(6):
|
||||
resp = await client.request(method, url, params=req_params, **kwargs)
|
||||
if resp.status_code not in (301, 302, 303, 307, 308):
|
||||
return resp
|
||||
loc = _absolute_url(resp.headers.get("location", ""), path)
|
||||
url = _url_with_token(loc) if use_vast_auth else loc
|
||||
req_params = extra or None
|
||||
logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0])
|
||||
|
||||
assert resp is not None
|
||||
return resp
|
||||
|
||||
|
||||
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
||||
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
||||
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
||||
@@ -26,6 +198,8 @@ SD_DEFAULT_NEGATIVE = os.getenv(
|
||||
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
|
||||
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
|
||||
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
|
||||
SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "")
|
||||
SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7"))
|
||||
|
||||
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
||||
|
||||
@@ -38,19 +212,37 @@ def _use_anima() -> bool:
|
||||
|
||||
|
||||
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
# Try new separator first
|
||||
sep = "__NEGATIVE_PROMPT__"
|
||||
if f"\n{sep}\n" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition(f"\n{sep}\n")
|
||||
return pos.strip(), neg.strip()
|
||||
# Fallback to old format
|
||||
if "\n\nNegative prompt:" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
|
||||
return pos.strip(), neg.strip()
|
||||
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str) -> dict:
|
||||
def _workflow_uses_anima(overrides: dict | None) -> bool:
|
||||
if overrides and overrides.get("checkpoint"):
|
||||
return False
|
||||
if overrides and overrides.get("unet"):
|
||||
return True
|
||||
return _use_anima()
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict:
|
||||
seed = int(uuid.uuid4().int % 2**32)
|
||||
if _use_anima():
|
||||
return {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}},
|
||||
o = overrides or {}
|
||||
if _workflow_uses_anima(o):
|
||||
unet = o.get("unet") or SD_UNET
|
||||
clip = o.get("clip") or SD_CLIP
|
||||
vae = o.get("vae") or SD_VAE
|
||||
workflow = {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}},
|
||||
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
||||
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
|
||||
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
|
||||
@@ -68,9 +260,24 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
}
|
||||
# Standard checkpoint workflow (Pony / SDXL)
|
||||
if SD_STYLE_LORA:
|
||||
workflow["46"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": SD_STYLE_LORA,
|
||||
"model": ["44", 0],
|
||||
"clip": ["45", 0],
|
||||
"strength_model": SD_STYLE_LORA_WEIGHT,
|
||||
"strength_clip": SD_STYLE_LORA_WEIGHT,
|
||||
},
|
||||
}
|
||||
workflow["19"]["inputs"]["model"] = ["46", 0]
|
||||
workflow["11"]["inputs"]["clip"] = ["46", 1]
|
||||
workflow["12"]["inputs"]["clip"] = ["46", 1]
|
||||
return workflow
|
||||
ckpt = o.get("checkpoint") or SD_CHECKPOINT
|
||||
return {
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
@@ -89,24 +296,78 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def comfy_api_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
timeout: float = 60,
|
||||
) -> tuple[int, dict | str, dict]:
|
||||
"""
|
||||
Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset).
|
||||
"""
|
||||
async with _make_comfy_client(timeout=timeout) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
req_params = _merge_params(params) if use_vast else (params or None)
|
||||
req_kwargs: dict = {}
|
||||
if json_body is not None and method.upper() not in ("GET", "HEAD"):
|
||||
req_kwargs["json"] = json_body
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
method.upper(),
|
||||
path,
|
||||
params=req_params,
|
||||
**req_kwargs,
|
||||
)
|
||||
headers = {
|
||||
k: resp.headers.get(k)
|
||||
for k in ("content-type", "location", "www-authenticate")
|
||||
if resp.headers.get(k)
|
||||
}
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
body = resp.text[:8000]
|
||||
return resp.status_code, body, headers
|
||||
|
||||
|
||||
async def fetch_object_info() -> dict:
|
||||
status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120)
|
||||
if status != 200 or not isinstance(body, dict):
|
||||
raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}")
|
||||
return body
|
||||
|
||||
|
||||
async def check_sd() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{SD_BASE_URL}/system_stats")
|
||||
async with _make_comfy_client(timeout=15) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
r = await _comfy_request(client, "GET", "/system_stats")
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]:
|
||||
async def txt2img(
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[bytes, str]:
|
||||
neg = negative_prompt or SD_DEFAULT_NEGATIVE
|
||||
workflow = _build_workflow(prompt, neg)
|
||||
workflow = _build_workflow(prompt, neg, overrides)
|
||||
client_id = uuid.uuid4().hex
|
||||
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
resp = await client.post(
|
||||
f"{SD_BASE_URL}/prompt",
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt)
|
||||
async with _make_comfy_client() as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
"POST",
|
||||
"/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -115,7 +376,7 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
|
||||
for _ in range(300):
|
||||
await asyncio.sleep(1)
|
||||
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
|
||||
hist = await _comfy_request(client, "GET", f"/history/{prompt_id}")
|
||||
data = hist.json()
|
||||
if prompt_id in data:
|
||||
entry = data[prompt_id]
|
||||
@@ -127,9 +388,15 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
for node_output in outputs.values():
|
||||
if "images" in node_output:
|
||||
img_info = node_output["images"][0]
|
||||
img_resp = await client.get(
|
||||
f"{SD_BASE_URL}/view",
|
||||
params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")},
|
||||
img_resp = await _comfy_request(
|
||||
client,
|
||||
"GET",
|
||||
"/view",
|
||||
params={
|
||||
"filename": img_info["filename"],
|
||||
"subfolder": img_info.get("subfolder", ""),
|
||||
"type": img_info.get("type", "output"),
|
||||
},
|
||||
)
|
||||
img_resp.raise_for_status()
|
||||
image_bytes = img_resp.content
|
||||
@@ -145,11 +412,43 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
||||
|
||||
|
||||
async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]:
|
||||
async def generate_from_full_prompt(
|
||||
full_prompt: str,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
positive, negative = split_prompt_and_negative(full_prompt)
|
||||
try:
|
||||
_, rel_path = await txt2img(positive, negative)
|
||||
_, rel_path = await txt2img(positive, negative, overrides=overrides)
|
||||
return rel_path, None
|
||||
except httpx.HTTPStatusError as e:
|
||||
code = e.response.status_code
|
||||
if code == 401:
|
||||
logger.error(
|
||||
"ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) "
|
||||
"and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. "
|
||||
"Test: curl -u user:pass http://IP:PORT/system_stats "
|
||||
"or open /?token=… in browser then curl with cookies. "
|
||||
"Alternative: trycloudflare URL for localhost:8188 in Port Mapping."
|
||||
)
|
||||
elif code in (301, 302, 303, 307, 308):
|
||||
logger.error(
|
||||
"ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. "
|
||||
"SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com "
|
||||
"(no ?token=). Location: %s",
|
||||
code,
|
||||
e.response.headers.get("location"),
|
||||
)
|
||||
else:
|
||||
logger.error("ComfyUI HTTP %s: %s", code, e)
|
||||
return None, str(e)
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(
|
||||
"ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. "
|
||||
"Use trycloudflare URL from Port Mapping for localhost:8188.",
|
||||
e,
|
||||
)
|
||||
return None, str(e)
|
||||
except Exception as e:
|
||||
logger.error("ComfyUI error: %s", e)
|
||||
return None, str(e)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from services.memory import get_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_session_persona(
|
||||
session_id: str,
|
||||
requested: str | None = None,
|
||||
*,
|
||||
create_persona: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Session.persona_id is the source of truth.
|
||||
requested is ignored when it disagrees (logged). create_persona used only if session missing.
|
||||
"""
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
return (create_persona or requested or "default").strip() or "default"
|
||||
|
||||
bound = (session.get("persona_id") or "default").strip() or "default"
|
||||
req = (requested or "").strip()
|
||||
if req and req != bound:
|
||||
logger.warning(
|
||||
"persona_id mismatch session=%s bound=%s requested=%s (using bound)",
|
||||
session_id,
|
||||
bound,
|
||||
req,
|
||||
)
|
||||
return bound
|
||||
@@ -0,0 +1,21 @@
|
||||
import logging
|
||||
|
||||
from services.chat_prompt import get_system_prompt
|
||||
from services.memory import get_all_sessions, get_history, upsert_static_system_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate_static_system_messages() -> int:
|
||||
"""Rebuild stored system rows from sessions.persona_id (strip legacy RPG text)."""
|
||||
updated = 0
|
||||
for session in await get_all_sessions():
|
||||
sid = session["session_id"]
|
||||
persona_id = session.get("persona_id") or "default"
|
||||
history = await get_history(sid)
|
||||
static = await get_system_prompt(persona_id, history, "")
|
||||
if await upsert_static_system_message(sid, static, history):
|
||||
updated += 1
|
||||
if updated:
|
||||
logger.info("Migrated %s session system message(s) to static persona prompt", updated)
|
||||
return updated
|
||||
Reference in New Issue
Block a user