157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
from services.llm import send_message, send_message_with_model
|
|
from services.personas import get_persona
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
|
|
Given a roleplay chat excerpt, output ONLY valid JSON (no markdown):
|
|
{
|
|
"should_generate": true,
|
|
"shot_type": "first_person_pov" | "landscape" | "third_person",
|
|
"action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'",
|
|
"environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'"
|
|
}
|
|
Rules:
|
|
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'.
|
|
- Do NOT include appearance/character tags — those are provided separately.
|
|
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
|
|
- Do NOT invent tags. If unsure — omit.
|
|
- Keep each field to 3-6 tags."""
|
|
|
|
|
|
def extract_image_prompt_tag(text: str) -> str | None:
|
|
if "[IMAGE_PROMPT:" not in text:
|
|
return None
|
|
try:
|
|
start = text.index("[IMAGE_PROMPT:") + len("[IMAGE_PROMPT:")
|
|
end = text.index("]", start)
|
|
return text[start:end].strip()
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
def strip_image_prompt_tag(text: str) -> str:
|
|
return re.sub(r"\[IMAGE_PROMPT:.*?\]", "", text, flags=re.DOTALL).strip()
|
|
|
|
|
|
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "")
|
|
SD_UNET = os.getenv("SD_UNET", "")
|
|
SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip()
|
|
|
|
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
|
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
|
ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
|
|
|
|
|
|
def _is_pony() -> bool:
|
|
return SD_CHECKPOINT in PONY_CHECKPOINTS
|
|
|
|
|
|
def _is_anima() -> bool:
|
|
return bool(SD_UNET) and not SD_CHECKPOINT
|
|
|
|
|
|
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
|
if _is_pony():
|
|
quality = "score_9, score_8_up, score_7_up, source_anime, highres"
|
|
elif _is_anima():
|
|
quality = "masterpiece, best quality, score_7, anime"
|
|
else:
|
|
quality = "masterpiece, best quality, highres"
|
|
|
|
parts = [quality]
|
|
|
|
appearance = (persona or {}).get("appearance_tags", "")
|
|
if appearance:
|
|
parts.append(appearance)
|
|
if outfit_tags:
|
|
parts.append(outfit_tags)
|
|
|
|
if scene.get("shot_type") == "landscape":
|
|
parts.append(scene.get("environment_tags", ""))
|
|
else:
|
|
if scene.get("shot_type") == "first_person_pov":
|
|
parts.append("pov, first-person view, looking at viewer")
|
|
parts.append(scene.get("action_tags", ""))
|
|
parts.append(scene.get("environment_tags", ""))
|
|
|
|
lora = (persona or {}).get("lora_name", "")
|
|
weight = (persona or {}).get("lora_weight", 0.8)
|
|
if lora:
|
|
parts.append(f"<lora:{lora}:{weight}>")
|
|
|
|
positive = ", ".join(p.strip() for p in parts if p and p.strip())
|
|
seen, deduped = set(), []
|
|
for tag in positive.split(", "):
|
|
t = tag.strip()
|
|
if t and t not in seen:
|
|
seen.add(t)
|
|
deduped.append(t)
|
|
return ", ".join(deduped)
|
|
|
|
|
|
async def generate_sd_prompt(
|
|
messages: list,
|
|
persona_id: str,
|
|
outfit_json: str = "[]",
|
|
) -> tuple[str | None, str | None]:
|
|
persona = await get_persona(persona_id)
|
|
# Generate only if persona has appearance tags
|
|
if not persona or not (persona.get("appearance_tags") or "").strip():
|
|
logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id)
|
|
return None, None
|
|
|
|
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
|
|
if not recent:
|
|
return None, None
|
|
|
|
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
|
|
|
|
builder_messages = [
|
|
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
|
|
{"role": "user", "content": f"Chat:\n{excerpt}"},
|
|
]
|
|
|
|
try:
|
|
if SD_PROMPT_MODEL:
|
|
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
|
else:
|
|
raw = await send_message(builder_messages)
|
|
raw = raw.strip()
|
|
if raw.startswith("```"):
|
|
raw = re.sub(r"^```\w*\n?", "", raw)
|
|
raw = re.sub(r"\n?```$", "", raw)
|
|
scene = json.loads(raw)
|
|
if not isinstance(scene, dict):
|
|
logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw)
|
|
return None, None
|
|
except Exception as e:
|
|
logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", ""))
|
|
return None, None
|
|
|
|
try:
|
|
outfit_list = json.loads(outfit_json or "[]")
|
|
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
|
except Exception:
|
|
outfit_tags = ""
|
|
|
|
positive = build_positive_prompt(scene, persona, outfit_tags)
|
|
|
|
if _is_pony():
|
|
negative = PONY_NEGATIVE
|
|
elif _is_anima():
|
|
negative = ANIMA_NEGATIVE
|
|
else:
|
|
negative = "low quality, blurry, bad anatomy, watermark, text"
|
|
|
|
if scene.get("shot_type") == "first_person_pov":
|
|
negative += ", third person, over the shoulder"
|
|
|
|
full = positive + f"\n\nNegative prompt: {negative}"
|
|
return full, negative
|