Fixed SD RPG
This commit is contained in:
+17
-11
@@ -45,6 +45,7 @@ def parse_card_v2(data: dict, card_id: str | None = None) -> dict:
|
||||
"first_mes": inner.get("first_mes", ""),
|
||||
"mes_example": inner.get("mes_example", ""),
|
||||
"appearance_tags": _extract_appearance(inner),
|
||||
"appearance_prose": "",
|
||||
"lorebook_json": json.dumps(entries, ensure_ascii=False),
|
||||
"alternate_greetings": alternates,
|
||||
"alternate_greetings_json": json.dumps(alternates, ensure_ascii=False),
|
||||
@@ -120,6 +121,8 @@ def parse_png_card(file_bytes: bytes) -> dict | None:
|
||||
|
||||
|
||||
def build_system_prompt(card: dict) -> str:
|
||||
from services.chat_prompt import ROLEPLAY_GUARDRAILS
|
||||
|
||||
parts = [
|
||||
f"You are {card['name']}. Stay in character.",
|
||||
f"Description: {card['description']}",
|
||||
@@ -129,6 +132,7 @@ def build_system_prompt(card: dict) -> str:
|
||||
if card.get("mes_example"):
|
||||
parts.append(f"Example dialogue:\n{card['mes_example']}")
|
||||
parts.append("Reply only as the character. Do not add image tags.")
|
||||
parts.append(ROLEPLAY_GUARDRAILS)
|
||||
return "\n\n".join(p for p in parts if p.split(": ", 1)[-1].strip())
|
||||
|
||||
|
||||
@@ -141,13 +145,13 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT OR REPLACE INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes,
|
||||
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json,
|
||||
avatar_path, alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
"""INSERT INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes, mes_example,
|
||||
raw_json, lora_name, lora_weight, appearance_tags, appearance_prose, lorebook_json, avatar_path,
|
||||
alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
card_id,
|
||||
card["card_id"],
|
||||
card["name"],
|
||||
card["description"],
|
||||
card["personality"],
|
||||
@@ -157,10 +161,11 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
|
||||
card["raw_json"],
|
||||
lora_name,
|
||||
lora_weight,
|
||||
card.get("appearance_tags", ""),
|
||||
card["appearance_tags"],
|
||||
card.get("appearance_prose", ""),
|
||||
card["lorebook_json"],
|
||||
card.get("avatar_path", ""),
|
||||
alt_json,
|
||||
card.get("alternate_greetings_json", "[]"),
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
@@ -199,8 +204,8 @@ async def delete_character(card_id: str) -> bool:
|
||||
async def update_appearance_tags(card_id: str, appearance_tags: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE characters SET appearance_tags = ? WHERE card_id = ?",
|
||||
(appearance_tags, card_id),
|
||||
"UPDATE characters SET appearance_tags = ?, appearance_prose = ? WHERE card_id = ?",
|
||||
(appearance_tags, "", card_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -228,7 +233,7 @@ async def preview_card_file(content: bytes, filename: str) -> dict:
|
||||
|
||||
async def update_character(card_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "personality", "scenario", "first_mes",
|
||||
"mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path",
|
||||
"mes_example", "appearance_tags", "appearance_prose", "lora_name", "lora_weight", "avatar_path",
|
||||
"alternate_greetings_json"}
|
||||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not updates:
|
||||
@@ -295,6 +300,7 @@ async def import_card_file(
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": saved.get("appearance_tags", ""),
|
||||
"appearance_prose": saved.get("appearance_prose", ""),
|
||||
"avatar_path": saved.get("avatar_path", ""),
|
||||
"personality": saved.get("personality", ""),
|
||||
"scenario": saved.get("scenario", ""),
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from services.personas import get_persona
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.character_card import get_character
|
||||
|
||||
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
||||
|
||||
from services.rp_sanitize import ROLEPLAY_GUARDRAILS # noqa: E402 — re-export for imports
|
||||
|
||||
|
||||
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
||||
"""Static character prompt only (no RPG runtime blocks)."""
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return DEFAULT_PROMPT
|
||||
prompt = persona["prompt"]
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
if persona.get("lorebook_json"):
|
||||
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
if persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card:
|
||||
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
if persona_id != "default":
|
||||
prompt += "\n\n" + ROLEPLAY_GUARDRAILS
|
||||
return prompt
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Parse ComfyUI /object_info into usable model lists."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Node types whose combo inputs we expose in the debug UI
|
||||
_MODEL_NODES: dict[str, tuple[str, str]] = {
|
||||
"checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
|
||||
"unets": ("UNETLoader", "unet_name"),
|
||||
"clips": ("CLIPLoader", "clip_name"),
|
||||
"vaes": ("VAELoader", "vae_name"),
|
||||
"loras": ("LoraLoader", "lora_name"),
|
||||
}
|
||||
|
||||
|
||||
def _combo_options(node_def: dict, input_name: str) -> list[str]:
|
||||
if not isinstance(node_def, dict):
|
||||
return []
|
||||
required = (node_def.get("input") or {}).get("required") or {}
|
||||
optional = (node_def.get("input") or {}).get("optional") or {}
|
||||
spec = required.get(input_name) or optional.get(input_name)
|
||||
if not spec or not isinstance(spec, (list, tuple)):
|
||||
return []
|
||||
first = spec[0]
|
||||
if isinstance(first, list):
|
||||
return [str(x) for x in first]
|
||||
return []
|
||||
|
||||
|
||||
def parse_model_lists(object_info: dict) -> dict[str, list[str]]:
|
||||
out: dict[str, list[str]] = {}
|
||||
for key, (node_type, input_name) in _MODEL_NODES.items():
|
||||
node_def = object_info.get(node_type) or {}
|
||||
options = _combo_options(node_def, input_name)
|
||||
if options:
|
||||
out[key] = options
|
||||
return out
|
||||
|
||||
|
||||
def list_node_types(object_info: dict) -> list[str]:
|
||||
return sorted(k for k in object_info.keys() if isinstance(object_info.get(k), dict))
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
CHAT_CONTEXT_MAX = int(os.getenv("CHAT_CONTEXT_MAX", "128000"))
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
return max(0, len(text or "") // 4)
|
||||
|
||||
|
||||
def compute_payload_usage(history: list, llm_system: str) -> dict:
|
||||
"""Estimate context fill for the payload messages_for_llm would send."""
|
||||
chars = len(llm_system or "")
|
||||
for m in history:
|
||||
if m.get("role") in ("user", "assistant"):
|
||||
chars += len(m.get("content") or "")
|
||||
tokens_est = chars // 4 if chars else 0
|
||||
max_tokens = CHAT_CONTEXT_MAX
|
||||
percent = round(100.0 * tokens_est / max_tokens, 1) if max_tokens else 0.0
|
||||
return {
|
||||
"chars": chars,
|
||||
"tokens_est": tokens_est,
|
||||
"max_tokens_est": max_tokens,
|
||||
"percent": percent,
|
||||
}
|
||||
|
||||
|
||||
def context_warning_line(percent: float) -> str:
|
||||
if percent <= 85:
|
||||
return ""
|
||||
return f"\n[Context: ~{int(percent)}% of budget — keep replies focused]"
|
||||
+119
-6
@@ -13,6 +13,8 @@ OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
|
||||
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
|
||||
# Softer model when primary returns content_filter / empty / API errors (default: CHAT_MODEL).
|
||||
LLM_FALLBACK_MODEL = (os.getenv("LLM_FALLBACK_MODEL") or "").strip() or CHAT_MODEL
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
||||
@@ -21,26 +23,128 @@ HEADERS = {
|
||||
}
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""OpenRouter returned an error or an unexpected response shape."""
|
||||
|
||||
|
||||
def _parse_completion_body(data: dict) -> str:
|
||||
if not isinstance(data, dict):
|
||||
raise LLMError(f"Invalid API response: expected object, got {type(data).__name__}")
|
||||
|
||||
if data.get("error"):
|
||||
err = data["error"]
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message") or str(err)
|
||||
code = err.get("code")
|
||||
else:
|
||||
msg = str(err)
|
||||
code = None
|
||||
suffix = f" (code={code})" if code is not None else ""
|
||||
raise LLMError(f"OpenRouter error{suffix}: {msg}")
|
||||
|
||||
choices = data.get("choices")
|
||||
if not choices:
|
||||
preview = str(data)[:400]
|
||||
raise LLMError(f"OpenRouter response has no 'choices'. Body preview: {preview}")
|
||||
|
||||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||||
message = first.get("message") or {}
|
||||
if not isinstance(message, dict):
|
||||
raise LLMError("OpenRouter choice has no message object")
|
||||
|
||||
finish = first.get("finish_reason") or ""
|
||||
native_finish = first.get("native_finish_reason") or ""
|
||||
blocked_reasons = {"content_filter", "safety", "moderation"}
|
||||
if finish in blocked_reasons or str(native_finish).upper() in (
|
||||
"PROHIBITED_CONTENT",
|
||||
"SAFETY",
|
||||
"BLOCKED",
|
||||
):
|
||||
raise LLMError(
|
||||
f"Content blocked by provider (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
content = message.get("content")
|
||||
if content is not None and str(content).strip():
|
||||
return str(content)
|
||||
|
||||
refusal = message.get("refusal")
|
||||
if refusal:
|
||||
raise LLMError(f"Model refused the request: {refusal}")
|
||||
|
||||
if finish and finish not in ("stop", "length", "tool_calls", "function_call"):
|
||||
raise LLMError(
|
||||
f"OpenRouter finished without content (finish_reason={finish}, native={native_finish})"
|
||||
)
|
||||
|
||||
raise LLMError("OpenRouter returned empty message content")
|
||||
|
||||
|
||||
def _clean(messages: list) -> list:
|
||||
"""Filter out messages with empty content."""
|
||||
return [m for m in messages if (m.get("content") or "").strip()]
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
async def _post_once(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
if not OPENROUTER_KEY:
|
||||
raise LLMError("ROUTER_KEY is not set in environment")
|
||||
|
||||
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
|
||||
async with httpx.AsyncClient(timeout=90) as client:
|
||||
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
try:
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
raise LLMError(f"Non-JSON response (HTTP {r.status_code}): {r.text[:300]}") from e
|
||||
|
||||
if r.status_code >= 400:
|
||||
try:
|
||||
_parse_completion_body(data)
|
||||
except LLMError:
|
||||
raise
|
||||
raise LLMError(f"HTTP {r.status_code}: {data}")
|
||||
|
||||
try:
|
||||
return _parse_completion_body(data)
|
||||
except LLMError:
|
||||
logger.warning(
|
||||
"OpenRouter completion failed model=%s status=%s body=%.500s",
|
||||
model,
|
||||
r.status_code,
|
||||
data,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
"""POST completion; on failure retries once with LLM_FALLBACK_MODEL (usually CHAT_MODEL)."""
|
||||
try:
|
||||
return await _post_once(model, messages, extra)
|
||||
except LLMError as primary_err:
|
||||
fallback = LLM_FALLBACK_MODEL
|
||||
if not fallback or fallback == model:
|
||||
raise
|
||||
logger.info(
|
||||
"LLM fallback: %s failed (%s) → retrying with %s",
|
||||
model,
|
||||
primary_err,
|
||||
fallback,
|
||||
)
|
||||
try:
|
||||
return await _post_once(fallback, messages, extra)
|
||||
except LLMError as fallback_err:
|
||||
raise LLMError(
|
||||
f"{primary_err} (fallback {fallback} also failed: {fallback_err})"
|
||||
) from fallback_err
|
||||
|
||||
|
||||
async def send_message(messages: list) -> str:
|
||||
"""System model — narrator, facts, SD prompt."""
|
||||
"""SYSTEM_MODEL with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(SYSTEM_MODEL, messages)
|
||||
|
||||
|
||||
async def send_message_with_model(messages: list, model: str) -> str:
|
||||
"""Explicit model — plot arc, narrator override."""
|
||||
"""Named model (RPG_*, SD_*) with automatic fallback to LLM_FALLBACK_MODEL."""
|
||||
return await _post(model, messages)
|
||||
|
||||
|
||||
@@ -73,10 +177,19 @@ async def stream_message(messages: list):
|
||||
return
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if chunk.get("error"):
|
||||
err = chunk["error"]
|
||||
msg = err.get("message", err) if isinstance(err, dict) else err
|
||||
raise LLMError(f"OpenRouter stream error: {msg}")
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
content = (choices[0].get("delta") or {}).get("content", "")
|
||||
if content:
|
||||
chunk_count += 1
|
||||
yield content
|
||||
except LLMError:
|
||||
raise
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
|
||||
+308
-26
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
|
||||
async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict:
|
||||
"""Existing sessions keep their persona_id; persona_id applies only on INSERT."""
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
@@ -13,9 +16,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") ->
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
pid = (persona_id or "default").strip() or "default"
|
||||
await db.execute(
|
||||
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
|
||||
(session_id, persona_id),
|
||||
(session_id, pid),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -71,24 +75,104 @@ async def update_session_persona(session_id: str, persona_id: str):
|
||||
(persona_id, session_id),
|
||||
)
|
||||
|
||||
# If persona changed, reset RPG state bound to the persona/arc.
|
||||
if prev is not None and prev != persona_id:
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}'
|
||||
WHERE session_id = ?""",
|
||||
(session_id,),
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM action_resolutions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await _reset_persona_bound_state(db, session_id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None:
|
||||
from services.rpg_state import DEFAULT_NARRATIVE_STATS
|
||||
|
||||
stats_default = json.dumps(DEFAULT_NARRATIVE_STATS, ensure_ascii=False)
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}',
|
||||
outfit_json = '[]',
|
||||
affinity = 0,
|
||||
scene_json = '{}',
|
||||
narrative_stats_json = ?
|
||||
WHERE session_id = ?""",
|
||||
(stats_default, session_id),
|
||||
)
|
||||
await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,))
|
||||
await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,))
|
||||
|
||||
|
||||
async def upsert_static_system_message(
|
||||
session_id: str, static_prompt: str, history: list | None = None
|
||||
) -> bool:
|
||||
"""Store only static persona prompt in messages. Returns True if written."""
|
||||
hist = history if history is not None else await get_history(session_id)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
if not hist:
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
if hist[0]["role"] == "system":
|
||||
if hist[0]["content"] == static_prompt:
|
||||
return False
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(static_prompt, session_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def delete_dialog_messages(session_id: str) -> None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def rebind_session_persona(
|
||||
session_id: str,
|
||||
persona_id: str,
|
||||
*,
|
||||
clear_history: bool = False,
|
||||
static_prompt: str,
|
||||
first_mes: str | None = None,
|
||||
) -> None:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
await update_session_persona(session_id, persona_id)
|
||||
if clear_history:
|
||||
await delete_dialog_messages(session_id)
|
||||
|
||||
history = await get_history(session_id)
|
||||
await upsert_static_system_message(session_id, static_prompt, history)
|
||||
|
||||
if clear_history and first_mes and first_mes.strip():
|
||||
await add_message(session_id, "assistant", first_mes.strip())
|
||||
|
||||
|
||||
async def update_session_rpg(session_id: str, rpg_enabled: bool):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
@@ -174,25 +258,116 @@ async def delete_session(session_id: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
async def get_action_resolutions_map(session_id: str) -> dict[int, dict]:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id, role, content, image_prompt, image_path
|
||||
"""SELECT message_id, intent_text, roll, outcome, resolution_text
|
||||
FROM action_resolutions
|
||||
WHERE session_id = ? AND message_id IS NOT NULL
|
||||
ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
out: dict[int, dict] = {}
|
||||
for r in rows:
|
||||
mid = r["message_id"]
|
||||
if mid is not None:
|
||||
out[int(mid)] = {
|
||||
"intent_text": r["intent_text"],
|
||||
"roll": r["roll"],
|
||||
"outcome": r["outcome"],
|
||||
"resolution_text": r["resolution_text"],
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def narrator_message_content(narrator: dict) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"roll": narrator.get("roll"),
|
||||
"outcome": narrator.get("outcome"),
|
||||
"text": narrator.get("text", ""),
|
||||
"original_intent": narrator.get("original_intent"),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def parse_narrator_message(content: str) -> dict | None:
|
||||
try:
|
||||
data = json.loads(content or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if not isinstance(data, dict) or not (data.get("text") or "").strip():
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
async def seed_quests_from_arc(session_id: str, arc: dict) -> int:
|
||||
"""Create active quests for arc beats that are not already in rpg_quests."""
|
||||
if not arc:
|
||||
return 0
|
||||
existing = {q["title"] for q in await get_quests(session_id)}
|
||||
added = 0
|
||||
for beat in arc.get("beats", []):
|
||||
title = (beat.get("title") or beat.get("injection", "")).strip()[:120]
|
||||
if title and title not in existing:
|
||||
await upsert_quest(session_id, title, "active")
|
||||
existing.add(title)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
resolutions = await get_action_resolutions_map(session_id)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id, role, content, image_prompt, image_path,
|
||||
image_prompt_alt, image_path_alt, choices_json
|
||||
FROM messages WHERE session_id = ? ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
result = []
|
||||
for idx, r in enumerate(rows):
|
||||
item = {
|
||||
"id": r["id"],
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"image_prompt": r["image_prompt"],
|
||||
"image_path": r["image_path"],
|
||||
"image_prompt_alt": r["image_prompt_alt"],
|
||||
"image_path_alt": r["image_path_alt"],
|
||||
"choices_json": r["choices_json"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
if r["role"] == "user" and r["id"] in resolutions:
|
||||
item["action_resolution"] = resolutions[r["id"]]
|
||||
result.append(item)
|
||||
if r["role"] == "user" and r["id"] in resolutions:
|
||||
nxt = rows[idx + 1] if idx + 1 < len(rows) else None
|
||||
if not nxt or nxt["role"] != "narrator":
|
||||
res = resolutions[r["id"]]
|
||||
result.append(
|
||||
{
|
||||
"id": -int(r["id"]),
|
||||
"role": "narrator",
|
||||
"content": narrator_message_content(
|
||||
{
|
||||
"roll": res.get("roll"),
|
||||
"outcome": res.get("outcome"),
|
||||
"text": res.get("resolution_text", ""),
|
||||
}
|
||||
),
|
||||
"image_prompt": None,
|
||||
"image_path": None,
|
||||
"image_prompt_alt": None,
|
||||
"image_path_alt": None,
|
||||
"choices_json": None,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def get_message(message_id: int) -> dict | None:
|
||||
@@ -230,6 +405,38 @@ async def delete_message(message_id: int):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def delete_message_and_following(session_id: str, message_id: int) -> bool:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ? AND id >= ?",
|
||||
(session_id, message_id),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def update_message_choices(message_id: int, choices_json: str | None):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET choices_json = ? WHERE id = ?",
|
||||
(choices_json, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def clear_choices_for_session(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET choices_json = NULL WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_message_preview(session_id: str, max_len: int = 80) -> str:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
@@ -261,8 +468,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
|
||||
await db.execute(
|
||||
"""INSERT INTO sessions
|
||||
(session_id, persona_id, title, rpg_enabled, facts_json, global_plot,
|
||||
status_quo, plot_arc_json, genre, rpg_settings_json, affinity)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
status_quo, plot_arc_json, genre, rpg_settings_json, affinity,
|
||||
outfit_json, scene_json, narrative_stats_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
new_id,
|
||||
source["persona_id"],
|
||||
@@ -275,6 +483,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
|
||||
source.get("genre", "adventure"),
|
||||
source.get("rpg_settings_json", "{}"),
|
||||
source.get("affinity", 0),
|
||||
source.get("outfit_json", "[]"),
|
||||
source.get("scene_json", "{}"),
|
||||
source.get("narrative_stats_json", '{"lust":0,"stamina":10,"tension":0}'),
|
||||
),
|
||||
)
|
||||
async with db.execute(
|
||||
@@ -309,18 +520,20 @@ async def add_message(
|
||||
content: str,
|
||||
image_prompt: str | None = None,
|
||||
image_path: str | None = None,
|
||||
):
|
||||
) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
cur = await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, role, content, image_prompt, image_path),
|
||||
)
|
||||
msg_id = cur.lastrowid
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return msg_id
|
||||
|
||||
|
||||
async def update_message_image(message_id: int, image_path: str):
|
||||
@@ -332,6 +545,33 @@ async def update_message_image(message_id: int, image_path: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt(message_id: int, image_prompt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt = ? WHERE id = ?",
|
||||
(image_prompt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt_alt(message_id: int, image_prompt_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt_alt = ? WHERE id = ?",
|
||||
(image_prompt_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_image_alt(message_id: int, image_path_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_path_alt = ? WHERE id = ?",
|
||||
(image_path_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_assistant_message_id(session_id: str) -> int | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
@@ -362,6 +602,18 @@ async def update_session_affinity(session_id: str, delta: int):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def set_session_affinity(session_id: str, value: int):
|
||||
"""Debug / admin: set absolute affinity (-30..30)."""
|
||||
clamped = max(-30, min(30, int(value)))
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET affinity = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(clamped, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return clamped
|
||||
|
||||
|
||||
async def update_session_genre(session_id: str, genre: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
@@ -389,6 +641,24 @@ async def update_session_outfit(session_id: str, outfit_json: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_scene(session_id: str, scene_json: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET scene_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(scene_json, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_narrative_stats(session_id: str, stats_json: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET narrative_stats_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(stats_json, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def upsert_quest(session_id: str, title: str, status: str = "active"):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
async with db.execute(
|
||||
@@ -429,6 +699,18 @@ async def update_quest_status(session_id: str, title: str, status: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_quest_by_id(quest_id: int, session_id: str, status: str) -> bool:
|
||||
if status not in ("active", "done", "failed"):
|
||||
return False
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
cur = await db.execute(
|
||||
"UPDATE rpg_quests SET status = ? WHERE id = ? AND session_id = ?",
|
||||
(status, quest_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def get_message_count(session_id: str) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.memory import (
|
||||
get_history,
|
||||
get_session,
|
||||
get_last_assistant_message_id,
|
||||
update_session_plot_arc,
|
||||
update_message_choices,
|
||||
seed_quests_from_arc,
|
||||
get_quests,
|
||||
)
|
||||
from services.rpg_state import apply_narrator_post
|
||||
from services.personas import get_persona
|
||||
from services.rpg_facts import facts_to_prompt
|
||||
from services.rpg_plot import generate_plot_arc, choices_from_narrator
|
||||
from services.rpg_context import format_narrator_context
|
||||
from services.rpg_narrator import narrator_post
|
||||
from services.sd_prompt import generate_sd_prompt
|
||||
from services.sd_images import run_sd_for_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RPG_SETTINGS = {
|
||||
"dice": True,
|
||||
"narrator": True,
|
||||
"quests": True,
|
||||
"affinity": True,
|
||||
"choices": True,
|
||||
"stats": False,
|
||||
}
|
||||
|
||||
|
||||
def get_rpg_settings(session: dict) -> dict:
|
||||
try:
|
||||
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
||||
except Exception:
|
||||
return DEFAULT_RPG_SETTINGS
|
||||
|
||||
|
||||
async def resolve_greeting(session_id: str, persona: dict) -> str:
|
||||
history = await get_history(session_id)
|
||||
for m in reversed(history):
|
||||
if m.get("role") == "assistant" and (m.get("content") or "").strip():
|
||||
return m["content"].strip()
|
||||
return (persona.get("first_mes") or "").strip()
|
||||
|
||||
|
||||
async def ensure_plot_arc_and_quests(
|
||||
session_id: str,
|
||||
persona: dict,
|
||||
greeting: str,
|
||||
genre: str,
|
||||
*,
|
||||
seed_quests: bool = True,
|
||||
) -> dict:
|
||||
session = await get_session(session_id) or {}
|
||||
arc_json = session.get("plot_arc_json") or "{}"
|
||||
try:
|
||||
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
||||
except Exception:
|
||||
arc = {}
|
||||
|
||||
if arc:
|
||||
return arc
|
||||
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", "Character"),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
greeting,
|
||||
facts_block=facts_block,
|
||||
genre=genre,
|
||||
)
|
||||
if not arc:
|
||||
return {}
|
||||
|
||||
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
||||
if seed_quests:
|
||||
await seed_quests_from_arc(session_id, arc)
|
||||
return arc
|
||||
|
||||
|
||||
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
history = await get_history(session_id)
|
||||
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
|
||||
if not assistant_msgs:
|
||||
raise ValueError("No assistant message (first_mes) found")
|
||||
|
||||
first_mes_text = assistant_msgs[-1].get("content", "").strip()
|
||||
if not first_mes_text:
|
||||
raise ValueError("Empty first_mes")
|
||||
|
||||
msg_id = await get_last_assistant_message_id(session_id)
|
||||
persona = await get_persona(persona_id) or {}
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
|
||||
arc: dict = {}
|
||||
choices: list = []
|
||||
status_quo = session.get("status_quo") or ""
|
||||
outfit_json = session.get("outfit_json") or "[]"
|
||||
|
||||
if rpg:
|
||||
genre = session.get("genre") or "adventure"
|
||||
arc = await ensure_plot_arc_and_quests(
|
||||
session_id,
|
||||
persona,
|
||||
first_mes_text,
|
||||
genre,
|
||||
seed_quests=rpg_settings.get("quests", True),
|
||||
)
|
||||
|
||||
session = await get_session(session_id) or session
|
||||
ctx_txt = f"assistant: {first_mes_text}"
|
||||
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
|
||||
quests_pre = await get_quests(session_id)
|
||||
narr_ctx = format_narrator_context(arc, quests_pre, session.get("status_quo") or "")
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
arc_json,
|
||||
facts_block,
|
||||
is_opening=True,
|
||||
extra_context=narr_ctx,
|
||||
)
|
||||
|
||||
if rpg_settings.get("choices", True):
|
||||
choices = choices_from_narrator(post.get("choices") or [])
|
||||
|
||||
await apply_narrator_post(session_id, post, rpg_settings, session)
|
||||
session = await get_session(session_id) or session
|
||||
status_quo = session.get("status_quo") or status_quo
|
||||
outfit_json = session.get("outfit_json") or outfit_json
|
||||
|
||||
quests = await get_quests(session_id)
|
||||
messages = await get_history(session_id)
|
||||
bundle = await generate_sd_prompt(
|
||||
messages,
|
||||
persona_id,
|
||||
outfit_json=outfit_json,
|
||||
scene_json=session.get("scene_json", "{}") if session else "{}",
|
||||
)
|
||||
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
|
||||
|
||||
updated = await get_session(session_id)
|
||||
affinity = updated.get("affinity", 0) if updated else 0
|
||||
|
||||
if msg_id and choices:
|
||||
await update_message_choices(
|
||||
msg_id, json.dumps(choices, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return {
|
||||
"plot_arc": arc,
|
||||
"quests": quests,
|
||||
"outfit_json": outfit_json,
|
||||
"status_quo": status_quo,
|
||||
"choices": choices,
|
||||
"image_prompt": sd_out.get("image_prompt"),
|
||||
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
||||
"image_path": sd_out.get("image_path"),
|
||||
"image_path_alt": sd_out.get("image_path_alt"),
|
||||
"image_error": sd_out.get("image_error"),
|
||||
"image_error_alt": sd_out.get("image_error_alt"),
|
||||
"affinity": affinity,
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Danbooru-style outfit tags with color enrichment for stable SD prompts."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
COLOR_TOKENS = frozenset({
|
||||
"white", "black", "red", "blue", "green", "yellow", "pink", "purple",
|
||||
"orange", "brown", "gray", "grey", "silver", "gold", "beige", "navy",
|
||||
"blonde", "dark", "light", "cyan", "teal", "maroon", "crimson",
|
||||
})
|
||||
|
||||
# garment substring -> default color when tag has no color word
|
||||
_GARMENT_DEFAULT_COLOR: list[tuple[str, str]] = [
|
||||
("championship_belt", "gold"),
|
||||
("belt_collar", "gold"),
|
||||
("belt", "brown"),
|
||||
("sports_shorts", "black"),
|
||||
("shorts", "black"),
|
||||
("torn_tank_top", "white"),
|
||||
("tank_top", "white"),
|
||||
("crop_top", "white"),
|
||||
("t_shirt", "white"),
|
||||
("shirt", "white"),
|
||||
("dress", "black"),
|
||||
("skirt", "black"),
|
||||
("jeans", "blue"),
|
||||
("pants", "black"),
|
||||
("hoodie", "gray"),
|
||||
("jacket", "black"),
|
||||
("ribbon", "red"),
|
||||
("collar", "black"),
|
||||
("boots", "black"),
|
||||
("socks", "white"),
|
||||
]
|
||||
|
||||
|
||||
def _clean_tag(raw: str) -> str:
|
||||
t = (raw or "").strip().lower()
|
||||
t = re.sub(r"[^\w]+", "_", t)
|
||||
t = re.sub(r"_+", "_", t).strip("_")
|
||||
return t
|
||||
|
||||
|
||||
def tag_has_color(tag: str) -> bool:
|
||||
parts = tag.lower().split("_")
|
||||
return any(p in COLOR_TOKENS for p in parts)
|
||||
|
||||
|
||||
def enrich_outfit_tag(tag: str) -> str:
|
||||
"""Add a color prefix when the tag names clothing but omits color."""
|
||||
t = _clean_tag(tag)
|
||||
if not t or tag_has_color(t):
|
||||
return t
|
||||
for needle, color in _GARMENT_DEFAULT_COLOR:
|
||||
if needle in t:
|
||||
return f"{color}_{t}"
|
||||
return t
|
||||
|
||||
|
||||
def normalize_outfit_list(raw: list | None) -> list[str]:
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
if not isinstance(raw, list):
|
||||
return out
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
t = enrich_outfit_tag(item)
|
||||
elif isinstance(item, dict):
|
||||
label = (item.get("tag") or item.get("label") or "").strip()
|
||||
color = (item.get("color") or "").strip().lower()
|
||||
base = _clean_tag(label)
|
||||
if not base:
|
||||
continue
|
||||
t = f"{color}_{base}" if color and not tag_has_color(base) else enrich_outfit_tag(base)
|
||||
else:
|
||||
continue
|
||||
if t and t not in seen:
|
||||
seen.add(t)
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def outfit_list_to_json(tags: list[str]) -> str:
|
||||
return json.dumps(normalize_outfit_list(tags), ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_and_normalize_outfit_json(raw: str | None) -> str:
|
||||
try:
|
||||
data = json.loads(raw or "[]")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = []
|
||||
if isinstance(data, str):
|
||||
data = [x.strip() for x in data.split(",") if x.strip()]
|
||||
return outfit_list_to_json(data if isinstance(data, list) else [])
|
||||
+24
-4
@@ -63,6 +63,7 @@ def _row_to_persona(row: dict) -> dict:
|
||||
"lora_name": row["lora_name"] or "",
|
||||
"lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8,
|
||||
"appearance_tags": row["appearance_tags"] or "",
|
||||
"appearance_prose": row.get("appearance_prose", "") or "",
|
||||
"personality": row.get("personality", "") or "",
|
||||
"scenario": row.get("scenario", "") or "",
|
||||
"first_mes": row.get("first_mes", "") or "",
|
||||
@@ -84,6 +85,9 @@ def build_persona_prompt(data: dict) -> str:
|
||||
if ex:
|
||||
parts.append(f"Example dialogue:\n{ex}")
|
||||
parts.append("Stay in character. Reply as the character. Do not add image tags.")
|
||||
from services.chat_prompt import ROLEPLAY_GUARDRAILS
|
||||
|
||||
parts.append(ROLEPLAY_GUARDRAILS)
|
||||
return "\n\n".join(p for p in parts if p and p.split(": ", 1)[-1].strip())
|
||||
|
||||
|
||||
@@ -117,6 +121,7 @@ async def create_persona(
|
||||
lora_name: str = "",
|
||||
lora_weight: float = 0.8,
|
||||
appearance_tags: str = "",
|
||||
appearance_prose: str = "",
|
||||
personality: str = "",
|
||||
scenario: str = "",
|
||||
first_mes: str = "",
|
||||
@@ -138,19 +143,19 @@ async def create_persona(
|
||||
await db.execute(
|
||||
"""INSERT INTO personas
|
||||
(persona_id, name, emoji, description, prompt, custom,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
persona_id, name, emoji, description, final_prompt,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
return {
|
||||
"name": name,
|
||||
"emoji": emoji,
|
||||
"description": description,
|
||||
@@ -160,6 +165,7 @@ async def create_persona(
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": appearance_tags,
|
||||
"appearance_prose": appearance_prose,
|
||||
"personality": personality,
|
||||
"scenario": scenario,
|
||||
"first_mes": first_mes,
|
||||
@@ -226,6 +232,7 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||||
"lora_name",
|
||||
"lora_weight",
|
||||
"appearance_tags",
|
||||
"appearance_prose",
|
||||
"personality",
|
||||
"scenario",
|
||||
"first_mes",
|
||||
@@ -255,6 +262,19 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||||
merged = dict(existing)
|
||||
merged.update(updates)
|
||||
updates["prompt"] = build_persona_prompt(merged)
|
||||
|
||||
if "appearance_tags" in updates and "appearance_prose" not in updates:
|
||||
tags = updates["appearance_tags"].strip()
|
||||
if tags:
|
||||
from services.llm import send_message
|
||||
try:
|
||||
prose = await send_message([
|
||||
{"role": "system", "content": "Convert danbooru tags to natural English description. Output only the description, no markdown."},
|
||||
{"role": "user", "content": f"Tags: {tags}"}
|
||||
])
|
||||
updates["appearance_prose"] = prose.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cols = ", ".join(f"{k} = ?" for k in updates)
|
||||
cur2 = await db.execute(
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Roleplay output guardrails and OOC stripping."""
|
||||
|
||||
import re
|
||||
|
||||
ROLEPLAY_GUARDRAILS = (
|
||||
"[In-character rules — breaking these ruins immersion]\n"
|
||||
"- Reply ONLY as the character in the present moment: spoken lines, visible actions, "
|
||||
"thoughts they would actually show.\n"
|
||||
"- NEVER write: P.S., PS, postscripts, footnotes, section headers, "
|
||||
'"Статус кво", "Status quo", "To be continued", scene summaries, '
|
||||
"editorial closings, or foreshadowing asides "
|
||||
'(e.g. "Когда вы выйдете...", "Ты уже знаешь правду...").\n'
|
||||
"- Do NOT explain subtext to the reader or predict future scenes. No narrator voice.\n"
|
||||
"- End inside the scene; do not wrap up with meta commentary.\n"
|
||||
"- Sections marked MANDATORY (relationship, state) are binding — obey without citing them."
|
||||
)
|
||||
|
||||
RP_OUTPUT_REMINDER = (
|
||||
"\n\n--- Reply format (MANDATORY) ---\n"
|
||||
"Next message = in-character only. "
|
||||
"Forbidden in output: P.S., Статус кво, Status quo, author notes, summaries, footnotes.\n"
|
||||
"---"
|
||||
)
|
||||
|
||||
|
||||
def status_quo_prompt_block(status_quo: str) -> str:
|
||||
sq = (status_quo or "").strip()
|
||||
if not sq:
|
||||
return ""
|
||||
return (
|
||||
"\n\n--- Current situation (INTERNAL — player never sees this block) ---\n"
|
||||
+ sq
|
||||
+ "\nBackground truth for you only. "
|
||||
"Never echo this header, never open/close replies with 'Статус кво' or summaries. "
|
||||
"Show the situation through dialogue and action.\n---"
|
||||
)
|
||||
|
||||
|
||||
_OOC_PARA_START = re.compile(
|
||||
r"^(?:"
|
||||
r"Статус\s*кво|Status\s*quo|"
|
||||
r"P\.?\s*S\.?|PS:|"
|
||||
r"Примечание:|Author'?s?\s*note:|"
|
||||
r"OOC:|\\[OOC\\]|"
|
||||
r"To be continued|Продолжение следует"
|
||||
r")\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def strip_ooc_from_reply(text: str) -> str:
|
||||
"""Remove common OOC tails (P.S., Статус кво paragraphs, etc.)."""
|
||||
if not text or not text.strip():
|
||||
return text or ""
|
||||
|
||||
out = text.rstrip()
|
||||
|
||||
# Drop trailing P.S. block (often last paragraph).
|
||||
out = re.sub(
|
||||
r"(?is)\n\s*P\.?\s*S\.?\s*[:.\-—].*$",
|
||||
"",
|
||||
out,
|
||||
).rstrip()
|
||||
|
||||
parts = re.split(r"\n\s*\n", out)
|
||||
kept: list[str] = []
|
||||
for part in parts:
|
||||
lines = [ln for ln in part.splitlines() if ln.strip()]
|
||||
if not lines:
|
||||
continue
|
||||
if _OOC_PARA_START.match(lines[0].strip()):
|
||||
continue
|
||||
kept.append(part)
|
||||
|
||||
if not kept:
|
||||
return ""
|
||||
return "\n\n".join(kept).strip()
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Shared context blocks for RPG narrator / plot LLM calls."""
|
||||
|
||||
from services.rpg_plot import count_active_quests
|
||||
|
||||
|
||||
def format_narrator_context(
|
||||
arc: dict | None,
|
||||
quests: list | None,
|
||||
status_quo: str = "",
|
||||
) -> str:
|
||||
parts: list[str] = []
|
||||
arc = arc or {}
|
||||
beats = arc.get("beats") or []
|
||||
if not isinstance(beats, list):
|
||||
beats = []
|
||||
|
||||
parts.append(f"Plot phase: {arc.get('phase', 'opening')}. Scripted beats left: {len(beats)}.")
|
||||
if not beats:
|
||||
parts.append(
|
||||
"IMPORTANT: Scripted beats are EXHAUSTED (quests may already be done). "
|
||||
"The story must CONTINUE — do not stall. "
|
||||
"Always return 2-4 meaningful choices for the player's next actions. "
|
||||
"You may add quest_updates with status 'active' for NEW optional threads. "
|
||||
"Do NOT re-activate quests the player already completed unless they explicitly revisit that thread."
|
||||
)
|
||||
elif count_active_quests(quests) == 0:
|
||||
pending = [
|
||||
(b.get("title") or b.get("id") or "beat")
|
||||
for b in beats[:3]
|
||||
if isinstance(b, dict)
|
||||
]
|
||||
parts.append(
|
||||
"IMPORTANT: No active quests but scripted beats remain — arc was likely desynced. "
|
||||
"The engine will inject the next beat; prefer choices that fit pending beats: "
|
||||
+ ", ".join(pending)
|
||||
+ ". Do NOT treat the arc as finished."
|
||||
)
|
||||
hint = (arc.get("next_beat_hint") or "").strip()
|
||||
if hint:
|
||||
parts.append(f"Arc hint: {hint}")
|
||||
|
||||
if quests:
|
||||
parts.append("Quest log:")
|
||||
for q in quests:
|
||||
parts.append(f" [{q.get('status', 'active')}] {q.get('title', '')}")
|
||||
else:
|
||||
parts.append("Quest log: (empty)")
|
||||
|
||||
sq = (status_quo or "").strip()
|
||||
if sq:
|
||||
parts.append(f"Status quo: {sq[:400]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
+340
-46
@@ -1,76 +1,370 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from services.llm import send_message_with_model, send_message
|
||||
from services.llm import LLMError, send_message_with_model, send_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FACTS_MODEL = os.getenv("RPG_FACTS_MODEL", "").strip() or "deepseek/deepseek-chat-v3"
|
||||
FACTS_STORE_LIMIT = int(os.getenv("FACTS_STORE_LIMIT", "100"))
|
||||
FACTS_PROMPT_MAX = int(os.getenv("FACTS_PROMPT_MAX", "40"))
|
||||
FACTS_DEDUP_THRESHOLD = int(os.getenv("FACTS_DEDUP_THRESHOLD", "30"))
|
||||
FACTS_COMPRESS_TARGET = int(os.getenv("FACTS_COMPRESS_TARGET", "22"))
|
||||
|
||||
FACTS_SYSTEM = """Extract NEW stable facts from the conversation.
|
||||
Return ONLY valid JSON (no markdown), as an array of objects:
|
||||
[{"text": "short durable fact", "rp_day": "when this became true in story time"}]
|
||||
|
||||
FACTS_SYSTEM = """Extract stable facts from the conversation.
|
||||
Return ONLY valid JSON (no markdown), as an array of short strings.
|
||||
Rules:
|
||||
- Facts must be durable (names, relations, inventory, locations, world rules).
|
||||
- Do not include ephemeral actions unless they change state.
|
||||
- Avoid duplicates.
|
||||
- Keep each fact <= 120 chars.
|
||||
Example output:
|
||||
["User name is Alex", "We are in a ruined castle", "NPC Mira distrusts the user"]"""
|
||||
- Return at most 5 NEW facts per turn. If nothing new, return [].
|
||||
- Do NOT repeat or rephrase facts already listed under "Already known".
|
||||
- Facts must be durable (names, relations, inventory, locations, lasting world state).
|
||||
- Skip momentary emotions unless they permanently change a relationship.
|
||||
- text <= 120 chars each.
|
||||
- rp_day: in-world time label (день 1, второй день, та же ночь, через год). Use RP time hint when unclear."""
|
||||
|
||||
FACTS_COMPRESS_SYSTEM = """You consolidate RPG session memory for a long-running chat.
|
||||
Return ONLY valid JSON (no markdown): an array of {"text": "...", "rp_day": "..."}.
|
||||
|
||||
Goals:
|
||||
- Aggressively MERGE near-duplicates (same topic in RU/EN, Rin/Рин, Grigo/Григорий).
|
||||
- Keep ONE best fact per topic; combine rp_day if needed (e.g. "день 1–2").
|
||||
- DROP redundant, trivial, or superseded facts.
|
||||
- Keep: names, relationships, key locations, lasting magic/rules, inventory, unresolved threads.
|
||||
- Target at most {target} facts (fewer is better). Each text <= 120 chars.
|
||||
- rp_day = in-world labels only."""
|
||||
|
||||
_NAME_ALIASES = (
|
||||
("grigoriy", "григорий"),
|
||||
("grigo", "григо"),
|
||||
("grigory", "григорий"),
|
||||
("rin", "рин"),
|
||||
("player", "игрок"),
|
||||
("user", "игрок"),
|
||||
("glade", "полян"),
|
||||
("flowers", "цвет"),
|
||||
("flower", "цвет"),
|
||||
("magical", "волшеб"),
|
||||
("magic", "волшеб"),
|
||||
("glow", "свет"),
|
||||
("glowing", "свет"),
|
||||
)
|
||||
|
||||
|
||||
def merge_facts(existing_json: str, new_facts: list[str], limit: int = 80) -> str:
|
||||
def parse_fact_entry(raw) -> dict | None:
|
||||
if isinstance(raw, dict):
|
||||
text = (raw.get("text") or raw.get("fact") or "").strip()
|
||||
rp_day = (raw.get("rp_day") or raw.get("learned") or raw.get("day") or "").strip()
|
||||
elif isinstance(raw, str):
|
||||
text = raw.strip()
|
||||
rp_day = ""
|
||||
else:
|
||||
return None
|
||||
if not text:
|
||||
return None
|
||||
return {"text": text[:120], "rp_day": rp_day[:80]}
|
||||
|
||||
|
||||
def parse_facts_list(facts_json: str | None) -> list[dict]:
|
||||
try:
|
||||
existing = json.loads(existing_json or "[]")
|
||||
if not isinstance(existing, list):
|
||||
existing = []
|
||||
data = json.loads(facts_json or "[]")
|
||||
except json.JSONDecodeError:
|
||||
existing = []
|
||||
|
||||
seen = {str(x).strip() for x in existing if str(x).strip()}
|
||||
merged = [str(x).strip() for x in existing if str(x).strip()]
|
||||
for f in new_facts:
|
||||
s = str(f).strip()
|
||||
if not s or s in seen:
|
||||
return []
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
out: list[dict] = []
|
||||
seen: set[str] = set()
|
||||
for item in data:
|
||||
entry = parse_fact_entry(item)
|
||||
if not entry:
|
||||
continue
|
||||
seen.add(s)
|
||||
merged.append(s)
|
||||
|
||||
if len(merged) > limit:
|
||||
merged = merged[-limit:]
|
||||
return json.dumps(merged, ensure_ascii=False)
|
||||
key = entry["text"].lower()
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
|
||||
async def extract_facts(context_messages: list[dict]) -> list[str]:
|
||||
# Build a compact transcript
|
||||
def facts_list_to_json(facts: list[dict]) -> str:
|
||||
return json.dumps(
|
||||
[{"text": f["text"], "rp_day": f.get("rp_day", "")} for f in facts],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def rp_day_from_scene(scene_json: str | None) -> str:
|
||||
try:
|
||||
scene = json.loads(scene_json or "{}")
|
||||
if isinstance(scene, dict):
|
||||
day = (scene.get("day") or "").strip()
|
||||
if day:
|
||||
return day[:80]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _normalize_fact_text(text: str) -> str:
|
||||
t = (text or "").lower()
|
||||
for a, b in _NAME_ALIASES:
|
||||
t = t.replace(a, b)
|
||||
t = re.sub(r"[^\w\s]", " ", t, flags=re.UNICODE)
|
||||
return re.sub(r"\s+", " ", t).strip()
|
||||
|
||||
|
||||
def _fact_tokens(text: str) -> set[str]:
|
||||
words = re.findall(r"[\w]+", _normalize_fact_text(text), flags=re.UNICODE)
|
||||
stop = {"the", "and", "that", "with", "this", "have", "has", "was", "are", "для", "что", "как", "это", "на", "в", "и", "а"}
|
||||
return {w for w in words if len(w) > 2 and w not in stop}
|
||||
|
||||
|
||||
def facts_are_similar(a: str, b: str) -> bool:
|
||||
al, bl = a.lower().strip(), b.lower().strip()
|
||||
if al == bl:
|
||||
return True
|
||||
shorter, longer = (al, bl) if len(al) <= len(bl) else (bl, al)
|
||||
if shorter in longer and len(shorter) / max(len(longer), 1) >= 0.35:
|
||||
return True
|
||||
ta, tb = _fact_tokens(a), _fact_tokens(b)
|
||||
if not ta or not tb:
|
||||
return False
|
||||
overlap = len(ta & tb) / len(ta | tb)
|
||||
return overlap >= 0.32
|
||||
|
||||
|
||||
def dedupe_facts_fuzzy(facts: list[dict]) -> list[dict]:
|
||||
out: list[dict] = []
|
||||
for f in facts:
|
||||
placed = False
|
||||
for i, existing in enumerate(out):
|
||||
if facts_are_similar(f["text"], existing["text"]):
|
||||
if len(f["text"]) > len(existing["text"]):
|
||||
out[i]["text"] = f["text"][:120]
|
||||
if f.get("rp_day") and not existing.get("rp_day"):
|
||||
out[i]["rp_day"] = f["rp_day"]
|
||||
placed = True
|
||||
break
|
||||
if not placed:
|
||||
out.append(dict(f))
|
||||
return out
|
||||
|
||||
|
||||
def merge_facts(
|
||||
existing_json: str,
|
||||
new_facts: list,
|
||||
*,
|
||||
rp_day_default: str = "",
|
||||
) -> str:
|
||||
merged = parse_facts_list(existing_json)
|
||||
seen = {f["text"].lower() for f in merged}
|
||||
default_day = (rp_day_default or "").strip()[:80]
|
||||
|
||||
for raw in new_facts or []:
|
||||
entry = parse_fact_entry(raw)
|
||||
if not entry:
|
||||
continue
|
||||
if not entry["rp_day"] and default_day:
|
||||
entry["rp_day"] = default_day
|
||||
key = entry["text"].lower()
|
||||
if key in seen:
|
||||
for i, existing in enumerate(merged):
|
||||
if existing["text"].lower() == key:
|
||||
if entry["rp_day"] and not existing.get("rp_day"):
|
||||
merged[i]["rp_day"] = entry["rp_day"]
|
||||
break
|
||||
continue
|
||||
dup = False
|
||||
for i, existing in enumerate(merged):
|
||||
if facts_are_similar(entry["text"], existing["text"]):
|
||||
if len(entry["text"]) > len(existing["text"]):
|
||||
merged[i]["text"] = entry["text"]
|
||||
if entry["rp_day"] and not existing.get("rp_day"):
|
||||
merged[i]["rp_day"] = entry["rp_day"]
|
||||
dup = True
|
||||
break
|
||||
if dup:
|
||||
continue
|
||||
seen.add(key)
|
||||
merged.append(entry)
|
||||
|
||||
return facts_list_to_json(merged)
|
||||
|
||||
|
||||
async def compress_facts(
|
||||
facts: list[dict],
|
||||
*,
|
||||
scene_context: str = "",
|
||||
status_quo: str = "",
|
||||
target: int = FACTS_COMPRESS_TARGET,
|
||||
) -> list[dict]:
|
||||
payload = json.dumps(facts, ensure_ascii=False, indent=2)
|
||||
user = (
|
||||
f"Current fact count: {len(facts)}. Target after merge: <= {target}.\n\n"
|
||||
f"Facts JSON:\n{payload}\n"
|
||||
)
|
||||
if scene_context.strip():
|
||||
user += f"\nCurrent scene:\n{scene_context.strip()[:1500]}\n"
|
||||
if status_quo.strip():
|
||||
user += f"\nStatus quo:\n{status_quo.strip()[:1500]}\n"
|
||||
|
||||
system = FACTS_COMPRESS_SYSTEM.format(target=target)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, FACTS_MODEL)
|
||||
if FACTS_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("compress_facts LLM failed: %s", e)
|
||||
return dedupe_facts_fuzzy(facts)[-target:]
|
||||
except Exception as e:
|
||||
logger.warning("compress_facts unexpected: %s", e)
|
||||
return dedupe_facts_fuzzy(facts)[-target:]
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
if isinstance(data, list):
|
||||
out = []
|
||||
for item in data:
|
||||
entry = parse_fact_entry(item)
|
||||
if entry:
|
||||
out.append(entry)
|
||||
if out:
|
||||
logger.info("compress_facts: %d -> %d", len(facts), len(out))
|
||||
return dedupe_facts_fuzzy(out)[:FACTS_STORE_LIMIT]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("compress_facts JSON parse failed. Raw=%.400s", raw)
|
||||
return dedupe_facts_fuzzy(facts)[-target:]
|
||||
|
||||
|
||||
async def merge_facts_persist(
|
||||
existing_json: str,
|
||||
new_facts: list,
|
||||
*,
|
||||
rp_day_default: str = "",
|
||||
scene_context: str = "",
|
||||
status_quo: str = "",
|
||||
) -> str:
|
||||
"""Merge, fuzzy-dedupe, LLM-compress when list grows too large."""
|
||||
merged_json = merge_facts(
|
||||
existing_json, new_facts, rp_day_default=rp_day_default
|
||||
)
|
||||
facts = dedupe_facts_fuzzy(parse_facts_list(merged_json))
|
||||
if len(facts) > FACTS_DEDUP_THRESHOLD:
|
||||
facts = await compress_facts(
|
||||
facts,
|
||||
scene_context=scene_context,
|
||||
status_quo=status_quo,
|
||||
target=FACTS_COMPRESS_TARGET,
|
||||
)
|
||||
facts = dedupe_facts_fuzzy(facts)
|
||||
if len(facts) > FACTS_STORE_LIMIT:
|
||||
facts = await compress_facts(
|
||||
facts,
|
||||
scene_context=scene_context,
|
||||
status_quo=status_quo,
|
||||
target=FACTS_STORE_LIMIT,
|
||||
)
|
||||
return facts_list_to_json(facts)
|
||||
|
||||
|
||||
async def extract_facts(
|
||||
context_messages: list[dict],
|
||||
*,
|
||||
rp_day_hint: str = "",
|
||||
existing_json: str = "",
|
||||
) -> list[dict]:
|
||||
transcript = "\n".join(
|
||||
f"{m.get('role')}: {m.get('content','')}".strip()
|
||||
f"{m.get('role')}: {m.get('content', '')}".strip()
|
||||
for m in context_messages
|
||||
if m.get("role") in ("user", "assistant")
|
||||
)[-6000:]
|
||||
|
||||
hint = (rp_day_hint or "").strip()
|
||||
known = parse_facts_list(existing_json)
|
||||
user_parts = []
|
||||
if known:
|
||||
known_lines = "\n".join(
|
||||
f"- [{f.get('rp_day') or '?'}] {f['text']}" for f in known[-40:]
|
||||
)
|
||||
user_parts.append(f"Already known facts (do NOT repeat):\n{known_lines}\n")
|
||||
if hint:
|
||||
user_parts.append(f"RP time hint: {hint}\n")
|
||||
user_parts.append(f"New transcript:\n{transcript}")
|
||||
user = "\n".join(user_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FACTS_SYSTEM},
|
||||
{"role": "user", "content": transcript},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
|
||||
raw = await (send_message_with_model(messages, FACTS_MODEL) if FACTS_MODEL else send_message(messages))
|
||||
try:
|
||||
data = json.loads(raw.strip())
|
||||
if isinstance(data, list):
|
||||
return [str(x) for x in data][:40]
|
||||
except Exception:
|
||||
raw = await (
|
||||
send_message_with_model(messages, FACTS_MODEL)
|
||||
if FACTS_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("extract_facts LLM failed (model=%s): %s", FACTS_MODEL or "SYSTEM", e)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning("extract_facts unexpected: %s", e)
|
||||
return []
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
if isinstance(data, list):
|
||||
out: list[dict] = []
|
||||
for item in data[:8]:
|
||||
entry = parse_fact_entry(item)
|
||||
if not entry:
|
||||
continue
|
||||
if any(facts_are_similar(entry["text"], k["text"]) for k in known):
|
||||
continue
|
||||
if not entry["rp_day"] and hint:
|
||||
entry["rp_day"] = hint[:80]
|
||||
out.append(entry)
|
||||
return out
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("extract_facts JSON parse failed. Raw=%.400s", raw)
|
||||
return []
|
||||
|
||||
|
||||
def facts_to_prompt(facts_json: str, max_items: int = 20) -> str:
|
||||
try:
|
||||
facts = json.loads(facts_json or "[]")
|
||||
if not isinstance(facts, list):
|
||||
return ""
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
facts = [str(x).strip() for x in facts if str(x).strip()]
|
||||
def facts_to_prompt(facts_json: str, max_items: int = FACTS_PROMPT_MAX) -> str:
|
||||
facts = dedupe_facts_fuzzy(parse_facts_list(facts_json))
|
||||
if not facts:
|
||||
return ""
|
||||
block = "\n".join(f"- {x}" for x in facts[-max_items:])
|
||||
return f"--- Facts (persistent memory) ---\n{block}\n---"
|
||||
|
||||
recent = facts[-max_items:]
|
||||
lines = []
|
||||
for f in recent:
|
||||
day = (f.get("rp_day") or "").strip()
|
||||
if day:
|
||||
lines.append(f"- [{day}] {f['text']}")
|
||||
else:
|
||||
lines.append(f"- {f['text']}")
|
||||
block = "\n".join(lines)
|
||||
total = len(facts)
|
||||
header = "--- Facts (persistent memory"
|
||||
if total > len(recent):
|
||||
header += f", showing {len(recent)} of {total}"
|
||||
header += ") ---"
|
||||
return f"{header}\n{block}\n---"
|
||||
|
||||
+63
-15
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from services.llm import send_message_with_model
|
||||
from services.llm import LLMError, send_message_with_model
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,8 +18,10 @@ Return ONLY valid JSON (no markdown):
|
||||
"check_reason": "brief reason why a check is needed (e.g. 'jumping over a pit')",
|
||||
"directives": ["short imperative rules for the next character reply"],
|
||||
"resolution_text": "what actually happens as result of the action — written as narrator prose (1-2 sentences). Only if needs_check=true and roll/outcome provided.",
|
||||
"status_quo_update": "optional short update about the world state"
|
||||
"status_quo_update": "optional short update about the world state",
|
||||
"scene_update": {"place": "", "time_of_day": "", "day": "", "weather": "", "exits": [], "layout_note": ""}
|
||||
}
|
||||
scene_update: only include keys that changed (partial). Omit scene_update if nothing changed.
|
||||
If needs_check=false: directives may still guide tone/pacing, resolution_text must be empty string.
|
||||
If needs_check=true and roll/outcome are provided: resolution_text MUST reflect the outcome.
|
||||
- critical failure (1): embarrassing or painful failure with extra complication
|
||||
@@ -32,17 +34,25 @@ After the character replied, update persistent state.
|
||||
Return ONLY valid JSON (no markdown):
|
||||
{
|
||||
"status_quo_update": "what changed in the world/state (1-3 sentences)",
|
||||
"facts": ["durable facts only"],
|
||||
"facts": ["durable fact strings OR {\"text\":\"...\",\"rp_day\":\"день 1\"}"],
|
||||
"choices": [{"id":"a","label":"..."}, ...],
|
||||
"affinity_delta": 0,
|
||||
"stats_delta": {"lust": 0, "stamina": 0, "tension": 0},
|
||||
"scene_update": {"place": "", "place_id": "", "time_of_day": "", "day": "", "weather": "", "exits": [], "layout_note": ""},
|
||||
"quest_updates": [{"title": "quest title", "status": "active|done|failed"}],
|
||||
"outfit_update": ["danbooru_tag", "danbooru_tag"]
|
||||
}
|
||||
Rules:
|
||||
- status_quo_update: internal DM state only (facts, location, mood). Never address the player, never use headers like "Status quo"/"Статус кво", P.S., or author commentary.
|
||||
- affinity_delta: integer -2..+2. Positive if character warmed up to player, negative if pushed away. 0 if neutral.
|
||||
- stats_delta: each lust/stamina/tension -2..+2 (0 if unchanged). lust=arousal, stamina=energy, tension=stress.
|
||||
- scene_update: partial location/time schema; only keys that changed. Do not duplicate all of status_quo into scene_update.
|
||||
- quest_updates: only include if a quest was clearly started, completed, or failed. Empty array otherwise.
|
||||
- choices: 0-4 options for what the player can do next.
|
||||
- outfit_update: ONLY include if the character's clothing visibly changed (put on, took off, changed outfit). Use exact danbooru-style underscore_tags (e.g. ["white_dress", "red_ribbon", "barefoot"]). Empty array if no change."""
|
||||
- choices: 0-4 options for what the player can do next. REQUIRED when scripted beats are exhausted — never return an empty choices array unless the session truly ended.
|
||||
- outfit_update: ONLY if clothing visibly changed. Use danbooru underscore_tags WITH COLOR when possible
|
||||
(e.g. white_tank_top, black_sports_shorts, gold_championship_belt, blue_jeans, red_ribbon).
|
||||
Every garment tag should include a color prefix unless the item is inherently colorless (barefoot, nude).
|
||||
Never bare generic tags like sports_shorts or torn_tank_top without a color. Empty array if no change."""
|
||||
|
||||
|
||||
async def narrator_pre(
|
||||
@@ -53,6 +63,7 @@ async def narrator_pre(
|
||||
user_message: str,
|
||||
roll: int | None = None,
|
||||
outcome: str | None = None,
|
||||
extra_context: str = "",
|
||||
) -> dict:
|
||||
roll_block = f"Roll d20={roll}\nOutcome={outcome}\n\n" if roll is not None else ""
|
||||
user = (
|
||||
@@ -63,10 +74,20 @@ async def narrator_pre(
|
||||
f"Facts:\n{facts_block}\n\n"
|
||||
f"Recent context:\n{context}\n"
|
||||
)
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
if extra_context:
|
||||
user += f"\n--- Session state ---\n{extra_context}\n---\n"
|
||||
try:
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("Narrator-pre LLM failed (model=%s): %s", NARRATOR_MODEL, e)
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
|
||||
except Exception as e:
|
||||
logger.warning("Narrator-pre unexpected error: %s", e)
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
@@ -76,10 +97,11 @@ async def narrator_pre(
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
if isinstance(data, dict):
|
||||
data["_ok"] = True
|
||||
return data
|
||||
except Exception:
|
||||
logger.warning("Narrator-pre JSON parse failed. Raw=%.500s", raw)
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""}
|
||||
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
|
||||
|
||||
|
||||
async def narrator_post(
|
||||
@@ -87,17 +109,42 @@ async def narrator_post(
|
||||
context: str,
|
||||
global_plot: str,
|
||||
facts_block: str,
|
||||
is_opening: bool = False,
|
||||
extra_context: str = "",
|
||||
) -> dict:
|
||||
opening_block = ""
|
||||
if is_opening:
|
||||
opening_block = (
|
||||
"\n\nOPENING SCENE: This is the first greeting, not a mid-conversation reply. "
|
||||
"Extract the character's INITIAL visible clothing from the greeting into outfit_update "
|
||||
"(danbooru underscore tags WITH color prefixes: white_shirt, black_shorts, gold_belt), "
|
||||
"even if clothing did not change during the scene. "
|
||||
"Set status_quo to describe the opening situation. "
|
||||
"Fill scene_update from greeting and scenario (place, time_of_day, day, layout_note). "
|
||||
"If the greeting shows clear warmth or hostility toward the player, set affinity_delta "
|
||||
"non-zero (-2..+2); use 0 only if truly neutral.\n"
|
||||
)
|
||||
user = (
|
||||
f"Persona: {persona_name}\n\n"
|
||||
f"Global plot:\n{global_plot}\n\n"
|
||||
f"Facts:\n{facts_block}\n\n"
|
||||
f"Recent context:\n{context}\n"
|
||||
f"{opening_block}"
|
||||
)
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
if extra_context:
|
||||
user += f"\n--- Session state ---\n{extra_context}\n---\n"
|
||||
try:
|
||||
raw = await send_message_with_model(
|
||||
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
|
||||
NARRATOR_MODEL,
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("Narrator-post LLM failed (model=%s): %s", NARRATOR_MODEL, e)
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
|
||||
except Exception as e:
|
||||
logger.warning("Narrator-post unexpected error: %s", e)
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
@@ -107,7 +154,8 @@ async def narrator_post(
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
if isinstance(data, dict):
|
||||
data["_ok"] = True
|
||||
return data
|
||||
except Exception:
|
||||
logger.warning("Narrator-post JSON parse failed. Raw=%.500s", raw)
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []}
|
||||
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
|
||||
|
||||
+405
-5
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from services.llm import send_message_with_model, send_message
|
||||
from services.llm import LLMError, send_message_with_model, send_message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +63,19 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
|
||||
{"role": "system", "content": ARC_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
raw = await (send_message_with_model(messages, PLOT_MODEL) if PLOT_MODEL else send_message(messages))
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("generate_plot_arc LLM failed (model=%s): %s", PLOT_MODEL or "SYSTEM", e)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("generate_plot_arc unexpected error: %s", e)
|
||||
return {}
|
||||
|
||||
cleaned = raw.strip()
|
||||
# common OpenRouter formatting: fenced json
|
||||
if cleaned.startswith("```"):
|
||||
@@ -79,17 +91,236 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
|
||||
return {}
|
||||
|
||||
|
||||
def should_advance_arc(user_text: str) -> str | None:
|
||||
BEAT_MATCH_SYSTEM = """You decide whether the player's latest message should fire ONE scripted plot beat.
|
||||
Return ONLY valid JSON (no markdown):
|
||||
{"fire_beat_id": "id from list or null", "confidence": "high|low"}
|
||||
|
||||
Rules:
|
||||
- Fire only if the message clearly matches that beat's narrative intent RIGHT NOW.
|
||||
- event_driven:rest — stopping to rest, sleep, camp, sauna break, recuperate (not mere sitting still in scene).
|
||||
- event_driven:travel — leaving, driving, journey, going to a new place, hitting the road.
|
||||
- event_driven:help_request — explicit plea for help/rescue/assistance.
|
||||
- event_driven:after_fail / after_success — follow-up to a recent failure/success beat.
|
||||
- Casual talk, flirting, exploring the current place without leaving does NOT fire travel.
|
||||
- If nothing fits well, return null.
|
||||
- Pick at most ONE beat; prefer high confidence only."""
|
||||
|
||||
|
||||
def dice_outcome_to_beat_trigger(outcome: str | None) -> str | None:
|
||||
"""Map d20 outcome to event_driven beat trigger (after_fail / after_success)."""
|
||||
o = (outcome or "").strip().lower()
|
||||
if o in ("failure", "critical failure"):
|
||||
return "event_driven:after_fail"
|
||||
if o in ("success", "critical success"):
|
||||
return "event_driven:after_success"
|
||||
return None
|
||||
|
||||
|
||||
def should_advance_arc_keywords(user_text: str) -> str | None:
|
||||
"""Legacy keyword fallback when LLM match is unavailable."""
|
||||
t = (user_text or "").lower()
|
||||
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн"]):
|
||||
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн", "привала", "заправк", "саун"]):
|
||||
return "event_driven:rest"
|
||||
if any(x in t for x in ["идем дальше", "пойдем дальше", "в путь", "продолжаем путь", "уходим", "возвращаемся", "переходим"]):
|
||||
if any(
|
||||
x in t
|
||||
for x in [
|
||||
"идем дальше", "пойдем дальше", "пойдём дальше", "едем дальше", "едем",
|
||||
"поехали", "выезжаем", "выезжаю", "в путь", "продолжаем путь",
|
||||
"уходим", "возвращаемся", "переходим", "за рул", "машин", "автомоб",
|
||||
"дорог", "трас", "шосс", "приех", "прибыва", "стади", "на стадион",
|
||||
"отправляемся", "выдвигаемся", "в дорогу",
|
||||
]
|
||||
):
|
||||
return "event_driven:travel"
|
||||
if any(x in t for x in ["помоги", "помочь", "нужна помощь", "спасите", "help"]):
|
||||
return "event_driven:help_request"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_llm_json(raw: str) -> dict | list | None:
|
||||
cleaned = (raw or "").strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
async def classify_plot_beat(
|
||||
user_text: str,
|
||||
beats: list[dict],
|
||||
recent_context: str = "",
|
||||
last_dice_outcome: str | None = None,
|
||||
) -> str | None:
|
||||
"""LLM: return beat id to fire, or None."""
|
||||
pending = [b for b in beats if isinstance(b, dict) and b.get("id")]
|
||||
if not pending or not (user_text or "").strip():
|
||||
return None
|
||||
|
||||
lines = []
|
||||
for b in pending[:8]:
|
||||
lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"id": b.get("id"),
|
||||
"title": b.get("title", ""),
|
||||
"trigger": b.get("trigger", ""),
|
||||
"injection": (b.get("injection") or "")[:200],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
user = (
|
||||
f"Player message:\n{user_text.strip()}\n\n"
|
||||
f"Pending beats:\n" + "\n".join(lines)
|
||||
)
|
||||
if recent_context.strip():
|
||||
user += f"\n\nRecent chat:\n{recent_context.strip()[-2500:]}\n"
|
||||
if last_dice_outcome:
|
||||
user += f"\nLast dice outcome this turn: {last_dice_outcome}\n"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": BEAT_MATCH_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("classify_plot_beat LLM failed: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("classify_plot_beat unexpected: %s", e)
|
||||
return None
|
||||
|
||||
data = _parse_llm_json(raw)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
bid = data.get("fire_beat_id")
|
||||
if bid in (None, "", "null", "none"):
|
||||
return None
|
||||
bid = str(bid).strip()
|
||||
if data.get("confidence") == "low":
|
||||
return None
|
||||
valid_ids = {str(b.get("id")) for b in pending}
|
||||
if bid in valid_ids:
|
||||
logger.info("classify_plot_beat: fired %s", bid)
|
||||
return bid
|
||||
return None
|
||||
|
||||
|
||||
def pop_beat_by_id(arc: dict, beat_id: str) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats") or []
|
||||
matched, remaining = [], []
|
||||
for b in beats:
|
||||
if isinstance(b, dict) and str(b.get("id")) == str(beat_id) and not matched:
|
||||
matched.append(b)
|
||||
else:
|
||||
remaining.append(b)
|
||||
arc["beats"] = remaining
|
||||
return arc, matched
|
||||
|
||||
|
||||
def beat_title(beat: dict) -> str:
|
||||
return ((beat.get("title") or beat.get("injection") or "")[:120]).strip()
|
||||
|
||||
|
||||
def count_active_quests(quests: list | None) -> int:
|
||||
return sum(1 for q in (quests or []) if q.get("status") == "active")
|
||||
|
||||
|
||||
def prune_beats_for_done_quests(arc: dict, quests: list | None) -> tuple[dict, list[dict]]:
|
||||
"""Drop beats whose title already matches a done/failed quest (manual quest close desync)."""
|
||||
done_titles = {
|
||||
(q.get("title") or "").strip().lower()
|
||||
for q in (quests or [])
|
||||
if q.get("status") in ("done", "failed")
|
||||
}
|
||||
if not done_titles:
|
||||
return arc, []
|
||||
removed, kept = [], []
|
||||
for b in arc.get("beats") or []:
|
||||
if isinstance(b, dict) and beat_title(b).lower() in done_titles:
|
||||
removed.append(b)
|
||||
else:
|
||||
kept.append(b)
|
||||
arc["beats"] = kept
|
||||
return arc, removed
|
||||
|
||||
|
||||
def pop_next_beats(arc: dict, max_beats: int = 1) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats") or []
|
||||
if not isinstance(beats, list) or not beats:
|
||||
return arc, []
|
||||
n = min(max_beats, len(beats))
|
||||
matched = [b for b in beats[:n] if isinstance(b, dict)]
|
||||
arc["beats"] = beats[n:]
|
||||
return arc, matched
|
||||
|
||||
|
||||
async def process_arc_beats(
|
||||
arc: dict,
|
||||
quests: list | None,
|
||||
user_text: str,
|
||||
*,
|
||||
recent_context: str = "",
|
||||
last_dice_outcome: str | None = None,
|
||||
allow_stuck_recovery: bool = True,
|
||||
) -> tuple[dict, list[dict], list[dict], str]:
|
||||
"""
|
||||
Prune completed beats, then fire by dice outcome, LLM match, keywords, or stuck recovery.
|
||||
Returns (arc, fired_beats, pruned_beats, mode).
|
||||
mode: '' | 'after_dice' | 'llm' | 'trigger' | 'stuck_recovery' | 'pruned'
|
||||
"""
|
||||
if not arc:
|
||||
return arc, [], [], ""
|
||||
|
||||
arc, pruned = prune_beats_for_done_quests(arc, quests)
|
||||
beats_pending = arc.get("beats") or []
|
||||
|
||||
dice_trig = dice_outcome_to_beat_trigger(last_dice_outcome)
|
||||
if dice_trig and beats_pending:
|
||||
arc, fired = pop_matching_beats(arc, dice_trig, max_beats=1)
|
||||
if fired:
|
||||
logger.info(
|
||||
"process_arc_beats: after_dice %s -> %s",
|
||||
last_dice_outcome,
|
||||
fired[0].get("id"),
|
||||
)
|
||||
return arc, fired, pruned, "after_dice"
|
||||
|
||||
if beats_pending:
|
||||
beat_id = await classify_plot_beat(
|
||||
user_text, beats_pending, recent_context, last_dice_outcome
|
||||
)
|
||||
if beat_id:
|
||||
arc, fired = pop_beat_by_id(arc, beat_id)
|
||||
if fired:
|
||||
return arc, fired, pruned, "llm"
|
||||
|
||||
trig = should_advance_arc_keywords(user_text)
|
||||
if trig:
|
||||
arc, fired = pop_matching_beats(arc, trig, max_beats=1)
|
||||
if fired:
|
||||
return arc, fired, pruned, "trigger"
|
||||
|
||||
if allow_stuck_recovery and arc.get("beats") and count_active_quests(quests) == 0:
|
||||
arc, fired = pop_next_beats(arc, 1)
|
||||
if fired:
|
||||
return arc, fired, pruned, "stuck_recovery"
|
||||
|
||||
if pruned:
|
||||
return arc, [], pruned, "pruned"
|
||||
return arc, [], [], ""
|
||||
|
||||
|
||||
PHASE_ORDER = ["opening", "hook", "complication", "reveal", "climax", "aftermath"]
|
||||
|
||||
|
||||
@@ -108,6 +339,129 @@ def advance_phase(arc: dict) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
BEATS_APPEND_SYSTEM = """You are a narrative designer for an RPG chat.
|
||||
The plot arc has NO remaining scripted beats. Generate 2-3 NEW beats to continue play.
|
||||
Return ONLY valid JSON (no markdown):
|
||||
{
|
||||
"beats": [
|
||||
{"id":"b_new_1","title":"short quest title","trigger":"event_driven:rest|event_driven:travel|event_driven:help_request|event_driven:after_fail|event_driven:after_success",
|
||||
"injection":"1-3 sentences in-world",
|
||||
"choices":[{"id":"a","label":"..."},{"id":"b","label":"..."}]}
|
||||
],
|
||||
"next_beat_hint": "what to push next",
|
||||
"phase": "hook|complication|reveal|climax|aftermath"
|
||||
}
|
||||
Match the current scene and completed quests. Do not restart finished storylines."""
|
||||
|
||||
|
||||
async def replenish_arc_beats(
|
||||
arc: dict,
|
||||
persona_name: str,
|
||||
recent_context: str,
|
||||
quests: list,
|
||||
genre: str = "adventure",
|
||||
) -> dict:
|
||||
"""Append new beats when arc.beats is empty so plot/quest engine can continue."""
|
||||
if arc.get("beats"):
|
||||
return arc
|
||||
|
||||
quest_lines = "\n".join(
|
||||
f" [{q.get('status')}] {q.get('title')}" for q in (quests or [])
|
||||
) or " (none)"
|
||||
user = (
|
||||
f"Character: {persona_name}\n"
|
||||
f"Genre: {format_genres(genre)}\n"
|
||||
f"Current arc title: {arc.get('title', '')}\n"
|
||||
f"Phase: {arc.get('phase', 'aftermath')}\n"
|
||||
f"Boundaries: {json.dumps(arc.get('boundaries', []), ensure_ascii=False)}\n"
|
||||
f"Quests:\n{quest_lines}\n\n"
|
||||
f"Recent chat:\n{recent_context[-4000:]}\n"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": BEATS_APPEND_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("replenish_arc_beats failed: %s", e)
|
||||
return arc
|
||||
except Exception as e:
|
||||
logger.warning("replenish_arc_beats unexpected: %s", e)
|
||||
return arc
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except Exception:
|
||||
logger.warning("replenish_arc_beats JSON parse failed. Raw=%.400s", raw)
|
||||
return arc
|
||||
|
||||
new_beats = data.get("beats") if isinstance(data, dict) else []
|
||||
if isinstance(new_beats, list) and new_beats:
|
||||
arc["beats"] = new_beats
|
||||
logger.info("replenish_arc_beats: added %d beats", len(new_beats))
|
||||
if isinstance(data, dict) and data.get("next_beat_hint"):
|
||||
arc["next_beat_hint"] = data["next_beat_hint"]
|
||||
if isinstance(data, dict) and data.get("phase"):
|
||||
arc["phase"] = data["phase"]
|
||||
return arc
|
||||
|
||||
|
||||
async def reconcile_plot_arc(
|
||||
session_id: str,
|
||||
*,
|
||||
replenish_if_empty: bool = True,
|
||||
recent_context: str = "",
|
||||
persona_name: str = "Character",
|
||||
genre: str = "adventure",
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Prune beats that match done quests; replenish if empty. Persists arc when changed.
|
||||
Returns (arc, changed).
|
||||
"""
|
||||
from services.memory import get_session, get_quests, update_session_plot_arc, seed_quests_from_arc
|
||||
|
||||
session = await get_session(session_id)
|
||||
if not session or not session.get("rpg_enabled"):
|
||||
return {}, False
|
||||
try:
|
||||
arc = json.loads(session.get("plot_arc_json") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
arc = {}
|
||||
if not isinstance(arc, dict):
|
||||
arc = {}
|
||||
|
||||
quests = await get_quests(session_id)
|
||||
arc, pruned = prune_beats_for_done_quests(arc, quests)
|
||||
changed = bool(pruned)
|
||||
|
||||
if replenish_if_empty and not arc.get("beats"):
|
||||
arc = await replenish_arc_beats(
|
||||
arc,
|
||||
persona_name,
|
||||
recent_context,
|
||||
quests,
|
||||
genre=session.get("genre") or genre,
|
||||
)
|
||||
if arc.get("beats"):
|
||||
changed = True
|
||||
await seed_quests_from_arc(session_id, arc)
|
||||
|
||||
if changed:
|
||||
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
||||
return arc, changed
|
||||
|
||||
|
||||
def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats", [])
|
||||
if not isinstance(beats, list):
|
||||
@@ -121,3 +475,49 @@ def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dic
|
||||
arc["beats"] = remaining
|
||||
return arc, matched
|
||||
|
||||
|
||||
def normalize_choice(
|
||||
raw: dict,
|
||||
*,
|
||||
source: str = "narrator",
|
||||
beat: dict | None = None,
|
||||
) -> dict | None:
|
||||
"""Normalize a choice dict for storage/UI. Adds source and optional beat metadata."""
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
label = (raw.get("label") or "").strip()
|
||||
if not label:
|
||||
return None
|
||||
cid = (raw.get("id") or label[:1].lower() or "a").strip()
|
||||
out = {"id": cid, "label": label, "source": source}
|
||||
if beat and source == "plot_beat":
|
||||
if beat.get("id"):
|
||||
out["beat_id"] = beat["id"]
|
||||
title = (beat.get("title") or "").strip()
|
||||
if title:
|
||||
out["beat_title"] = title
|
||||
injection = (beat.get("injection") or "").strip()
|
||||
if injection:
|
||||
out["beat_injection"] = injection
|
||||
return out
|
||||
|
||||
|
||||
def choices_from_beat(beat: dict) -> list[dict]:
|
||||
if not isinstance(beat, dict):
|
||||
return []
|
||||
return [
|
||||
c for c in (
|
||||
normalize_choice(item, source="plot_beat", beat=beat)
|
||||
for item in (beat.get("choices") or [])
|
||||
)
|
||||
if c
|
||||
]
|
||||
|
||||
|
||||
def choices_from_narrator(raw_choices: list) -> list[dict]:
|
||||
if not isinstance(raw_choices, list):
|
||||
return []
|
||||
return [
|
||||
c for c in (normalize_choice(item, source="narrator") for item in raw_choices) if c
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
import json
|
||||
|
||||
DEFAULT_NARRATIVE_STATS = {"lust": 0, "stamina": 10, "tension": 0}
|
||||
STAT_KEYS = ("lust", "stamina", "tension")
|
||||
|
||||
|
||||
def parse_stats_json(raw: str | None) -> dict:
|
||||
try:
|
||||
data = json.loads(raw or "{}") if isinstance(raw, str) else (raw or {})
|
||||
except Exception:
|
||||
data = {}
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
out = dict(DEFAULT_NARRATIVE_STATS)
|
||||
for k in STAT_KEYS:
|
||||
try:
|
||||
out[k] = max(0, min(10, int(data.get(k, out[k]))))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def parse_scene_json(raw: str | None) -> dict:
|
||||
try:
|
||||
data = json.loads(raw or "{}") if isinstance(raw, str) else (raw or {})
|
||||
except Exception:
|
||||
data = {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def merge_scene(existing: dict, update: dict | None) -> dict:
|
||||
if not update or not isinstance(update, dict):
|
||||
return dict(existing)
|
||||
merged = dict(existing)
|
||||
for k, v in update.items():
|
||||
if v is None:
|
||||
continue
|
||||
if k == "exits" and isinstance(v, list):
|
||||
merged["exits"] = v
|
||||
elif isinstance(v, str) and v.strip():
|
||||
merged[k] = v.strip()
|
||||
elif v != "":
|
||||
merged[k] = v
|
||||
return merged
|
||||
|
||||
|
||||
def apply_stats_delta(stats: dict, delta: dict | None) -> dict:
|
||||
out = parse_stats_json(json.dumps(stats, ensure_ascii=False))
|
||||
if not delta or not isinstance(delta, dict):
|
||||
return out
|
||||
for k in STAT_KEYS:
|
||||
try:
|
||||
d = int(delta.get(k, 0))
|
||||
except (TypeError, ValueError):
|
||||
d = 0
|
||||
d = max(-2, min(2, d))
|
||||
if d:
|
||||
out[k] = max(0, min(10, out[k] + d))
|
||||
return out
|
||||
|
||||
|
||||
def scene_prompt_block(scene: dict) -> str:
|
||||
if not scene:
|
||||
return ""
|
||||
lines = []
|
||||
for key, label in (
|
||||
("place", "Location"),
|
||||
("place_id", "Location ID"),
|
||||
("time_of_day", "Time of day"),
|
||||
("day", "Day"),
|
||||
("weather", "Weather"),
|
||||
("layout_note", "Layout"),
|
||||
):
|
||||
v = (scene.get(key) or "").strip()
|
||||
if v:
|
||||
lines.append(f"{label}: {v}")
|
||||
exits = scene.get("exits")
|
||||
if isinstance(exits, list) and exits:
|
||||
lines.append("Exits: " + ", ".join(str(e) for e in exits))
|
||||
if not lines:
|
||||
return ""
|
||||
return (
|
||||
"\n\n--- Scene ---\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\nUse this setting consistently. Do not recite scene labels verbatim.\n---"
|
||||
)
|
||||
|
||||
|
||||
# Scoped to current player until multi-user / multi-card relationships exist.
|
||||
RELATIONSHIP_TARGET_LABEL = "current player (the human you are roleplaying with)"
|
||||
|
||||
|
||||
def _mandatory_state_preamble() -> str:
|
||||
return (
|
||||
"MANDATORY ROLEPLAY DIRECTIVE (binding — overrides default cheerfulness or energy):\n"
|
||||
f"- Target: {RELATIONSHIP_TARGET_LABEL} only.\n"
|
||||
"- Your reply MUST visibly obey this state in tone, body language, pacing, and what the character attempts.\n"
|
||||
"- Do NOT contradict a low stamina / high tension / low affinity reading with peppy or intimate behavior unless the text above explicitly allows recovery.\n"
|
||||
"- Never name affinity, lust, stamina, tension, stats, meters, or numeric values in dialogue.\n"
|
||||
)
|
||||
|
||||
|
||||
def _band_instruction(value: int, bands: list[tuple[int, str]]) -> str:
|
||||
"""bands: list of (min_inclusive, instruction) sorted by min descending."""
|
||||
v = int(value)
|
||||
for min_val, text in bands:
|
||||
if v >= min_val:
|
||||
return text
|
||||
return bands[-1][1] if bands else ""
|
||||
|
||||
|
||||
def affinity_prompt_block(affinity: int) -> str:
|
||||
aff = max(-30, min(30, int(affinity)))
|
||||
attitude = _band_instruction(
|
||||
aff,
|
||||
[
|
||||
(10, "Devoted trust: openly affectionate, seeks closeness, defends the player, vulnerable honesty."),
|
||||
(5, "Warm bond: friendly, teasing allowed, volunteers help, remembers small kindnesses."),
|
||||
(1, "Slight fondness: polite-positive, rare soft moments, not yet intimate."),
|
||||
(0, "Neutral professional distance: neither warm nor cold unless scene demands."),
|
||||
(-1, "Cool and guarded: short answers, deflects personal topics, skeptical of motives."),
|
||||
(-5, "Hostile or deeply distrustful: sarcasm, refusal, may threaten to leave or expose the player."),
|
||||
(-30, "Open enmity: antagonistic, undermines, may sabotage or attack socially/physically if fitting genre."),
|
||||
],
|
||||
)
|
||||
return (
|
||||
"\n\n--- Relationship toward player (MANDATORY) ---\n"
|
||||
+ _mandatory_state_preamble()
|
||||
+ f"Affinity (internal, not spoken): level {aff}.\n"
|
||||
f"Required attitude toward {RELATIONSHIP_TARGET_LABEL}: {attitude}\n"
|
||||
"---"
|
||||
)
|
||||
|
||||
|
||||
def stats_prompt_block(stats: dict) -> str:
|
||||
s = parse_stats_json(json.dumps(stats, ensure_ascii=False))
|
||||
lust = s["lust"]
|
||||
stamina = s["stamina"]
|
||||
tension = s["tension"]
|
||||
|
||||
lust_line = _band_instruction(
|
||||
lust,
|
||||
[
|
||||
(9, "Overwhelming arousal colors every beat: breathy, distracted, struggles to stay on task."),
|
||||
(7, "Strong desire: flushed skin, lingering touch, voice unsteady, thoughts drift to intimacy."),
|
||||
(5, "Clear attraction: meaningful glances, leaning in, playful double meanings if genre fits."),
|
||||
(3, "Mild warmth: subtle flirt only when appropriate; easily redirected."),
|
||||
(1, "Little romantic charge: platonic focus unless the player escalates."),
|
||||
(0, "No romantic or sexual undertone in body language or subtext."),
|
||||
],
|
||||
)
|
||||
stamina_line = _band_instruction(
|
||||
stamina,
|
||||
[
|
||||
(9, "Peak energy: brisk movement, sharp focus, may offer to take strenuous actions."),
|
||||
(7, "Well-rested: alert, steady pace, normal exertion."),
|
||||
(5, "Average fatigue: fine for talk/light action; heavy labor needs justification."),
|
||||
(4, "Tired: slower reactions, sits when possible, voice softer."),
|
||||
(3, "Heavy fatigue: frequent pauses, avoids running/fighting, may ask to stop."),
|
||||
(2, "Exhausted: barely moves, slumped posture, short sentences, needs support to walk."),
|
||||
(1, "On the verge of collapse: eyelids heavy, may stumble or nearly pass out; minimal action only."),
|
||||
(0, "Cannot sustain activity: collapse/immediate sleep/rest is imminent unless helped."),
|
||||
],
|
||||
)
|
||||
tension_line = _band_instruction(
|
||||
tension,
|
||||
[
|
||||
(9, "Breaking point: trembling, tears or rage close to surface, irrational snap decisions."),
|
||||
(7, "High stress: clipped speech, hyper-vigilant, jumps at sounds, defensive."),
|
||||
(5, "Uneasy: fidgeting, forced smiles, changes subject from danger."),
|
||||
(3, "Mild edge: occasional sigh, watches exits, relaxes with reassurance."),
|
||||
(1, "Mostly calm: normal breathing, open posture."),
|
||||
(0, "Fully at ease in body and voice."),
|
||||
],
|
||||
)
|
||||
|
||||
return (
|
||||
"\n\n--- Physical & emotional state (MANDATORY) ---\n"
|
||||
+ _mandatory_state_preamble()
|
||||
+ f"Internal scales (0–10, never spoken): lust/arousal={lust}, stamina/energy={stamina}, tension/stress={tension}.\n"
|
||||
f"- Lust/arousal: {lust_line}\n"
|
||||
f"- Stamina/energy: {stamina_line}\n"
|
||||
f"- Tension/stress: {tension_line}\n"
|
||||
"If multiple apply, combine them (e.g. low stamina + high tension = shaky exhaustion, not peppy panic).\n"
|
||||
"---"
|
||||
)
|
||||
|
||||
|
||||
def format_narrator_outcome_for_llm(data: dict) -> str:
|
||||
"""Turn stored narrator JSON into a binding user-turn for the character model."""
|
||||
roll = data.get("roll")
|
||||
outcome = (data.get("outcome") or "").strip().lower()
|
||||
text = (data.get("text") or "").strip()
|
||||
lines = [
|
||||
"--- Narrator ruling (MANDATORY — your next in-character reply MUST follow this) ---",
|
||||
f"Roll d20={roll}. Outcome: {outcome}.",
|
||||
f"What ACTUALLY happened (canonical truth): {text}",
|
||||
]
|
||||
if outcome in ("failure", "critical failure"):
|
||||
lines.append(
|
||||
"The player's action FAILED as they imagined it. "
|
||||
"Do NOT write a success version: no crowd fleeing, no intimidation working, "
|
||||
"no effortless victory. Show the failure, embarrassment, or partial result above."
|
||||
)
|
||||
elif outcome == "critical success":
|
||||
lines.append(
|
||||
"The attempt succeeded dramatically. You may show amplified success aligned with the outcome above."
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
"The attempt succeeded. Your reply must align with the narrator outcome above, not contradict it."
|
||||
)
|
||||
lines.append("Respond as the character to THIS outcome only. Never cite dice, rolls, or stats.")
|
||||
lines.append("---")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_user_message_for_llm(content: str, *, has_dice_resolution: bool) -> str:
|
||||
if not has_dice_resolution:
|
||||
return content
|
||||
return (
|
||||
"[Player stated intent — canonical outcome is in the narrator ruling immediately below]\n"
|
||||
+ content
|
||||
)
|
||||
|
||||
|
||||
def scene_to_sd_hint(scene: dict) -> str:
|
||||
if not scene:
|
||||
return ""
|
||||
parts = []
|
||||
for k in ("place", "time_of_day", "day", "weather", "layout_note"):
|
||||
v = (scene.get(k) or "").strip()
|
||||
if v:
|
||||
parts.append(f"{k}: {v}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
async def apply_narrator_post(session_id: str, post: dict, rpg_settings: dict, session: dict | None = None) -> dict:
|
||||
"""Persist narrator_post fields into session. Returns summary of what changed."""
|
||||
from services.memory import (
|
||||
get_session,
|
||||
update_session_status_quo,
|
||||
update_session_affinity,
|
||||
update_session_outfit,
|
||||
update_session_scene,
|
||||
update_session_narrative_stats,
|
||||
update_session_facts,
|
||||
upsert_quest,
|
||||
)
|
||||
from services.rpg_facts import merge_facts_persist, rp_day_from_scene, parse_facts_list
|
||||
|
||||
if session is None:
|
||||
session = await get_session(session_id) or {}
|
||||
|
||||
applied = {
|
||||
"status_quo": False,
|
||||
"facts_added": 0,
|
||||
"affinity_delta": 0,
|
||||
"quests_updated": 0,
|
||||
"scene": False,
|
||||
"outfit": False,
|
||||
}
|
||||
|
||||
sq = (post.get("status_quo_update") or "").strip()
|
||||
if sq:
|
||||
await update_session_status_quo(session_id, sq)
|
||||
applied["status_quo"] = True
|
||||
|
||||
post_facts = post.get("facts") or []
|
||||
if isinstance(post_facts, list) and post_facts:
|
||||
rp_day = rp_day_from_scene(session.get("scene_json"))
|
||||
before = len(parse_facts_list(session.get("facts_json") or "[]"))
|
||||
merged = await merge_facts_persist(
|
||||
session.get("facts_json", "[]"),
|
||||
post_facts,
|
||||
rp_day_default=rp_day,
|
||||
scene_context=json.dumps(
|
||||
parse_scene_json(session.get("scene_json")), ensure_ascii=False
|
||||
),
|
||||
status_quo=session.get("status_quo") or "",
|
||||
)
|
||||
await update_session_facts(session_id, merged)
|
||||
after = len(parse_facts_list(merged))
|
||||
applied["facts_added"] = max(0, after - before)
|
||||
|
||||
if rpg_settings.get("affinity", True):
|
||||
delta = int(post.get("affinity_delta") or 0)
|
||||
if delta:
|
||||
await update_session_affinity(session_id, delta)
|
||||
applied["affinity_delta"] = delta
|
||||
|
||||
outfit_update = post.get("outfit_update")
|
||||
if isinstance(outfit_update, list) and outfit_update:
|
||||
from services.outfit_tags import outfit_list_to_json
|
||||
|
||||
await update_session_outfit(session_id, outfit_list_to_json(outfit_update))
|
||||
applied["outfit"] = True
|
||||
|
||||
scene_update = post.get("scene_update")
|
||||
if isinstance(scene_update, dict) and scene_update:
|
||||
merged = merge_scene(parse_scene_json(session.get("scene_json")), scene_update)
|
||||
await update_session_scene(session_id, json.dumps(merged, ensure_ascii=False))
|
||||
applied["scene"] = True
|
||||
|
||||
if rpg_settings.get("stats", False):
|
||||
stats_delta = post.get("stats_delta")
|
||||
if isinstance(stats_delta, dict) and stats_delta:
|
||||
current = parse_stats_json(session.get("narrative_stats_json"))
|
||||
updated = apply_stats_delta(current, stats_delta)
|
||||
await update_session_narrative_stats(
|
||||
session_id, json.dumps(updated, ensure_ascii=False)
|
||||
)
|
||||
|
||||
if rpg_settings.get("quests", True):
|
||||
for qu in post.get("quest_updates") or []:
|
||||
title = (qu.get("title") or "").strip()
|
||||
if title:
|
||||
await upsert_quest(session_id, title[:120], qu.get("status", "active"))
|
||||
applied["quests_updated"] += 1
|
||||
|
||||
return applied
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Run ComfyUI generation from SdPromptBundle (single hybrid prompt for Anima)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from services import sdbackend as sd_service
|
||||
from services.memory import update_message_image, update_message_prompt, update_message_prompt_alt
|
||||
from services.sd_prompt import SdPromptBundle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
async def run_sd_for_message(bundle: SdPromptBundle | None, msg_id: int | None) -> dict:
|
||||
"""Generate image, persist prompts/paths on message. Returns fields for API/SSE."""
|
||||
out = {
|
||||
"image_prompt": None,
|
||||
"image_prompt_alt": None,
|
||||
"image_path": None,
|
||||
"image_path_alt": None,
|
||||
"image_error": None,
|
||||
"image_error_alt": None,
|
||||
}
|
||||
if not bundle or not bundle.tag_full:
|
||||
return out
|
||||
|
||||
out["image_prompt"] = bundle.tag_full
|
||||
if bundle.desc_full and bundle.desc_full != bundle.tag_full:
|
||||
out["image_prompt_alt"] = bundle.desc_full
|
||||
|
||||
if msg_id:
|
||||
await update_message_prompt(msg_id, bundle.tag_full)
|
||||
if out["image_prompt_alt"]:
|
||||
await update_message_prompt_alt(msg_id, out["image_prompt_alt"])
|
||||
|
||||
if not SD_AUTO_GENERATE:
|
||||
return out
|
||||
|
||||
rel, err = await sd_service.generate_from_full_prompt(bundle.tag_full)
|
||||
if rel:
|
||||
out["image_path"] = f"/static/{rel}"
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
else:
|
||||
out["image_error"] = err
|
||||
|
||||
return out
|
||||
+650
-67
@@ -2,26 +2,115 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm import send_message, send_message_with_model
|
||||
from services.personas import get_persona
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NEGATIVE_PROMPT_SEPARATOR = "\n\n__NEGATIVE_PROMPT__\n\n"
|
||||
|
||||
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
|
||||
Given a roleplay chat excerpt, output ONLY valid JSON (no markdown):
|
||||
{
|
||||
"should_generate": true,
|
||||
"shot_type": "first_person_pov" | "landscape" | "third_person",
|
||||
"action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'",
|
||||
"environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'"
|
||||
"action_tags": "booru-style tags for pose/action/expression",
|
||||
"environment_tags": "booru-style tags for location/lighting/time"
|
||||
}
|
||||
Rules:
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'.
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined.
|
||||
- Do NOT include appearance/character tags — those are provided separately.
|
||||
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
|
||||
- Do NOT invent tags. If unsure — omit.
|
||||
- Keep each field to 3-6 tags."""
|
||||
- Keep action_tags and environment_tags to 3-6 tags each.
|
||||
- shot_type: default "first_person_pov" for dialogue/intimacy at arm's length. "third_person" only for wide action (fight, chase). "landscape" only when environment is the focus.
|
||||
- should_generate: false for non-visual beats (pure internal monologue, time skips with no new pose, empty lines).
|
||||
- NEVER use negative words in tag fields (not, without, naked, nsfw, etc.)."""
|
||||
|
||||
ANIMA_BUILDER_EXTRA = """
|
||||
Anima hybrid mode — ALSO include:
|
||||
"pov_cue": "face_to_face" | "walking_together" | "doorway_invite" | "reach_to_viewer" | "dialogue_close",
|
||||
"viewer_body_visible": false,
|
||||
"scene_description": "ONE short English sentence (max 40 words). Camera POV: what the viewer sees. Mood/atmosphere only — do NOT repeat tags from action_tags/environment_tags. Do NOT list comma-separated booru tags."
|
||||
POV / interaction rules:
|
||||
- Default viewer_body_visible: false. The viewer's body, hands, or face must NOT appear in the image — only the character toward the camera.
|
||||
- For hugs, embraces: use arms_out, reaching_towards_viewer, inviting_hug — NOT holding_hands, lifting, carrying, nose_rub (these draw a second body in POV).
|
||||
- For long messages with time skips ("About an hour later..."), illustrate ONLY the final visible beat (usually the last paragraph).
|
||||
- scene_description: describe HER toward the camera only — NEVER "someone", "both", "with you", "hand in hand with", or another person's body.
|
||||
- NEVER use tags: looking_at_each_other, couple, 2girls, 2boys, multiple_girls. For POV walking together omit holding_hands (use walking, smiling, reaching_towards_viewer instead).
|
||||
- pov_cue: pick the framing that matches the CURRENT beat (walking_together for strolling side by side, doorway_invite for doorway with arms open, reach_to_viewer when she reaches toward camera, face_to_face for close dialogue).
|
||||
- Illustrate ONLY the beat under === ILLUSTRATE ===; use === Context === for outfit/location hints only.
|
||||
- Do NOT put English sentences in action_tags or environment_tags — tags only."""
|
||||
|
||||
POV_CUE_PHRASES: dict[str, str] = {
|
||||
"face_to_face": "POV: close face-to-face, she looks directly at you",
|
||||
"walking_together": "POV: walking beside you, profile and shared path visible",
|
||||
"doorway_invite": "POV: she blocks the doorway, arms open toward you",
|
||||
"reach_to_viewer": "POV: she reaches toward the camera",
|
||||
"dialogue_close": "POV: close conversation, she faces you at arm's length",
|
||||
}
|
||||
|
||||
POV_CUE_DEFAULT = "POV: she stands before you, facing the camera"
|
||||
|
||||
POV_INTERACTION_NEGATIVE = (
|
||||
"duplicate, clone, multiple_girls, 2girls, extra_person, pov hands, "
|
||||
"disembodied hands, extra arms, second person"
|
||||
)
|
||||
|
||||
_CONTACT_ACTION_KEYWORDS = (
|
||||
"hug", "holding_hands", "hand_holding", "arms_out", "embrace",
|
||||
"reaching", "inviting_hug", "arm_around", "cuddling",
|
||||
)
|
||||
|
||||
_JUNK_STANDALONE_TAGS = frozenset({
|
||||
"white", "black", "skin", "ear", "ears", "girl", "boy", "fox", "wolf", "cat",
|
||||
"short", "tall", "golden", "silver", "red", "blue", "green", "purple",
|
||||
"pink", "brown", "eye", "eyes", "hair",
|
||||
})
|
||||
|
||||
_INVALID_TAGS = frozenset({
|
||||
"pumped_up", "pumped", "looking_at_each_other", "couple",
|
||||
"2girls", "2boys", "multiple_girls", "multiple_boys", "duo",
|
||||
})
|
||||
|
||||
_POV_DROP_ACTION_TAGS = frozenset({
|
||||
"holding_hands", "hand_holding", "looking_at_each_other", "couple",
|
||||
"lifting", "carry", "carrying", "princess_carry", "nose_rub", "nose_boop",
|
||||
})
|
||||
|
||||
_TIME_SKIP_RE = re.compile(
|
||||
r"(?i)\b(?:about an hour later|hours later|later that (?:day|evening|night)|"
|
||||
r"the next (?:day|morning|evening)|meanwhile|after (?:some )?time)\b[.…\s]*",
|
||||
)
|
||||
|
||||
_POV_MOOD_FALLBACK: dict[str, str] = {
|
||||
"walking_together": "Easy warmth and quiet laughter in the afternoon light.",
|
||||
"doorway_invite": "Cool air and playful tension as she waits in the doorway.",
|
||||
"reach_to_viewer": "A charged moment as she reaches toward the camera.",
|
||||
"face_to_face": "Her expression softens in close focus toward the camera.",
|
||||
"dialogue_close": "Intimate calm in the space between you.",
|
||||
}
|
||||
|
||||
_INDOOR_ENV_MARKERS = frozenset({"doorway", "indoors", "indoor", "apartment", "inside", "room"})
|
||||
_OUTDOOR_ENV_MARKERS = frozenset({"outdoor", "outdoors", "outside", "street"})
|
||||
|
||||
_POV_PROSE_BANNED = re.compile(
|
||||
r"\b(someone|both|together with|hand in hand with|another person|second person|"
|
||||
r"your hands|your fingers|your embrace|your heat|intertwined|with you|"
|
||||
r"demands your|before you)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
SD_ANIMA_DUAL_COMPARE = os.getenv("SD_ANIMA_DUAL_COMPARE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SdPromptBundle:
|
||||
tag_full: str
|
||||
negative: str
|
||||
desc_full: str | None = None
|
||||
|
||||
|
||||
def extract_image_prompt_tag(text: str) -> str | None:
|
||||
@@ -44,7 +133,7 @@ SD_UNET = os.getenv("SD_UNET", "")
|
||||
SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip()
|
||||
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
|
||||
|
||||
|
||||
@@ -56,37 +145,201 @@ def _is_anima() -> bool:
|
||||
return bool(SD_UNET) and not SD_CHECKPOINT
|
||||
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
def anima_dual_enabled() -> bool:
|
||||
return _is_anima() and SD_ANIMA_DUAL_COMPARE
|
||||
|
||||
|
||||
def _builder_system() -> str:
|
||||
if _is_anima():
|
||||
return PROMPT_BUILDER_SYSTEM + ANIMA_BUILDER_EXTRA
|
||||
return PROMPT_BUILDER_SYSTEM
|
||||
|
||||
|
||||
def _normalize_shot_type(scene: dict) -> dict:
|
||||
st = (scene.get("shot_type") or "").strip().lower()
|
||||
if st == "landscape":
|
||||
scene["shot_type"] = "landscape"
|
||||
return _sanitize_scene_fields(scene)
|
||||
if st == "third_person":
|
||||
action = (scene.get("action_tags") or "").lower()
|
||||
wide = ("battle", "fight", "chase", "running", "crowd", "wide_shot", "group_shot")
|
||||
if any(w in action for w in wide):
|
||||
scene["shot_type"] = "third_person"
|
||||
return _sanitize_scene_fields(scene)
|
||||
scene["shot_type"] = "first_person_pov"
|
||||
if scene.get("viewer_body_visible") is None:
|
||||
scene["viewer_body_visible"] = False
|
||||
return _sanitize_scene_fields(scene)
|
||||
|
||||
|
||||
def _split_tag_input(tag_str: str) -> list[str]:
|
||||
return [t.strip() for t in (tag_str or "").split(",") if t.strip()]
|
||||
|
||||
|
||||
def _is_sentence_like_tag(tag: str) -> bool:
|
||||
t = tag.strip()
|
||||
if len(t) > 45:
|
||||
return True
|
||||
if re.search(r"[.!?]", t):
|
||||
return True
|
||||
words = t.split()
|
||||
return len(words) >= 5 and "_" not in t
|
||||
|
||||
|
||||
def _filter_tag_field(tag_str: str, *, for_pov: bool, field: str) -> str:
|
||||
kept: list[str] = []
|
||||
for raw in _split_tag_input(tag_str):
|
||||
key = raw.lower().replace(" ", "_")
|
||||
if key in _INVALID_TAGS:
|
||||
continue
|
||||
if _is_sentence_like_tag(raw):
|
||||
continue
|
||||
if for_pov and field == "action" and key in _POV_DROP_ACTION_TAGS:
|
||||
continue
|
||||
kept.append(raw if "_" in raw else key)
|
||||
return ", ".join(kept)
|
||||
|
||||
|
||||
def _reconcile_environment_tags(env_str: str) -> str:
|
||||
tags = _split_tag_input(env_str)
|
||||
keys = {t.lower().replace(" ", "_") for t in tags}
|
||||
has_indoor = bool(keys & _INDOOR_ENV_MARKERS) or any(
|
||||
any(m in k for m in _INDOOR_ENV_MARKERS) for k in keys
|
||||
)
|
||||
has_outdoor = bool(keys & _OUTDOOR_ENV_MARKERS) or any(
|
||||
any(m in k for m in _OUTDOOR_ENV_MARKERS) for k in keys
|
||||
)
|
||||
if has_indoor and has_outdoor:
|
||||
tags = [t for t in tags if t.lower().replace(" ", "_") not in _OUTDOOR_ENV_MARKERS]
|
||||
return ", ".join(tags)
|
||||
|
||||
|
||||
def _sanitize_pov_prose(desc: str, scene: dict) -> str:
|
||||
if not desc or not desc.strip():
|
||||
return ""
|
||||
if scene.get("shot_type") != "first_person_pov":
|
||||
return desc.strip()
|
||||
|
||||
kept: list[str] = []
|
||||
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
|
||||
s = sentence.strip()
|
||||
if not s:
|
||||
continue
|
||||
if _POV_PROSE_BANNED.search(s):
|
||||
continue
|
||||
if re.search(r"\bwolfgirl\b", s, re.I) and re.search(
|
||||
r"\b(walks|walking|stands)\b", s, re.I
|
||||
):
|
||||
continue
|
||||
kept.append(s)
|
||||
out = " ".join(kept).strip()
|
||||
return re.sub(r"\bat the viewer\b", "at the camera", out, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _sanitize_scene_fields(scene: dict) -> dict:
|
||||
scene = dict(scene)
|
||||
for_pov = scene.get("shot_type") == "first_person_pov"
|
||||
scene["action_tags"] = _filter_tag_field(
|
||||
scene.get("action_tags") or "", for_pov=for_pov, field="action"
|
||||
)
|
||||
env = _filter_tag_field(scene.get("environment_tags") or "", for_pov=False, field="env")
|
||||
scene["environment_tags"] = _reconcile_environment_tags(env)
|
||||
scene["scene_description"] = _sanitize_pov_prose(
|
||||
(scene.get("scene_description") or "").strip(), scene
|
||||
)
|
||||
return scene
|
||||
|
||||
|
||||
def _scene_should_generate(scene: dict) -> bool:
|
||||
if scene.get("should_generate") is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _sanitize_tags_string(tag_str: str) -> str:
|
||||
if not tag_str:
|
||||
return ""
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in tag_str.split(","):
|
||||
t = raw.strip()
|
||||
if not t:
|
||||
continue
|
||||
key = t.lower().replace(" ", "_")
|
||||
if key in seen:
|
||||
continue
|
||||
if key in _INVALID_TAGS:
|
||||
continue
|
||||
if "_" not in key and key in _JUNK_STANDALONE_TAGS:
|
||||
continue
|
||||
if len(key) <= 2:
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(t if "_" in t else key)
|
||||
return ", ".join(out)
|
||||
|
||||
|
||||
def _quality_prefix() -> str:
|
||||
if _is_pony():
|
||||
quality = "score_9, score_8_up, score_7_up, source_anime, highres"
|
||||
elif _is_anima():
|
||||
quality = "masterpiece, best quality, score_7, anime"
|
||||
else:
|
||||
quality = "masterpiece, best quality, highres"
|
||||
return "score_9, score_8_up, score_7_up, source_anime, highres"
|
||||
if _is_anima():
|
||||
return "masterpiece, best quality, score_7, anime"
|
||||
return "masterpiece, best quality, highres"
|
||||
|
||||
parts = [quality]
|
||||
|
||||
appearance = (persona or {}).get("appearance_tags", "")
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(outfit_tags)
|
||||
def _appearance_for_persona(persona: dict | None) -> str:
|
||||
"""Tag core uses appearance_tags only (prose is for LLM context, not Comfy tag line)."""
|
||||
return _sanitize_tags_string((persona or {}).get("appearance_tags", ""))
|
||||
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
else:
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(scene.get("action_tags", ""))
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
|
||||
def _dedupe_outfit_tags(outfit_tags: str) -> str:
|
||||
tags = _split_tag_input(outfit_tags)
|
||||
keys = {t.lower().replace(" ", "_") for t in tags}
|
||||
if len(keys & {"jeans", "ripped_jeans", "black_jeans"}) > 1 and "jeans" in keys:
|
||||
tags = [t for t in tags if t.lower().replace(" ", "_") != "jeans"]
|
||||
return ", ".join(tags)
|
||||
|
||||
|
||||
def _scene_has_physical_contact(scene: dict) -> bool:
|
||||
action = (scene.get("action_tags") or "").lower()
|
||||
return any(k in action for k in _CONTACT_ACTION_KEYWORDS)
|
||||
|
||||
|
||||
def _infer_pov_cue_from_action(action_tags: str) -> str:
|
||||
action = (action_tags or "").lower()
|
||||
if any(k in action for k in ("holding_hands", "hand_holding", "walking", "strolling")):
|
||||
return "walking_together"
|
||||
if any(k in action for k in ("doorway", "door", "entry", "threshold")):
|
||||
if any(k in action for k in ("arms_out", "hug", "embrace", "inviting")):
|
||||
return "doorway_invite"
|
||||
if any(k in action for k in ("arms_out", "reaching", "inviting_hug", "hug", "embrace")):
|
||||
return "reach_to_viewer"
|
||||
if any(k in action for k in ("sitting", "lying", "bed")):
|
||||
return "dialogue_close"
|
||||
return "face_to_face"
|
||||
|
||||
|
||||
def _build_pov_phrase(scene: dict) -> str:
|
||||
if scene.get("shot_type") != "first_person_pov":
|
||||
return ""
|
||||
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
if cue in POV_CUE_PHRASES:
|
||||
return POV_CUE_PHRASES[cue]
|
||||
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
|
||||
return POV_CUE_PHRASES.get(inferred, POV_CUE_DEFAULT)
|
||||
|
||||
|
||||
def _append_lora(parts: list[str], persona: dict | None) -> None:
|
||||
lora = (persona or {}).get("lora_name", "")
|
||||
weight = (persona or {}).get("lora_weight", 0.8)
|
||||
if lora:
|
||||
parts.append(f"<lora:{lora}:{weight}>")
|
||||
|
||||
|
||||
def _dedupe_comma_join(parts: list[str]) -> str:
|
||||
positive = ", ".join(p.strip() for p in parts if p and p.strip())
|
||||
seen, deduped = set(), []
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for tag in positive.split(", "):
|
||||
t = tag.strip()
|
||||
if t and t not in seen:
|
||||
@@ -95,53 +348,152 @@ def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str =
|
||||
return ", ".join(deduped)
|
||||
|
||||
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
outfit_json: str = "[]",
|
||||
) -> tuple[str | None, str | None]:
|
||||
persona = await get_persona(persona_id)
|
||||
# Generate only if persona has appearance tags
|
||||
if not persona or not (persona.get("appearance_tags") or "").strip():
|
||||
logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id)
|
||||
return None, None
|
||||
def _build_tag_core(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Anchor + structure: quality, appearance, outfit, action/env tags, LoRA. No POV prose, no scene_description."""
|
||||
parts = [_quality_prefix()]
|
||||
appearance = _appearance_for_persona(persona)
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
else:
|
||||
if not _is_anima() and scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
_append_lora(parts, persona)
|
||||
return _dedupe_comma_join(parts)
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
|
||||
if not recent:
|
||||
return None, None
|
||||
|
||||
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
|
||||
def build_positive_prompt_tags_only(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Tags + contextual POV phrase (Anima) or legacy Pony path."""
|
||||
if not _is_anima():
|
||||
return build_positive_prompt(scene, persona, outfit_tags)
|
||||
core = _build_tag_core(scene, persona, outfit_tags)
|
||||
pov = _build_pov_phrase(scene)
|
||||
if pov:
|
||||
return f"{core}, {pov}" if core else pov
|
||||
return core
|
||||
|
||||
builder_messages = [
|
||||
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
|
||||
{"role": "user", "content": f"Chat:\n{excerpt}"},
|
||||
]
|
||||
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||
raw = re.sub(r"\n?```$", "", raw)
|
||||
scene = json.loads(raw)
|
||||
if not isinstance(scene, dict):
|
||||
logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", ""))
|
||||
return None, None
|
||||
def _tag_tokens_for_dedupe(tag_line: str) -> set[str]:
|
||||
tokens: set[str] = set()
|
||||
for part in tag_line.replace("<lora:", " ").split(","):
|
||||
for word in re.split(r"[\s_./]+", part.lower()):
|
||||
w = word.strip()
|
||||
if len(w) >= 4:
|
||||
tokens.add(w)
|
||||
return tokens
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
def _trim_redundant_scene_description(desc: str, tag_line: str) -> str:
|
||||
tag_tokens = _tag_tokens_for_dedupe(tag_line)
|
||||
if not tag_tokens or not desc.strip():
|
||||
return desc.strip()
|
||||
|
||||
kept: list[str] = []
|
||||
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
|
||||
s = sentence.strip()
|
||||
if not s:
|
||||
continue
|
||||
words = [w.lower() for w in re.findall(r"[a-zA-Z]{4,}", s)]
|
||||
if not words:
|
||||
kept.append(s)
|
||||
continue
|
||||
overlap = sum(1 for w in words if w in tag_tokens) / len(words)
|
||||
if overlap < 0.62:
|
||||
kept.append(s)
|
||||
|
||||
return " ".join(kept).strip()
|
||||
|
||||
|
||||
def _extract_illustrate_content(content: str, max_chars: int = 1400) -> str:
|
||||
"""Long assistant posts (first_mes): use final beat after time-skip, last paragraphs."""
|
||||
text = strip_image_prompt_tag(content).strip()
|
||||
if not text:
|
||||
return ""
|
||||
chunks = _TIME_SKIP_RE.split(text)
|
||||
if len(chunks) > 1:
|
||||
text = chunks[-1].strip()
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
||||
if paragraphs:
|
||||
for n in (1, 2, 3):
|
||||
tail = "\n\n".join(paragraphs[-n:])
|
||||
if len(tail) <= max_chars:
|
||||
return tail
|
||||
return paragraphs[-1][-max_chars:]
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _fallback_mood_prose(scene: dict) -> str:
|
||||
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
if cue in _POV_MOOD_FALLBACK:
|
||||
return _POV_MOOD_FALLBACK[cue]
|
||||
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
|
||||
return _POV_MOOD_FALLBACK.get(inferred, "Soft atmosphere; her expression toward the camera.")
|
||||
|
||||
|
||||
def _cap_scene_description(desc: str, max_words: int = 40, max_chars: int = 220) -> str:
|
||||
words = desc.split()
|
||||
if len(words) > max_words:
|
||||
desc = " ".join(words[:max_words])
|
||||
if len(desc) > max_chars:
|
||||
desc = desc[: max_chars - 3] + "..."
|
||||
return desc
|
||||
|
||||
|
||||
def build_positive_prompt_hybrid(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Production Anima prompt: tag core + POV cue + short mood prose."""
|
||||
if not _is_anima():
|
||||
return build_positive_prompt(scene, persona, outfit_tags)
|
||||
|
||||
base = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
desc = _trim_redundant_scene_description(
|
||||
(scene.get("scene_description") or "").strip(),
|
||||
base,
|
||||
)
|
||||
desc = _cap_scene_description(desc)
|
||||
if not desc:
|
||||
desc = _cap_scene_description(_fallback_mood_prose(scene))
|
||||
if not desc:
|
||||
return base
|
||||
|
||||
lora = (persona or {}).get("lora_name", "")
|
||||
weight = (persona or {}).get("lora_weight", 0.8)
|
||||
lora_suffix = f" <lora:{lora}:{weight}>" if lora else ""
|
||||
if lora_suffix and base.endswith(lora_suffix):
|
||||
base = base[: -len(lora_suffix)]
|
||||
return f"{base}. {desc}{lora_suffix}"
|
||||
return f"{base}. {desc}"
|
||||
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
"""Legacy entry: Pony/non-Anima full prompt; Anima delegates to tags-only."""
|
||||
if _is_anima():
|
||||
return build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
|
||||
parts = [_quality_prefix()]
|
||||
appearance = _appearance_for_persona(persona)
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
else:
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
|
||||
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
|
||||
_append_lora(parts, persona)
|
||||
return _dedupe_comma_join(parts)
|
||||
|
||||
|
||||
def _negative_for_scene(scene: dict) -> str:
|
||||
if _is_pony():
|
||||
negative = PONY_NEGATIVE
|
||||
elif _is_anima():
|
||||
@@ -151,6 +503,237 @@ async def generate_sd_prompt(
|
||||
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
negative += ", third person, over the shoulder"
|
||||
viewer_visible = scene.get("viewer_body_visible") is True
|
||||
if not viewer_visible or _scene_has_physical_contact(scene):
|
||||
negative += ", " + POV_INTERACTION_NEGATIVE
|
||||
|
||||
full = positive + f"\n\nNegative prompt: {negative}"
|
||||
return full, negative
|
||||
return negative
|
||||
|
||||
|
||||
def _format_builder_user_block(
|
||||
persona: dict, messages: list[dict], outfit_json: str, scene_json: str = "{}"
|
||||
) -> str:
|
||||
lines: list[str] = []
|
||||
tags = (persona.get("appearance_tags") or "").strip()
|
||||
lines.append(f"Character appearance (tags): {tags}")
|
||||
prose = (persona.get("appearance_prose") or "").strip()
|
||||
if _is_anima() and prose and prose != tags:
|
||||
snippet = prose[:300] + ("..." if len(prose) > 300 else "")
|
||||
lines.append(f"Character notes (do not copy into tags or scene_description): {snippet}")
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_ref = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_ref = ""
|
||||
|
||||
if outfit_ref:
|
||||
lines.append(f"Current outfit (tags): {outfit_ref}")
|
||||
|
||||
from services.rpg_state import parse_scene_json, scene_to_sd_hint
|
||||
|
||||
scene_hint = scene_to_sd_hint(parse_scene_json(scene_json))
|
||||
if scene_hint:
|
||||
lines.append(f"Scene (location/time):\n{scene_hint}")
|
||||
|
||||
recent = [m for m in messages if m.get("role") in ("user", "assistant")][-6:]
|
||||
if not recent:
|
||||
lines.append("\nChat:\n(no messages — return should_generate=false)")
|
||||
return "\n".join(lines)
|
||||
|
||||
illustrate: list[dict] = []
|
||||
if recent[-1]["role"] == "assistant":
|
||||
illustrate = [recent[-1]]
|
||||
if len(recent) >= 2 and recent[-2]["role"] == "user":
|
||||
illustrate.insert(0, recent[-2])
|
||||
else:
|
||||
illustrate = [recent[-1]]
|
||||
if len(recent) >= 2 and recent[-2]["role"] == "assistant":
|
||||
illustrate.insert(0, recent[-2])
|
||||
|
||||
context = [m for m in recent if m not in illustrate]
|
||||
|
||||
lines.append("\n=== ILLUSTRATE (draw THIS beat only) ===")
|
||||
for m in illustrate:
|
||||
raw = m.get("content", "")
|
||||
content = _extract_illustrate_content(raw) if m.get("role") == "assistant" else strip_image_prompt_tag(raw)
|
||||
lines.append(f"{m['role']}: {content}")
|
||||
|
||||
if context:
|
||||
lines.append("\n=== Context (outfit/location hints only — do not illustrate old beats) ===")
|
||||
for m in context:
|
||||
content = strip_image_prompt_tag(m.get("content", ""))
|
||||
if len(content) > 800:
|
||||
content = content[:797] + "..."
|
||||
lines.append(f"{m['role']}: {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _parse_scene_json(raw: str) -> dict:
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = re.sub(r"^```\w*\n?", "", cleaned)
|
||||
cleaned = re.sub(r"\n?```$", "", cleaned)
|
||||
scene = json.loads(cleaned)
|
||||
if not isinstance(scene, dict):
|
||||
raise ValueError("LLM returned non-object JSON")
|
||||
return _normalize_shot_type(scene)
|
||||
|
||||
|
||||
def _bundle_from_scene(scene: dict, persona: dict, outfit_tags: str) -> SdPromptBundle:
|
||||
negative = _negative_for_scene(scene)
|
||||
if _is_anima():
|
||||
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
|
||||
tag_full = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
desc_full = None
|
||||
if anima_dual_enabled():
|
||||
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
desc_full = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=desc_full)
|
||||
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
tag_full = positive + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=None)
|
||||
|
||||
|
||||
def _parse_chat_excerpt(excerpt: str) -> list[dict]:
|
||||
messages: list[dict] = []
|
||||
for line in (excerpt or "").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
lower = line.lower()
|
||||
if lower.startswith("user:"):
|
||||
messages.append({"role": "user", "content": line[5:].strip()})
|
||||
elif lower.startswith("assistant:"):
|
||||
messages.append({"role": "assistant", "content": line[10:].strip()})
|
||||
elif lower.startswith("system:"):
|
||||
messages.append({"role": "system", "content": line[7:].strip()})
|
||||
else:
|
||||
messages.append({"role": "user", "content": line})
|
||||
return messages
|
||||
|
||||
|
||||
async def run_prompt_builder(
|
||||
persona_id: str,
|
||||
*,
|
||||
messages: list[dict] | None = None,
|
||||
chat_excerpt: str = "",
|
||||
outfit_json: str = "[]",
|
||||
appearance_override: str | None = None,
|
||||
use_prose: bool = False,
|
||||
) -> dict:
|
||||
"""Debug: full SD prompt builder pipeline with LLM raw output."""
|
||||
persona = await get_persona(persona_id) or {}
|
||||
if appearance_override is not None:
|
||||
persona = {**persona, "appearance_tags": appearance_override}
|
||||
|
||||
recent = messages if messages is not None else _parse_chat_excerpt(chat_excerpt)
|
||||
recent = [m for m in recent if m.get("role") in ("user", "assistant")]
|
||||
|
||||
user_block = _format_builder_user_block(persona, recent, outfit_json)
|
||||
builder_messages = [
|
||||
{"role": "system", "content": _builder_system()},
|
||||
{"role": "user", "content": user_block},
|
||||
]
|
||||
model_used = SD_PROMPT_MODEL or "SYSTEM_MODEL"
|
||||
result: dict = {
|
||||
"persona_id": persona_id,
|
||||
"sd_prompt_model": model_used,
|
||||
"builder_system": _builder_system(),
|
||||
"builder_user": user_block,
|
||||
"anima_dual": anima_dual_enabled(),
|
||||
}
|
||||
|
||||
raw = ""
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
result["llm_raw"] = raw
|
||||
scene = _parse_scene_json(raw)
|
||||
result["scene"] = scene
|
||||
|
||||
if not _scene_should_generate(scene):
|
||||
result["skipped"] = True
|
||||
result["error"] = "should_generate=false"
|
||||
return result
|
||||
|
||||
try:
|
||||
outfit_tags = ", ".join(json.loads(outfit_json or "[]"))
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
negative = _negative_for_scene(scene)
|
||||
if _is_anima():
|
||||
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
|
||||
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
|
||||
result["tag_positive"] = tags_only
|
||||
result["hybrid_positive"] = hybrid
|
||||
result["negative"] = negative
|
||||
result["tags_only_full"] = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
result["hybrid_full"] = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
result["tag_full"] = result["hybrid_full"]
|
||||
else:
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
result["tag_positive"] = positive
|
||||
result["negative"] = negative
|
||||
result["tag_full"] = positive + NEGATIVE_PROMPT_SEPARATOR + negative
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
result["llm_raw"] = raw or result.get("llm_raw", "")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
outfit_json: str = "[]",
|
||||
scene_json: str = "{}",
|
||||
) -> SdPromptBundle | None:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return None
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")]
|
||||
if not recent:
|
||||
return None
|
||||
|
||||
user_block = _format_builder_user_block(persona, recent, outfit_json, scene_json)
|
||||
builder_messages = [
|
||||
{"role": "system", "content": _builder_system()},
|
||||
{"role": "user", "content": user_block},
|
||||
]
|
||||
|
||||
raw = ""
|
||||
try:
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
scene = _parse_scene_json(raw)
|
||||
except Exception as e:
|
||||
logger.warning("sd_prompt failed: %s raw=%.200s", e, raw)
|
||||
return None
|
||||
|
||||
if not _scene_should_generate(scene):
|
||||
logger.info("sd_prompt: skipped (should_generate=false)")
|
||||
return None
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
bundle = _bundle_from_scene(scene, persona, outfit_tags)
|
||||
if anima_dual_enabled() and bundle.desc_full:
|
||||
logger.info(
|
||||
"Anima prompts: hybrid=%.80s | tags_only=%.80s",
|
||||
bundle.tag_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
|
||||
bundle.desc_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
|
||||
)
|
||||
return bundle
|
||||
|
||||
+323
-24
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
@@ -11,7 +12,178 @@ load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/")
|
||||
|
||||
def _parse_basic_auth() -> httpx.BasicAuth | None:
|
||||
"""
|
||||
Vast Caddy on mapped ports often uses Basic realm=restricted.
|
||||
Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD.
|
||||
"""
|
||||
raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip()
|
||||
if raw:
|
||||
if ":" in raw:
|
||||
user, _, password = raw.partition(":")
|
||||
else:
|
||||
user, password = "", raw
|
||||
return httpx.BasicAuth(user, password)
|
||||
user = (os.getenv("SD_COMFY_USER") or "").strip()
|
||||
password = (os.getenv("SD_COMFY_PASSWORD") or "").strip()
|
||||
if user or password:
|
||||
return httpx.BasicAuth(user, password)
|
||||
return None
|
||||
|
||||
|
||||
SD_BASIC_AUTH = _parse_basic_auth()
|
||||
|
||||
|
||||
def _parse_comfy_config() -> tuple[str, dict[str, str]]:
|
||||
"""
|
||||
SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=...
|
||||
API paths must be base + /prompt, not ...?token=xxx/prompt
|
||||
"""
|
||||
raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip()
|
||||
extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip()
|
||||
parsed = urlparse(raw)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path and path != "/":
|
||||
base = f"{base}{path}"
|
||||
query: dict[str, str] = {}
|
||||
for key, values in parse_qs(parsed.query).items():
|
||||
if values:
|
||||
query[key] = values[-1]
|
||||
if extra_token:
|
||||
query["token"] = extra_token
|
||||
base = base.rstrip("/")
|
||||
# Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply
|
||||
if "trycloudflare.com" in base.lower():
|
||||
if query.pop("token", None):
|
||||
logger.info(
|
||||
"SD_BASE_URL is trycloudflare tunnel: Vast token stripped. "
|
||||
"Use tunnel for port 8188 only (see instance Port Mapping)."
|
||||
)
|
||||
return base, query
|
||||
|
||||
|
||||
SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config()
|
||||
|
||||
|
||||
def _comfy_url(path: str) -> str:
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{SD_BASE_URL}{path}"
|
||||
|
||||
|
||||
def _log_comfy_target() -> str:
|
||||
if SD_QUERY_PARAMS.get("token"):
|
||||
return f"{SD_BASE_URL}?token=***"
|
||||
return SD_BASE_URL
|
||||
|
||||
|
||||
def _absolute_url(location: str, fallback_path: str = "/") -> str:
|
||||
if not location:
|
||||
return _comfy_url(fallback_path)
|
||||
if location.startswith(("http://", "https://")):
|
||||
return location
|
||||
if location.startswith("/"):
|
||||
return f"{SD_BASE_URL}{location}"
|
||||
return f"{SD_BASE_URL}/{location}"
|
||||
|
||||
|
||||
def _url_with_token(url: str) -> str:
|
||||
"""Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect)."""
|
||||
if not SD_QUERY_PARAMS.get("token"):
|
||||
return url
|
||||
p = urlparse(url)
|
||||
q: dict[str, str] = {}
|
||||
for key, values in parse_qs(p.query).items():
|
||||
if values:
|
||||
q[key] = values[-1]
|
||||
q.update(SD_QUERY_PARAMS)
|
||||
return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), ""))
|
||||
|
||||
|
||||
def _merge_params(extra: dict | None) -> dict | None:
|
||||
if not SD_QUERY_PARAMS and not extra:
|
||||
return None
|
||||
merged = dict(SD_QUERY_PARAMS)
|
||||
if extra:
|
||||
merged.update(extra)
|
||||
return merged
|
||||
|
||||
|
||||
def _is_vast_gateway() -> bool:
|
||||
return "trycloudflare.com" not in SD_BASE_URL.lower()
|
||||
|
||||
|
||||
def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=False,
|
||||
auth=SD_BASIC_AUTH,
|
||||
)
|
||||
|
||||
|
||||
async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None:
|
||||
"""
|
||||
Vast Caddy: browser opens /?token=… and gets a session cookie; API then works.
|
||||
Prime with redirects so Set-Cookie is collected, then merge into the API client.
|
||||
"""
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
if not token or not _is_vast_gateway():
|
||||
return
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30,
|
||||
follow_redirects=True,
|
||||
auth=SD_BASIC_AUTH,
|
||||
) as prime:
|
||||
r = await prime.get(_comfy_url("/"), params={"token": token})
|
||||
client.cookies.update(prime.cookies)
|
||||
logger.info(
|
||||
"Comfy gateway prime GET /?token=*** → %s, cookies=%s",
|
||||
r.status_code,
|
||||
list(prime.cookies.keys()) or "(none)",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Comfy gateway prime failed: %s", e)
|
||||
|
||||
|
||||
async def _comfy_request(
|
||||
client: httpx.AsyncClient,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Comfy API: trycloudflare tunnel = no token.
|
||||
Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached.
|
||||
"""
|
||||
url = _comfy_url(path)
|
||||
extra = params or {}
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
|
||||
if token and _is_vast_gateway():
|
||||
await _prime_comfy_gateway(client)
|
||||
|
||||
req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None)
|
||||
resp: httpx.Response | None = None
|
||||
|
||||
for hop in range(6):
|
||||
resp = await client.request(method, url, params=req_params, **kwargs)
|
||||
if resp.status_code not in (301, 302, 303, 307, 308):
|
||||
return resp
|
||||
loc = _absolute_url(resp.headers.get("location", ""), path)
|
||||
url = _url_with_token(loc) if use_vast_auth else loc
|
||||
req_params = extra or None
|
||||
logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0])
|
||||
|
||||
assert resp is not None
|
||||
return resp
|
||||
|
||||
|
||||
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
||||
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
||||
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
||||
@@ -26,6 +198,8 @@ SD_DEFAULT_NEGATIVE = os.getenv(
|
||||
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
|
||||
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
|
||||
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
|
||||
SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "")
|
||||
SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7"))
|
||||
|
||||
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
||||
|
||||
@@ -38,20 +212,38 @@ def _use_anima() -> bool:
|
||||
|
||||
|
||||
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
# Try new separator first
|
||||
sep = "__NEGATIVE_PROMPT__"
|
||||
if f"\n{sep}\n" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition(f"\n{sep}\n")
|
||||
return pos.strip(), neg.strip()
|
||||
# Fallback to old format
|
||||
if "\n\nNegative prompt:" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
|
||||
return pos.strip(), neg.strip()
|
||||
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str) -> dict:
|
||||
def _workflow_uses_anima(overrides: dict | None) -> bool:
|
||||
if overrides and overrides.get("checkpoint"):
|
||||
return False
|
||||
if overrides and overrides.get("unet"):
|
||||
return True
|
||||
return _use_anima()
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict:
|
||||
seed = int(uuid.uuid4().int % 2**32)
|
||||
if _use_anima():
|
||||
return {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}},
|
||||
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
||||
o = overrides or {}
|
||||
if _workflow_uses_anima(o):
|
||||
unet = o.get("unet") or SD_UNET
|
||||
clip = o.get("clip") or SD_CLIP
|
||||
vae = o.get("vae") or SD_VAE
|
||||
workflow = {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}},
|
||||
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 720, "batch_size": 1}},
|
||||
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
|
||||
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
|
||||
"19": {
|
||||
@@ -68,9 +260,24 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
}
|
||||
# Standard checkpoint workflow (Pony / SDXL)
|
||||
if SD_STYLE_LORA:
|
||||
workflow["46"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": SD_STYLE_LORA,
|
||||
"model": ["44", 0],
|
||||
"clip": ["45", 0],
|
||||
"strength_model": SD_STYLE_LORA_WEIGHT,
|
||||
"strength_clip": SD_STYLE_LORA_WEIGHT,
|
||||
},
|
||||
}
|
||||
workflow["19"]["inputs"]["model"] = ["46", 0]
|
||||
workflow["11"]["inputs"]["clip"] = ["46", 1]
|
||||
workflow["12"]["inputs"]["clip"] = ["46", 1]
|
||||
return workflow
|
||||
ckpt = o.get("checkpoint") or SD_CHECKPOINT
|
||||
return {
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
@@ -89,24 +296,78 @@ def _build_workflow(positive: str, negative: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def comfy_api_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
timeout: float = 60,
|
||||
) -> tuple[int, dict | str, dict]:
|
||||
"""
|
||||
Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset).
|
||||
"""
|
||||
async with _make_comfy_client(timeout=timeout) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
token = SD_QUERY_PARAMS.get("token")
|
||||
use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
|
||||
req_params = _merge_params(params) if use_vast else (params or None)
|
||||
req_kwargs: dict = {}
|
||||
if json_body is not None and method.upper() not in ("GET", "HEAD"):
|
||||
req_kwargs["json"] = json_body
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
method.upper(),
|
||||
path,
|
||||
params=req_params,
|
||||
**req_kwargs,
|
||||
)
|
||||
headers = {
|
||||
k: resp.headers.get(k)
|
||||
for k in ("content-type", "location", "www-authenticate")
|
||||
if resp.headers.get(k)
|
||||
}
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
body = resp.text[:8000]
|
||||
return resp.status_code, body, headers
|
||||
|
||||
|
||||
async def fetch_object_info() -> dict:
|
||||
status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120)
|
||||
if status != 200 or not isinstance(body, dict):
|
||||
raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}")
|
||||
return body
|
||||
|
||||
|
||||
async def check_sd() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{SD_BASE_URL}/system_stats")
|
||||
async with _make_comfy_client(timeout=15) as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
r = await _comfy_request(client, "GET", "/system_stats")
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]:
|
||||
async def txt2img(
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[bytes, str]:
|
||||
neg = negative_prompt or SD_DEFAULT_NEGATIVE
|
||||
workflow = _build_workflow(prompt, neg)
|
||||
workflow = _build_workflow(prompt, neg, overrides)
|
||||
client_id = uuid.uuid4().hex
|
||||
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
resp = await client.post(
|
||||
f"{SD_BASE_URL}/prompt",
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt)
|
||||
async with _make_comfy_client() as client:
|
||||
await _prime_comfy_gateway(client)
|
||||
resp = await _comfy_request(
|
||||
client,
|
||||
"POST",
|
||||
"/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -115,7 +376,7 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
|
||||
for _ in range(300):
|
||||
await asyncio.sleep(1)
|
||||
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
|
||||
hist = await _comfy_request(client, "GET", f"/history/{prompt_id}")
|
||||
data = hist.json()
|
||||
if prompt_id in data:
|
||||
entry = data[prompt_id]
|
||||
@@ -127,9 +388,15 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
for node_output in outputs.values():
|
||||
if "images" in node_output:
|
||||
img_info = node_output["images"][0]
|
||||
img_resp = await client.get(
|
||||
f"{SD_BASE_URL}/view",
|
||||
params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")},
|
||||
img_resp = await _comfy_request(
|
||||
client,
|
||||
"GET",
|
||||
"/view",
|
||||
params={
|
||||
"filename": img_info["filename"],
|
||||
"subfolder": img_info.get("subfolder", ""),
|
||||
"type": img_info.get("type", "output"),
|
||||
},
|
||||
)
|
||||
img_resp.raise_for_status()
|
||||
image_bytes = img_resp.content
|
||||
@@ -145,11 +412,43 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
||||
|
||||
|
||||
async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]:
|
||||
async def generate_from_full_prompt(
|
||||
full_prompt: str,
|
||||
*,
|
||||
overrides: dict | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
positive, negative = split_prompt_and_negative(full_prompt)
|
||||
try:
|
||||
_, rel_path = await txt2img(positive, negative)
|
||||
_, rel_path = await txt2img(positive, negative, overrides=overrides)
|
||||
return rel_path, None
|
||||
except httpx.HTTPStatusError as e:
|
||||
code = e.response.status_code
|
||||
if code == 401:
|
||||
logger.error(
|
||||
"ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) "
|
||||
"and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. "
|
||||
"Test: curl -u user:pass http://IP:PORT/system_stats "
|
||||
"or open /?token=… in browser then curl with cookies. "
|
||||
"Alternative: trycloudflare URL for localhost:8188 in Port Mapping."
|
||||
)
|
||||
elif code in (301, 302, 303, 307, 308):
|
||||
logger.error(
|
||||
"ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. "
|
||||
"SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com "
|
||||
"(no ?token=). Location: %s",
|
||||
code,
|
||||
e.response.headers.get("location"),
|
||||
)
|
||||
else:
|
||||
logger.error("ComfyUI HTTP %s: %s", code, e)
|
||||
return None, str(e)
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(
|
||||
"ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. "
|
||||
"Use trycloudflare URL from Port Mapping for localhost:8188.",
|
||||
e,
|
||||
)
|
||||
return None, str(e)
|
||||
except Exception as e:
|
||||
logger.error("ComfyUI error: %s", e)
|
||||
return None, str(e)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from services.memory import get_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_session_persona(
|
||||
session_id: str,
|
||||
requested: str | None = None,
|
||||
*,
|
||||
create_persona: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Session.persona_id is the source of truth.
|
||||
requested is ignored when it disagrees (logged). create_persona used only if session missing.
|
||||
"""
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
return (create_persona or requested or "default").strip() or "default"
|
||||
|
||||
bound = (session.get("persona_id") or "default").strip() or "default"
|
||||
req = (requested or "").strip()
|
||||
if req and req != bound:
|
||||
logger.warning(
|
||||
"persona_id mismatch session=%s bound=%s requested=%s (using bound)",
|
||||
session_id,
|
||||
bound,
|
||||
req,
|
||||
)
|
||||
return bound
|
||||
@@ -0,0 +1,21 @@
|
||||
import logging
|
||||
|
||||
from services.chat_prompt import get_system_prompt
|
||||
from services.memory import get_all_sessions, get_history, upsert_static_system_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def migrate_static_system_messages() -> int:
|
||||
"""Rebuild stored system rows from sessions.persona_id (strip legacy RPG text)."""
|
||||
updated = 0
|
||||
for session in await get_all_sessions():
|
||||
sid = session["session_id"]
|
||||
persona_id = session.get("persona_id") or "default"
|
||||
history = await get_history(sid)
|
||||
static = await get_system_prompt(persona_id, history, "")
|
||||
if await upsert_static_system_message(sid, static, history):
|
||||
updated += 1
|
||||
if updated:
|
||||
logger.info("Migrated %s session system message(s) to static persona prompt", updated)
|
||||
return updated
|
||||
Reference in New Issue
Block a user