import json import logging import os import re from services.llm import send_message, send_message_with_model from services.personas import get_persona logger = logging.getLogger(__name__) PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models. Given a roleplay chat excerpt, output ONLY valid JSON (no markdown): { "should_generate": true, "shot_type": "first_person_pov" | "landscape" | "third_person", "action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'", "environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'" } Rules: - ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'. - Do NOT include appearance/character tags — those are provided separately. - Do NOT include quality tags, model names, style words, 'pov', or category/metadata words. - Do NOT invent tags. If unsure — omit. - Keep each field to 3-6 tags.""" def extract_image_prompt_tag(text: str) -> str | None: if "[IMAGE_PROMPT:" not in text: return None try: start = text.index("[IMAGE_PROMPT:") + len("[IMAGE_PROMPT:") end = text.index("]", start) return text[start:end].strip() except ValueError: return None def strip_image_prompt_tag(text: str) -> str: return re.sub(r"\[IMAGE_PROMPT:.*?\]", "", text, flags=re.DOTALL).strip() SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "") SD_UNET = os.getenv("SD_UNET", "") SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip() PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"} PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored" ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia" def _is_pony() -> bool: return SD_CHECKPOINT in PONY_CHECKPOINTS def _is_anima() -> bool: return bool(SD_UNET) and not SD_CHECKPOINT def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: if _is_pony(): quality = "score_9, score_8_up, score_7_up, source_anime, highres" elif _is_anima(): quality = "masterpiece, best quality, score_7, anime" else: quality = "masterpiece, best quality, highres" parts = [quality] appearance = (persona or {}).get("appearance_tags", "") if appearance: parts.append(appearance) if outfit_tags: parts.append(outfit_tags) if scene.get("shot_type") == "landscape": parts.append(scene.get("environment_tags", "")) else: if scene.get("shot_type") == "first_person_pov": parts.append("pov, first-person view, looking at viewer") parts.append(scene.get("action_tags", "")) parts.append(scene.get("environment_tags", "")) lora = (persona or {}).get("lora_name", "") weight = (persona or {}).get("lora_weight", 0.8) if lora: parts.append(f"") positive = ", ".join(p.strip() for p in parts if p and p.strip()) seen, deduped = set(), [] for tag in positive.split(", "): t = tag.strip() if t and t not in seen: seen.add(t) deduped.append(t) return ", ".join(deduped) async def generate_sd_prompt( messages: list, persona_id: str, outfit_json: str = "[]", ) -> tuple[str | None, str | None]: persona = await get_persona(persona_id) # Generate only if persona has appearance tags if not persona or not (persona.get("appearance_tags") or "").strip(): logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id) return None, None recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:] if not recent: return None, None excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent) builder_messages = [ {"role": "system", "content": PROMPT_BUILDER_SYSTEM}, {"role": "user", "content": f"Chat:\n{excerpt}"}, ] try: if SD_PROMPT_MODEL: raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL) else: raw = await send_message(builder_messages) raw = raw.strip() if raw.startswith("```"): raw = re.sub(r"^```\w*\n?", "", raw) raw = re.sub(r"\n?```$", "", raw) scene = json.loads(raw) if not isinstance(scene, dict): logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw) return None, None except Exception as e: logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", "")) return None, None try: outfit_list = json.loads(outfit_json or "[]") outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else "" except Exception: outfit_tags = "" positive = build_positive_prompt(scene, persona, outfit_tags) if _is_pony(): negative = PONY_NEGATIVE elif _is_anima(): negative = ANIMA_NEGATIVE else: negative = "low quality, blurry, bad anatomy, watermark, text" if scene.get("shot_type") == "first_person_pov": negative += ", third person, over the shoulder" full = positive + f"\n\nNegative prompt: {negative}" return full, negative