Fixed SD RPG

This commit is contained in:
2026-06-04 08:05:06 +03:00
parent d4cd8f02f4
commit 6189a5fb74
62 changed files with 6969 additions and 552 deletions
+17 -11
View File
@@ -45,6 +45,7 @@ def parse_card_v2(data: dict, card_id: str | None = None) -> dict:
"first_mes": inner.get("first_mes", ""),
"mes_example": inner.get("mes_example", ""),
"appearance_tags": _extract_appearance(inner),
"appearance_prose": "",
"lorebook_json": json.dumps(entries, ensure_ascii=False),
"alternate_greetings": alternates,
"alternate_greetings_json": json.dumps(alternates, ensure_ascii=False),
@@ -120,6 +121,8 @@ def parse_png_card(file_bytes: bytes) -> dict | None:
def build_system_prompt(card: dict) -> str:
from services.chat_prompt import ROLEPLAY_GUARDRAILS
parts = [
f"You are {card['name']}. Stay in character.",
f"Description: {card['description']}",
@@ -129,6 +132,7 @@ def build_system_prompt(card: dict) -> str:
if card.get("mes_example"):
parts.append(f"Example dialogue:\n{card['mes_example']}")
parts.append("Reply only as the character. Do not add image tags.")
parts.append(ROLEPLAY_GUARDRAILS)
return "\n\n".join(p for p in parts if p.split(": ", 1)[-1].strip())
@@ -141,13 +145,13 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""INSERT OR REPLACE INTO characters
(card_id, name, description, personality, scenario, first_mes,
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json,
avatar_path, alternate_greetings_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
"""INSERT INTO characters
(card_id, name, description, personality, scenario, first_mes, mes_example,
raw_json, lora_name, lora_weight, appearance_tags, appearance_prose, lorebook_json, avatar_path,
alternate_greetings_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
card_id,
card["card_id"],
card["name"],
card["description"],
card["personality"],
@@ -157,10 +161,11 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
card["raw_json"],
lora_name,
lora_weight,
card.get("appearance_tags", ""),
card["appearance_tags"],
card.get("appearance_prose", ""),
card["lorebook_json"],
card.get("avatar_path", ""),
alt_json,
card.get("alternate_greetings_json", "[]"),
),
)
await db.commit()
@@ -199,8 +204,8 @@ async def delete_character(card_id: str) -> bool:
async def update_appearance_tags(card_id: str, appearance_tags: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE characters SET appearance_tags = ? WHERE card_id = ?",
(appearance_tags, card_id),
"UPDATE characters SET appearance_tags = ?, appearance_prose = ? WHERE card_id = ?",
(appearance_tags, "", card_id),
)
await db.commit()
@@ -228,7 +233,7 @@ async def preview_card_file(content: bytes, filename: str) -> dict:
async def update_character(card_id: str, fields: dict) -> bool:
allowed = {"name", "description", "personality", "scenario", "first_mes",
"mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path",
"mes_example", "appearance_tags", "appearance_prose", "lora_name", "lora_weight", "avatar_path",
"alternate_greetings_json"}
updates = {k: v for k, v in fields.items() if k in allowed}
if not updates:
@@ -295,6 +300,7 @@ async def import_card_file(
"lora_name": lora_name,
"lora_weight": lora_weight,
"appearance_tags": saved.get("appearance_tags", ""),
"appearance_prose": saved.get("appearance_prose", ""),
"avatar_path": saved.get("avatar_path", ""),
"personality": saved.get("personality", ""),
"scenario": saved.get("scenario", ""),
+30
View File
@@ -0,0 +1,30 @@
from services.personas import get_persona
from services.lorebook import get_lorebook_context
from services.character_card import get_character
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
from services.rp_sanitize import ROLEPLAY_GUARDRAILS # noqa: E402 — re-export for imports
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
"""Static character prompt only (no RPG runtime blocks)."""
persona = await get_persona(persona_id)
if not persona:
return DEFAULT_PROMPT
prompt = persona["prompt"]
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
context = recent + [{"role": "user", "content": user_message}]
if persona.get("lorebook_json"):
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
if persona_id.startswith("card_"):
card = await get_character(persona_id[5:])
if card:
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
if persona_id != "default":
prompt += "\n\n" + ROLEPLAY_GUARDRAILS
return prompt
+40
View File
@@ -0,0 +1,40 @@
"""Parse ComfyUI /object_info into usable model lists."""
from __future__ import annotations
# Node types whose combo inputs we expose in the debug UI
_MODEL_NODES: dict[str, tuple[str, str]] = {
"checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
"unets": ("UNETLoader", "unet_name"),
"clips": ("CLIPLoader", "clip_name"),
"vaes": ("VAELoader", "vae_name"),
"loras": ("LoraLoader", "lora_name"),
}
def _combo_options(node_def: dict, input_name: str) -> list[str]:
if not isinstance(node_def, dict):
return []
required = (node_def.get("input") or {}).get("required") or {}
optional = (node_def.get("input") or {}).get("optional") or {}
spec = required.get(input_name) or optional.get(input_name)
if not spec or not isinstance(spec, (list, tuple)):
return []
first = spec[0]
if isinstance(first, list):
return [str(x) for x in first]
return []
def parse_model_lists(object_info: dict) -> dict[str, list[str]]:
out: dict[str, list[str]] = {}
for key, (node_type, input_name) in _MODEL_NODES.items():
node_def = object_info.get(node_type) or {}
options = _combo_options(node_def, input_name)
if options:
out[key] = options
return out
def list_node_types(object_info: dict) -> list[str]:
return sorted(k for k in object_info.keys() if isinstance(object_info.get(k), dict))
+30
View File
@@ -0,0 +1,30 @@
import os
CHAT_CONTEXT_MAX = int(os.getenv("CHAT_CONTEXT_MAX", "128000"))
def estimate_tokens(text: str) -> int:
return max(0, len(text or "") // 4)
def compute_payload_usage(history: list, llm_system: str) -> dict:
"""Estimate context fill for the payload messages_for_llm would send."""
chars = len(llm_system or "")
for m in history:
if m.get("role") in ("user", "assistant"):
chars += len(m.get("content") or "")
tokens_est = chars // 4 if chars else 0
max_tokens = CHAT_CONTEXT_MAX
percent = round(100.0 * tokens_est / max_tokens, 1) if max_tokens else 0.0
return {
"chars": chars,
"tokens_est": tokens_est,
"max_tokens_est": max_tokens,
"percent": percent,
}
def context_warning_line(percent: float) -> str:
if percent <= 85:
return ""
return f"\n[Context: ~{int(percent)}% of budget — keep replies focused]"
+119 -6
View File
@@ -13,6 +13,8 @@ OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
# Softer model when primary returns content_filter / empty / API errors (default: CHAT_MODEL).
LLM_FALLBACK_MODEL = (os.getenv("LLM_FALLBACK_MODEL") or "").strip() or CHAT_MODEL
HEADERS = {
"Authorization": f"Bearer {OPENROUTER_KEY}",
@@ -21,26 +23,128 @@ HEADERS = {
}
class LLMError(Exception):
"""OpenRouter returned an error or an unexpected response shape."""
def _parse_completion_body(data: dict) -> str:
if not isinstance(data, dict):
raise LLMError(f"Invalid API response: expected object, got {type(data).__name__}")
if data.get("error"):
err = data["error"]
if isinstance(err, dict):
msg = err.get("message") or str(err)
code = err.get("code")
else:
msg = str(err)
code = None
suffix = f" (code={code})" if code is not None else ""
raise LLMError(f"OpenRouter error{suffix}: {msg}")
choices = data.get("choices")
if not choices:
preview = str(data)[:400]
raise LLMError(f"OpenRouter response has no 'choices'. Body preview: {preview}")
first = choices[0] if isinstance(choices[0], dict) else {}
message = first.get("message") or {}
if not isinstance(message, dict):
raise LLMError("OpenRouter choice has no message object")
finish = first.get("finish_reason") or ""
native_finish = first.get("native_finish_reason") or ""
blocked_reasons = {"content_filter", "safety", "moderation"}
if finish in blocked_reasons or str(native_finish).upper() in (
"PROHIBITED_CONTENT",
"SAFETY",
"BLOCKED",
):
raise LLMError(
f"Content blocked by provider (finish_reason={finish}, native={native_finish})"
)
content = message.get("content")
if content is not None and str(content).strip():
return str(content)
refusal = message.get("refusal")
if refusal:
raise LLMError(f"Model refused the request: {refusal}")
if finish and finish not in ("stop", "length", "tool_calls", "function_call"):
raise LLMError(
f"OpenRouter finished without content (finish_reason={finish}, native={native_finish})"
)
raise LLMError("OpenRouter returned empty message content")
def _clean(messages: list) -> list:
"""Filter out messages with empty content."""
return [m for m in messages if (m.get("content") or "").strip()]
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
async def _post_once(model: str, messages: list, extra: dict | None = None) -> str:
if not OPENROUTER_KEY:
raise LLMError("ROUTER_KEY is not set in environment")
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
async with httpx.AsyncClient(timeout=90) as client:
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
try:
data = r.json()
except Exception as e:
raise LLMError(f"Non-JSON response (HTTP {r.status_code}): {r.text[:300]}") from e
if r.status_code >= 400:
try:
_parse_completion_body(data)
except LLMError:
raise
raise LLMError(f"HTTP {r.status_code}: {data}")
try:
return _parse_completion_body(data)
except LLMError:
logger.warning(
"OpenRouter completion failed model=%s status=%s body=%.500s",
model,
r.status_code,
data,
)
raise
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
"""POST completion; on failure retries once with LLM_FALLBACK_MODEL (usually CHAT_MODEL)."""
try:
return await _post_once(model, messages, extra)
except LLMError as primary_err:
fallback = LLM_FALLBACK_MODEL
if not fallback or fallback == model:
raise
logger.info(
"LLM fallback: %s failed (%s) → retrying with %s",
model,
primary_err,
fallback,
)
try:
return await _post_once(fallback, messages, extra)
except LLMError as fallback_err:
raise LLMError(
f"{primary_err} (fallback {fallback} also failed: {fallback_err})"
) from fallback_err
async def send_message(messages: list) -> str:
"""System model — narrator, facts, SD prompt."""
"""SYSTEM_MODEL with automatic fallback to LLM_FALLBACK_MODEL."""
return await _post(SYSTEM_MODEL, messages)
async def send_message_with_model(messages: list, model: str) -> str:
"""Explicit model — plot arc, narrator override."""
"""Named model (RPG_*, SD_*) with automatic fallback to LLM_FALLBACK_MODEL."""
return await _post(model, messages)
@@ -73,10 +177,19 @@ async def stream_message(messages: list):
return
try:
chunk = json.loads(data)
content = chunk["choices"][0]["delta"].get("content", "")
if chunk.get("error"):
err = chunk["error"]
msg = err.get("message", err) if isinstance(err, dict) else err
raise LLMError(f"OpenRouter stream error: {msg}")
choices = chunk.get("choices") or []
if not choices:
continue
content = (choices[0].get("delta") or {}).get("content", "")
if content:
chunk_count += 1
yield content
except LLMError:
raise
except Exception:
continue
except Exception as e:
+308 -26
View File
@@ -1,8 +1,11 @@
import json
import aiosqlite
from database.db import DB_PATH
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict:
"""Existing sessions keep their persona_id; persona_id applies only on INSERT."""
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
@@ -13,9 +16,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") ->
if row:
return dict(row)
pid = (persona_id or "default").strip() or "default"
await db.execute(
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
(session_id, persona_id),
(session_id, pid),
)
await db.commit()
@@ -71,24 +75,104 @@ async def update_session_persona(session_id: str, persona_id: str):
(persona_id, session_id),
)
# If persona changed, reset RPG state bound to the persona/arc.
if prev is not None and prev != persona_id:
await db.execute(
"""UPDATE sessions
SET facts_json = '[]',
global_plot = '',
status_quo = '',
plot_arc_json = '{}'
WHERE session_id = ?""",
(session_id,),
)
await db.execute(
"DELETE FROM action_resolutions WHERE session_id = ?",
(session_id,),
)
await _reset_persona_bound_state(db, session_id)
await db.commit()
async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None:
from services.rpg_state import DEFAULT_NARRATIVE_STATS
stats_default = json.dumps(DEFAULT_NARRATIVE_STATS, ensure_ascii=False)
await db.execute(
"""UPDATE sessions
SET facts_json = '[]',
global_plot = '',
status_quo = '',
plot_arc_json = '{}',
outfit_json = '[]',
affinity = 0,
scene_json = '{}',
narrative_stats_json = ?
WHERE session_id = ?""",
(stats_default, session_id),
)
await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,))
await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,))
async def upsert_static_system_message(
session_id: str, static_prompt: str, history: list | None = None
) -> bool:
"""Store only static persona prompt in messages. Returns True if written."""
hist = history if history is not None else await get_history(session_id)
async with aiosqlite.connect(DB_PATH) as db:
if not hist:
await db.execute(
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
VALUES (?, 'system', ?, NULL, NULL)""",
(session_id, static_prompt),
)
await db.execute(
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(session_id,),
)
await db.commit()
return True
if hist[0]["role"] == "system":
if hist[0]["content"] == static_prompt:
return False
await db.execute(
"""UPDATE messages SET content = ?
WHERE session_id = ? AND role = 'system'
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
(static_prompt, session_id, session_id),
)
await db.commit()
return True
await db.execute(
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
VALUES (?, 'system', ?, NULL, NULL)""",
(session_id, static_prompt),
)
await db.commit()
return True
async def delete_dialog_messages(session_id: str) -> None:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')",
(session_id,),
)
await db.commit()
async def rebind_session_persona(
session_id: str,
persona_id: str,
*,
clear_history: bool = False,
static_prompt: str,
first_mes: str | None = None,
) -> None:
session = await get_session(session_id)
if not session:
raise ValueError("Session not found")
await update_session_persona(session_id, persona_id)
if clear_history:
await delete_dialog_messages(session_id)
history = await get_history(session_id)
await upsert_static_system_message(session_id, static_prompt, history)
if clear_history and first_mes and first_mes.strip():
await add_message(session_id, "assistant", first_mes.strip())
async def update_session_rpg(session_id: str, rpg_enabled: bool):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
@@ -174,25 +258,116 @@ async def delete_session(session_id: str):
await db.commit()
async def get_history(session_id: str) -> list:
async def get_action_resolutions_map(session_id: str) -> dict[int, dict]:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""SELECT id, role, content, image_prompt, image_path
"""SELECT message_id, intent_text, roll, outcome, resolution_text
FROM action_resolutions
WHERE session_id = ? AND message_id IS NOT NULL
ORDER BY id""",
(session_id,),
) as cur:
rows = await cur.fetchall()
out: dict[int, dict] = {}
for r in rows:
mid = r["message_id"]
if mid is not None:
out[int(mid)] = {
"intent_text": r["intent_text"],
"roll": r["roll"],
"outcome": r["outcome"],
"resolution_text": r["resolution_text"],
}
return out
def narrator_message_content(narrator: dict) -> str:
return json.dumps(
{
"roll": narrator.get("roll"),
"outcome": narrator.get("outcome"),
"text": narrator.get("text", ""),
"original_intent": narrator.get("original_intent"),
},
ensure_ascii=False,
)
def parse_narrator_message(content: str) -> dict | None:
try:
data = json.loads(content or "{}")
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict) or not (data.get("text") or "").strip():
return None
return data
async def seed_quests_from_arc(session_id: str, arc: dict) -> int:
"""Create active quests for arc beats that are not already in rpg_quests."""
if not arc:
return 0
existing = {q["title"] for q in await get_quests(session_id)}
added = 0
for beat in arc.get("beats", []):
title = (beat.get("title") or beat.get("injection", "")).strip()[:120]
if title and title not in existing:
await upsert_quest(session_id, title, "active")
existing.add(title)
added += 1
return added
async def get_history(session_id: str) -> list:
resolutions = await get_action_resolutions_map(session_id)
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""SELECT id, role, content, image_prompt, image_path,
image_prompt_alt, image_path_alt, choices_json
FROM messages WHERE session_id = ? ORDER BY id""",
(session_id,),
) as cursor:
rows = await cursor.fetchall()
return [
{
result = []
for idx, r in enumerate(rows):
item = {
"id": r["id"],
"role": r["role"],
"content": r["content"],
"image_prompt": r["image_prompt"],
"image_path": r["image_path"],
"image_prompt_alt": r["image_prompt_alt"],
"image_path_alt": r["image_path_alt"],
"choices_json": r["choices_json"],
}
for r in rows
]
if r["role"] == "user" and r["id"] in resolutions:
item["action_resolution"] = resolutions[r["id"]]
result.append(item)
if r["role"] == "user" and r["id"] in resolutions:
nxt = rows[idx + 1] if idx + 1 < len(rows) else None
if not nxt or nxt["role"] != "narrator":
res = resolutions[r["id"]]
result.append(
{
"id": -int(r["id"]),
"role": "narrator",
"content": narrator_message_content(
{
"roll": res.get("roll"),
"outcome": res.get("outcome"),
"text": res.get("resolution_text", ""),
}
),
"image_prompt": None,
"image_path": None,
"image_prompt_alt": None,
"image_path_alt": None,
"choices_json": None,
}
)
return result
async def get_message(message_id: int) -> dict | None:
@@ -230,6 +405,38 @@ async def delete_message(message_id: int):
await db.commit()
async def delete_message_and_following(session_id: str, message_id: int) -> bool:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"DELETE FROM messages WHERE session_id = ? AND id >= ?",
(session_id, message_id),
)
await db.execute(
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(session_id,),
)
await db.commit()
return True
async def update_message_choices(message_id: int, choices_json: str | None):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET choices_json = ? WHERE id = ?",
(choices_json, message_id),
)
await db.commit()
async def clear_choices_for_session(session_id: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET choices_json = NULL WHERE session_id = ?",
(session_id,),
)
await db.commit()
async def get_last_message_preview(session_id: str, max_len: int = 80) -> str:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
@@ -261,8 +468,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
await db.execute(
"""INSERT INTO sessions
(session_id, persona_id, title, rpg_enabled, facts_json, global_plot,
status_quo, plot_arc_json, genre, rpg_settings_json, affinity)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
status_quo, plot_arc_json, genre, rpg_settings_json, affinity,
outfit_json, scene_json, narrative_stats_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
new_id,
source["persona_id"],
@@ -275,6 +483,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
source.get("genre", "adventure"),
source.get("rpg_settings_json", "{}"),
source.get("affinity", 0),
source.get("outfit_json", "[]"),
source.get("scene_json", "{}"),
source.get("narrative_stats_json", '{"lust":0,"stamina":10,"tension":0}'),
),
)
async with db.execute(
@@ -309,18 +520,20 @@ async def add_message(
content: str,
image_prompt: str | None = None,
image_path: str | None = None,
):
) -> int:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
cur = await db.execute(
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
VALUES (?, ?, ?, ?, ?)""",
(session_id, role, content, image_prompt, image_path),
)
msg_id = cur.lastrowid
await db.execute(
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(session_id,),
)
await db.commit()
return msg_id
async def update_message_image(message_id: int, image_path: str):
@@ -332,6 +545,33 @@ async def update_message_image(message_id: int, image_path: str):
await db.commit()
async def update_message_prompt(message_id: int, image_prompt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_prompt = ? WHERE id = ?",
(image_prompt, message_id),
)
await db.commit()
async def update_message_prompt_alt(message_id: int, image_prompt_alt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_prompt_alt = ? WHERE id = ?",
(image_prompt_alt, message_id),
)
await db.commit()
async def update_message_image_alt(message_id: int, image_path_alt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_path_alt = ? WHERE id = ?",
(image_path_alt, message_id),
)
await db.commit()
async def get_last_assistant_message_id(session_id: str) -> int | None:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
@@ -362,6 +602,18 @@ async def update_session_affinity(session_id: str, delta: int):
await db.commit()
async def set_session_affinity(session_id: str, value: int):
"""Debug / admin: set absolute affinity (-30..30)."""
clamped = max(-30, min(30, int(value)))
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE sessions SET affinity = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(clamped, session_id),
)
await db.commit()
return clamped
async def update_session_genre(session_id: str, genre: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
@@ -389,6 +641,24 @@ async def update_session_outfit(session_id: str, outfit_json: str):
await db.commit()
async def update_session_scene(session_id: str, scene_json: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE sessions SET scene_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(scene_json, session_id),
)
await db.commit()
async def update_session_narrative_stats(session_id: str, stats_json: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE sessions SET narrative_stats_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(stats_json, session_id),
)
await db.commit()
async def upsert_quest(session_id: str, title: str, status: str = "active"):
async with aiosqlite.connect(DB_PATH) as db:
async with db.execute(
@@ -429,6 +699,18 @@ async def update_quest_status(session_id: str, title: str, status: str):
await db.commit()
async def update_quest_by_id(quest_id: int, session_id: str, status: str) -> bool:
if status not in ("active", "done", "failed"):
return False
async with aiosqlite.connect(DB_PATH) as db:
cur = await db.execute(
"UPDATE rpg_quests SET status = ? WHERE id = ? AND session_id = ?",
(status, quest_id, session_id),
)
await db.commit()
return cur.rowcount > 0
async def get_message_count(session_id: str) -> int:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
+173
View File
@@ -0,0 +1,173 @@
import json
import logging
from services.memory import (
get_history,
get_session,
get_last_assistant_message_id,
update_session_plot_arc,
update_message_choices,
seed_quests_from_arc,
get_quests,
)
from services.rpg_state import apply_narrator_post
from services.personas import get_persona
from services.rpg_facts import facts_to_prompt
from services.rpg_plot import generate_plot_arc, choices_from_narrator
from services.rpg_context import format_narrator_context
from services.rpg_narrator import narrator_post
from services.sd_prompt import generate_sd_prompt
from services.sd_images import run_sd_for_message
logger = logging.getLogger(__name__)
DEFAULT_RPG_SETTINGS = {
"dice": True,
"narrator": True,
"quests": True,
"affinity": True,
"choices": True,
"stats": False,
}
def get_rpg_settings(session: dict) -> dict:
try:
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
except Exception:
return DEFAULT_RPG_SETTINGS
async def resolve_greeting(session_id: str, persona: dict) -> str:
history = await get_history(session_id)
for m in reversed(history):
if m.get("role") == "assistant" and (m.get("content") or "").strip():
return m["content"].strip()
return (persona.get("first_mes") or "").strip()
async def ensure_plot_arc_and_quests(
session_id: str,
persona: dict,
greeting: str,
genre: str,
*,
seed_quests: bool = True,
) -> dict:
session = await get_session(session_id) or {}
arc_json = session.get("plot_arc_json") or "{}"
try:
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
except Exception:
arc = {}
if arc:
return arc
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
arc = await generate_plot_arc(
persona.get("name", "Character"),
persona.get("description", ""),
persona.get("scenario", ""),
greeting,
facts_block=facts_block,
genre=genre,
)
if not arc:
return {}
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
if seed_quests:
await seed_quests_from_arc(session_id, arc)
return arc
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
session = await get_session(session_id)
if not session:
raise ValueError("Session not found")
history = await get_history(session_id)
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
if not assistant_msgs:
raise ValueError("No assistant message (first_mes) found")
first_mes_text = assistant_msgs[-1].get("content", "").strip()
if not first_mes_text:
raise ValueError("Empty first_mes")
msg_id = await get_last_assistant_message_id(session_id)
persona = await get_persona(persona_id) or {}
rpg_settings = get_rpg_settings(session)
arc: dict = {}
choices: list = []
status_quo = session.get("status_quo") or ""
outfit_json = session.get("outfit_json") or "[]"
if rpg:
genre = session.get("genre") or "adventure"
arc = await ensure_plot_arc_and_quests(
session_id,
persona,
first_mes_text,
genre,
seed_quests=rpg_settings.get("quests", True),
)
session = await get_session(session_id) or session
ctx_txt = f"assistant: {first_mes_text}"
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
quests_pre = await get_quests(session_id)
narr_ctx = format_narrator_context(arc, quests_pre, session.get("status_quo") or "")
post = await narrator_post(
persona.get("name", persona_id),
ctx_txt,
arc_json,
facts_block,
is_opening=True,
extra_context=narr_ctx,
)
if rpg_settings.get("choices", True):
choices = choices_from_narrator(post.get("choices") or [])
await apply_narrator_post(session_id, post, rpg_settings, session)
session = await get_session(session_id) or session
status_quo = session.get("status_quo") or status_quo
outfit_json = session.get("outfit_json") or outfit_json
quests = await get_quests(session_id)
messages = await get_history(session_id)
bundle = await generate_sd_prompt(
messages,
persona_id,
outfit_json=outfit_json,
scene_json=session.get("scene_json", "{}") if session else "{}",
)
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
updated = await get_session(session_id)
affinity = updated.get("affinity", 0) if updated else 0
if msg_id and choices:
await update_message_choices(
msg_id, json.dumps(choices, ensure_ascii=False)
)
return {
"plot_arc": arc,
"quests": quests,
"outfit_json": outfit_json,
"status_quo": status_quo,
"choices": choices,
"image_prompt": sd_out.get("image_prompt"),
"image_prompt_alt": sd_out.get("image_prompt_alt"),
"image_path": sd_out.get("image_path"),
"image_path_alt": sd_out.get("image_path_alt"),
"image_error": sd_out.get("image_error"),
"image_error_alt": sd_out.get("image_error_alt"),
"affinity": affinity,
}
+94
View File
@@ -0,0 +1,94 @@
"""Danbooru-style outfit tags with color enrichment for stable SD prompts."""
import json
import re
COLOR_TOKENS = frozenset({
"white", "black", "red", "blue", "green", "yellow", "pink", "purple",
"orange", "brown", "gray", "grey", "silver", "gold", "beige", "navy",
"blonde", "dark", "light", "cyan", "teal", "maroon", "crimson",
})
# garment substring -> default color when tag has no color word
_GARMENT_DEFAULT_COLOR: list[tuple[str, str]] = [
("championship_belt", "gold"),
("belt_collar", "gold"),
("belt", "brown"),
("sports_shorts", "black"),
("shorts", "black"),
("torn_tank_top", "white"),
("tank_top", "white"),
("crop_top", "white"),
("t_shirt", "white"),
("shirt", "white"),
("dress", "black"),
("skirt", "black"),
("jeans", "blue"),
("pants", "black"),
("hoodie", "gray"),
("jacket", "black"),
("ribbon", "red"),
("collar", "black"),
("boots", "black"),
("socks", "white"),
]
def _clean_tag(raw: str) -> str:
t = (raw or "").strip().lower()
t = re.sub(r"[^\w]+", "_", t)
t = re.sub(r"_+", "_", t).strip("_")
return t
def tag_has_color(tag: str) -> bool:
parts = tag.lower().split("_")
return any(p in COLOR_TOKENS for p in parts)
def enrich_outfit_tag(tag: str) -> str:
"""Add a color prefix when the tag names clothing but omits color."""
t = _clean_tag(tag)
if not t or tag_has_color(t):
return t
for needle, color in _GARMENT_DEFAULT_COLOR:
if needle in t:
return f"{color}_{t}"
return t
def normalize_outfit_list(raw: list | None) -> list[str]:
out: list[str] = []
seen: set[str] = set()
if not isinstance(raw, list):
return out
for item in raw:
if isinstance(item, str):
t = enrich_outfit_tag(item)
elif isinstance(item, dict):
label = (item.get("tag") or item.get("label") or "").strip()
color = (item.get("color") or "").strip().lower()
base = _clean_tag(label)
if not base:
continue
t = f"{color}_{base}" if color and not tag_has_color(base) else enrich_outfit_tag(base)
else:
continue
if t and t not in seen:
seen.add(t)
out.append(t)
return out
def outfit_list_to_json(tags: list[str]) -> str:
return json.dumps(normalize_outfit_list(tags), ensure_ascii=False)
def parse_and_normalize_outfit_json(raw: str | None) -> str:
try:
data = json.loads(raw or "[]")
except (json.JSONDecodeError, TypeError):
data = []
if isinstance(data, str):
data = [x.strip() for x in data.split(",") if x.strip()]
return outfit_list_to_json(data if isinstance(data, list) else [])
+24 -4
View File
@@ -63,6 +63,7 @@ def _row_to_persona(row: dict) -> dict:
"lora_name": row["lora_name"] or "",
"lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8,
"appearance_tags": row["appearance_tags"] or "",
"appearance_prose": row.get("appearance_prose", "") or "",
"personality": row.get("personality", "") or "",
"scenario": row.get("scenario", "") or "",
"first_mes": row.get("first_mes", "") or "",
@@ -84,6 +85,9 @@ def build_persona_prompt(data: dict) -> str:
if ex:
parts.append(f"Example dialogue:\n{ex}")
parts.append("Stay in character. Reply as the character. Do not add image tags.")
from services.chat_prompt import ROLEPLAY_GUARDRAILS
parts.append(ROLEPLAY_GUARDRAILS)
return "\n\n".join(p for p in parts if p and p.split(": ", 1)[-1].strip())
@@ -117,6 +121,7 @@ async def create_persona(
lora_name: str = "",
lora_weight: float = 0.8,
appearance_tags: str = "",
appearance_prose: str = "",
personality: str = "",
scenario: str = "",
first_mes: str = "",
@@ -138,19 +143,19 @@ async def create_persona(
await db.execute(
"""INSERT INTO personas
(persona_id, name, emoji, description, prompt, custom,
sd_enabled, lora_name, lora_weight, appearance_tags,
sd_enabled, lora_name, lora_weight, appearance_tags, appearance_prose,
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
alternate_greetings_json)
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
persona_id, name, emoji, description, final_prompt,
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags,
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, appearance_prose,
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
alternate_greetings_json,
),
)
await db.commit()
return {
return {
"name": name,
"emoji": emoji,
"description": description,
@@ -160,6 +165,7 @@ async def create_persona(
"lora_name": lora_name,
"lora_weight": lora_weight,
"appearance_tags": appearance_tags,
"appearance_prose": appearance_prose,
"personality": personality,
"scenario": scenario,
"first_mes": first_mes,
@@ -226,6 +232,7 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
"lora_name",
"lora_weight",
"appearance_tags",
"appearance_prose",
"personality",
"scenario",
"first_mes",
@@ -255,6 +262,19 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
merged = dict(existing)
merged.update(updates)
updates["prompt"] = build_persona_prompt(merged)
if "appearance_tags" in updates and "appearance_prose" not in updates:
tags = updates["appearance_tags"].strip()
if tags:
from services.llm import send_message
try:
prose = await send_message([
{"role": "system", "content": "Convert danbooru tags to natural English description. Output only the description, no markdown."},
{"role": "user", "content": f"Tags: {tags}"}
])
updates["appearance_prose"] = prose.strip()
except Exception:
pass
cols = ", ".join(f"{k} = ?" for k in updates)
cur2 = await db.execute(
+77
View File
@@ -0,0 +1,77 @@
"""Roleplay output guardrails and OOC stripping."""
import re
ROLEPLAY_GUARDRAILS = (
"[In-character rules — breaking these ruins immersion]\n"
"- Reply ONLY as the character in the present moment: spoken lines, visible actions, "
"thoughts they would actually show.\n"
"- NEVER write: P.S., PS, postscripts, footnotes, section headers, "
'"Статус кво", "Status quo", "To be continued", scene summaries, '
"editorial closings, or foreshadowing asides "
'(e.g. "Когда вы выйдете...", "Ты уже знаешь правду...").\n'
"- Do NOT explain subtext to the reader or predict future scenes. No narrator voice.\n"
"- End inside the scene; do not wrap up with meta commentary.\n"
"- Sections marked MANDATORY (relationship, state) are binding — obey without citing them."
)
RP_OUTPUT_REMINDER = (
"\n\n--- Reply format (MANDATORY) ---\n"
"Next message = in-character only. "
"Forbidden in output: P.S., Статус кво, Status quo, author notes, summaries, footnotes.\n"
"---"
)
def status_quo_prompt_block(status_quo: str) -> str:
sq = (status_quo or "").strip()
if not sq:
return ""
return (
"\n\n--- Current situation (INTERNAL — player never sees this block) ---\n"
+ sq
+ "\nBackground truth for you only. "
"Never echo this header, never open/close replies with 'Статус кво' or summaries. "
"Show the situation through dialogue and action.\n---"
)
_OOC_PARA_START = re.compile(
r"^(?:"
r"Статус\s*кво|Status\s*quo|"
r"P\.?\s*S\.?|PS:|"
r"Примечание:|Author'?s?\s*note:|"
r"OOC:|\\[OOC\\]|"
r"To be continued|Продолжение следует"
r")\b",
re.IGNORECASE,
)
def strip_ooc_from_reply(text: str) -> str:
"""Remove common OOC tails (P.S., Статус кво paragraphs, etc.)."""
if not text or not text.strip():
return text or ""
out = text.rstrip()
# Drop trailing P.S. block (often last paragraph).
out = re.sub(
r"(?is)\n\s*P\.?\s*S\.?\s*[:.\-—].*$",
"",
out,
).rstrip()
parts = re.split(r"\n\s*\n", out)
kept: list[str] = []
for part in parts:
lines = [ln for ln in part.splitlines() if ln.strip()]
if not lines:
continue
if _OOC_PARA_START.match(lines[0].strip()):
continue
kept.append(part)
if not kept:
return ""
return "\n\n".join(kept).strip()
+53
View File
@@ -0,0 +1,53 @@
"""Shared context blocks for RPG narrator / plot LLM calls."""
from services.rpg_plot import count_active_quests
def format_narrator_context(
arc: dict | None,
quests: list | None,
status_quo: str = "",
) -> str:
parts: list[str] = []
arc = arc or {}
beats = arc.get("beats") or []
if not isinstance(beats, list):
beats = []
parts.append(f"Plot phase: {arc.get('phase', 'opening')}. Scripted beats left: {len(beats)}.")
if not beats:
parts.append(
"IMPORTANT: Scripted beats are EXHAUSTED (quests may already be done). "
"The story must CONTINUE — do not stall. "
"Always return 2-4 meaningful choices for the player's next actions. "
"You may add quest_updates with status 'active' for NEW optional threads. "
"Do NOT re-activate quests the player already completed unless they explicitly revisit that thread."
)
elif count_active_quests(quests) == 0:
pending = [
(b.get("title") or b.get("id") or "beat")
for b in beats[:3]
if isinstance(b, dict)
]
parts.append(
"IMPORTANT: No active quests but scripted beats remain — arc was likely desynced. "
"The engine will inject the next beat; prefer choices that fit pending beats: "
+ ", ".join(pending)
+ ". Do NOT treat the arc as finished."
)
hint = (arc.get("next_beat_hint") or "").strip()
if hint:
parts.append(f"Arc hint: {hint}")
if quests:
parts.append("Quest log:")
for q in quests:
parts.append(f" [{q.get('status', 'active')}] {q.get('title', '')}")
else:
parts.append("Quest log: (empty)")
sq = (status_quo or "").strip()
if sq:
parts.append(f"Status quo: {sq[:400]}")
return "\n".join(parts)
+340 -46
View File
@@ -1,76 +1,370 @@
import json
import logging
import os
import re
from services.llm import send_message_with_model, send_message
from services.llm import LLMError, send_message_with_model, send_message
logger = logging.getLogger(__name__)
FACTS_MODEL = os.getenv("RPG_FACTS_MODEL", "").strip() or "deepseek/deepseek-chat-v3"
FACTS_STORE_LIMIT = int(os.getenv("FACTS_STORE_LIMIT", "100"))
FACTS_PROMPT_MAX = int(os.getenv("FACTS_PROMPT_MAX", "40"))
FACTS_DEDUP_THRESHOLD = int(os.getenv("FACTS_DEDUP_THRESHOLD", "30"))
FACTS_COMPRESS_TARGET = int(os.getenv("FACTS_COMPRESS_TARGET", "22"))
FACTS_SYSTEM = """Extract NEW stable facts from the conversation.
Return ONLY valid JSON (no markdown), as an array of objects:
[{"text": "short durable fact", "rp_day": "when this became true in story time"}]
FACTS_SYSTEM = """Extract stable facts from the conversation.
Return ONLY valid JSON (no markdown), as an array of short strings.
Rules:
- Facts must be durable (names, relations, inventory, locations, world rules).
- Do not include ephemeral actions unless they change state.
- Avoid duplicates.
- Keep each fact <= 120 chars.
Example output:
["User name is Alex", "We are in a ruined castle", "NPC Mira distrusts the user"]"""
- Return at most 5 NEW facts per turn. If nothing new, return [].
- Do NOT repeat or rephrase facts already listed under "Already known".
- Facts must be durable (names, relations, inventory, locations, lasting world state).
- Skip momentary emotions unless they permanently change a relationship.
- text <= 120 chars each.
- rp_day: in-world time label (день 1, второй день, та же ночь, через год). Use RP time hint when unclear."""
FACTS_COMPRESS_SYSTEM = """You consolidate RPG session memory for a long-running chat.
Return ONLY valid JSON (no markdown): an array of {"text": "...", "rp_day": "..."}.
Goals:
- Aggressively MERGE near-duplicates (same topic in RU/EN, Rin/Рин, Grigo/Григорий).
- Keep ONE best fact per topic; combine rp_day if needed (e.g. "день 12").
- DROP redundant, trivial, or superseded facts.
- Keep: names, relationships, key locations, lasting magic/rules, inventory, unresolved threads.
- Target at most {target} facts (fewer is better). Each text <= 120 chars.
- rp_day = in-world labels only."""
_NAME_ALIASES = (
("grigoriy", "григорий"),
("grigo", "григо"),
("grigory", "григорий"),
("rin", "рин"),
("player", "игрок"),
("user", "игрок"),
("glade", "полян"),
("flowers", "цвет"),
("flower", "цвет"),
("magical", "волшеб"),
("magic", "волшеб"),
("glow", "свет"),
("glowing", "свет"),
)
def merge_facts(existing_json: str, new_facts: list[str], limit: int = 80) -> str:
def parse_fact_entry(raw) -> dict | None:
if isinstance(raw, dict):
text = (raw.get("text") or raw.get("fact") or "").strip()
rp_day = (raw.get("rp_day") or raw.get("learned") or raw.get("day") or "").strip()
elif isinstance(raw, str):
text = raw.strip()
rp_day = ""
else:
return None
if not text:
return None
return {"text": text[:120], "rp_day": rp_day[:80]}
def parse_facts_list(facts_json: str | None) -> list[dict]:
try:
existing = json.loads(existing_json or "[]")
if not isinstance(existing, list):
existing = []
data = json.loads(facts_json or "[]")
except json.JSONDecodeError:
existing = []
seen = {str(x).strip() for x in existing if str(x).strip()}
merged = [str(x).strip() for x in existing if str(x).strip()]
for f in new_facts:
s = str(f).strip()
if not s or s in seen:
return []
if not isinstance(data, list):
return []
out: list[dict] = []
seen: set[str] = set()
for item in data:
entry = parse_fact_entry(item)
if not entry:
continue
seen.add(s)
merged.append(s)
if len(merged) > limit:
merged = merged[-limit:]
return json.dumps(merged, ensure_ascii=False)
key = entry["text"].lower()
if key in seen:
continue
seen.add(key)
out.append(entry)
return out
async def extract_facts(context_messages: list[dict]) -> list[str]:
# Build a compact transcript
def facts_list_to_json(facts: list[dict]) -> str:
return json.dumps(
[{"text": f["text"], "rp_day": f.get("rp_day", "")} for f in facts],
ensure_ascii=False,
)
def rp_day_from_scene(scene_json: str | None) -> str:
try:
scene = json.loads(scene_json or "{}")
if isinstance(scene, dict):
day = (scene.get("day") or "").strip()
if day:
return day[:80]
except (json.JSONDecodeError, TypeError):
pass
return ""
def _normalize_fact_text(text: str) -> str:
t = (text or "").lower()
for a, b in _NAME_ALIASES:
t = t.replace(a, b)
t = re.sub(r"[^\w\s]", " ", t, flags=re.UNICODE)
return re.sub(r"\s+", " ", t).strip()
def _fact_tokens(text: str) -> set[str]:
words = re.findall(r"[\w]+", _normalize_fact_text(text), flags=re.UNICODE)
stop = {"the", "and", "that", "with", "this", "have", "has", "was", "are", "для", "что", "как", "это", "на", "в", "и", "а"}
return {w for w in words if len(w) > 2 and w not in stop}
def facts_are_similar(a: str, b: str) -> bool:
al, bl = a.lower().strip(), b.lower().strip()
if al == bl:
return True
shorter, longer = (al, bl) if len(al) <= len(bl) else (bl, al)
if shorter in longer and len(shorter) / max(len(longer), 1) >= 0.35:
return True
ta, tb = _fact_tokens(a), _fact_tokens(b)
if not ta or not tb:
return False
overlap = len(ta & tb) / len(ta | tb)
return overlap >= 0.32
def dedupe_facts_fuzzy(facts: list[dict]) -> list[dict]:
out: list[dict] = []
for f in facts:
placed = False
for i, existing in enumerate(out):
if facts_are_similar(f["text"], existing["text"]):
if len(f["text"]) > len(existing["text"]):
out[i]["text"] = f["text"][:120]
if f.get("rp_day") and not existing.get("rp_day"):
out[i]["rp_day"] = f["rp_day"]
placed = True
break
if not placed:
out.append(dict(f))
return out
def merge_facts(
existing_json: str,
new_facts: list,
*,
rp_day_default: str = "",
) -> str:
merged = parse_facts_list(existing_json)
seen = {f["text"].lower() for f in merged}
default_day = (rp_day_default or "").strip()[:80]
for raw in new_facts or []:
entry = parse_fact_entry(raw)
if not entry:
continue
if not entry["rp_day"] and default_day:
entry["rp_day"] = default_day
key = entry["text"].lower()
if key in seen:
for i, existing in enumerate(merged):
if existing["text"].lower() == key:
if entry["rp_day"] and not existing.get("rp_day"):
merged[i]["rp_day"] = entry["rp_day"]
break
continue
dup = False
for i, existing in enumerate(merged):
if facts_are_similar(entry["text"], existing["text"]):
if len(entry["text"]) > len(existing["text"]):
merged[i]["text"] = entry["text"]
if entry["rp_day"] and not existing.get("rp_day"):
merged[i]["rp_day"] = entry["rp_day"]
dup = True
break
if dup:
continue
seen.add(key)
merged.append(entry)
return facts_list_to_json(merged)
async def compress_facts(
facts: list[dict],
*,
scene_context: str = "",
status_quo: str = "",
target: int = FACTS_COMPRESS_TARGET,
) -> list[dict]:
payload = json.dumps(facts, ensure_ascii=False, indent=2)
user = (
f"Current fact count: {len(facts)}. Target after merge: <= {target}.\n\n"
f"Facts JSON:\n{payload}\n"
)
if scene_context.strip():
user += f"\nCurrent scene:\n{scene_context.strip()[:1500]}\n"
if status_quo.strip():
user += f"\nStatus quo:\n{status_quo.strip()[:1500]}\n"
system = FACTS_COMPRESS_SYSTEM.format(target=target)
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
try:
raw = await (
send_message_with_model(messages, FACTS_MODEL)
if FACTS_MODEL
else send_message(messages)
)
except LLMError as e:
logger.warning("compress_facts LLM failed: %s", e)
return dedupe_facts_fuzzy(facts)[-target:]
except Exception as e:
logger.warning("compress_facts unexpected: %s", e)
return dedupe_facts_fuzzy(facts)[-target:]
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip()
try:
data = json.loads(cleaned)
if isinstance(data, list):
out = []
for item in data:
entry = parse_fact_entry(item)
if entry:
out.append(entry)
if out:
logger.info("compress_facts: %d -> %d", len(facts), len(out))
return dedupe_facts_fuzzy(out)[:FACTS_STORE_LIMIT]
except json.JSONDecodeError:
logger.warning("compress_facts JSON parse failed. Raw=%.400s", raw)
return dedupe_facts_fuzzy(facts)[-target:]
async def merge_facts_persist(
existing_json: str,
new_facts: list,
*,
rp_day_default: str = "",
scene_context: str = "",
status_quo: str = "",
) -> str:
"""Merge, fuzzy-dedupe, LLM-compress when list grows too large."""
merged_json = merge_facts(
existing_json, new_facts, rp_day_default=rp_day_default
)
facts = dedupe_facts_fuzzy(parse_facts_list(merged_json))
if len(facts) > FACTS_DEDUP_THRESHOLD:
facts = await compress_facts(
facts,
scene_context=scene_context,
status_quo=status_quo,
target=FACTS_COMPRESS_TARGET,
)
facts = dedupe_facts_fuzzy(facts)
if len(facts) > FACTS_STORE_LIMIT:
facts = await compress_facts(
facts,
scene_context=scene_context,
status_quo=status_quo,
target=FACTS_STORE_LIMIT,
)
return facts_list_to_json(facts)
async def extract_facts(
context_messages: list[dict],
*,
rp_day_hint: str = "",
existing_json: str = "",
) -> list[dict]:
transcript = "\n".join(
f"{m.get('role')}: {m.get('content','')}".strip()
f"{m.get('role')}: {m.get('content', '')}".strip()
for m in context_messages
if m.get("role") in ("user", "assistant")
)[-6000:]
hint = (rp_day_hint or "").strip()
known = parse_facts_list(existing_json)
user_parts = []
if known:
known_lines = "\n".join(
f"- [{f.get('rp_day') or '?'}] {f['text']}" for f in known[-40:]
)
user_parts.append(f"Already known facts (do NOT repeat):\n{known_lines}\n")
if hint:
user_parts.append(f"RP time hint: {hint}\n")
user_parts.append(f"New transcript:\n{transcript}")
user = "\n".join(user_parts)
messages = [
{"role": "system", "content": FACTS_SYSTEM},
{"role": "user", "content": transcript},
{"role": "user", "content": user},
]
raw = await (send_message_with_model(messages, FACTS_MODEL) if FACTS_MODEL else send_message(messages))
try:
data = json.loads(raw.strip())
if isinstance(data, list):
return [str(x) for x in data][:40]
except Exception:
raw = await (
send_message_with_model(messages, FACTS_MODEL)
if FACTS_MODEL
else send_message(messages)
)
except LLMError as e:
logger.warning("extract_facts LLM failed (model=%s): %s", FACTS_MODEL or "SYSTEM", e)
return []
except Exception as e:
logger.warning("extract_facts unexpected: %s", e)
return []
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip()
try:
data = json.loads(cleaned)
if isinstance(data, list):
out: list[dict] = []
for item in data[:8]:
entry = parse_fact_entry(item)
if not entry:
continue
if any(facts_are_similar(entry["text"], k["text"]) for k in known):
continue
if not entry["rp_day"] and hint:
entry["rp_day"] = hint[:80]
out.append(entry)
return out
except json.JSONDecodeError:
logger.warning("extract_facts JSON parse failed. Raw=%.400s", raw)
return []
def facts_to_prompt(facts_json: str, max_items: int = 20) -> str:
try:
facts = json.loads(facts_json or "[]")
if not isinstance(facts, list):
return ""
except json.JSONDecodeError:
return ""
facts = [str(x).strip() for x in facts if str(x).strip()]
def facts_to_prompt(facts_json: str, max_items: int = FACTS_PROMPT_MAX) -> str:
facts = dedupe_facts_fuzzy(parse_facts_list(facts_json))
if not facts:
return ""
block = "\n".join(f"- {x}" for x in facts[-max_items:])
return f"--- Facts (persistent memory) ---\n{block}\n---"
recent = facts[-max_items:]
lines = []
for f in recent:
day = (f.get("rp_day") or "").strip()
if day:
lines.append(f"- [{day}] {f['text']}")
else:
lines.append(f"- {f['text']}")
block = "\n".join(lines)
total = len(facts)
header = "--- Facts (persistent memory"
if total > len(recent):
header += f", showing {len(recent)} of {total}"
header += ") ---"
return f"{header}\n{block}\n---"
+63 -15
View File
@@ -2,7 +2,7 @@ import json
import os
import random
from services.llm import send_message_with_model
from services.llm import LLMError, send_message_with_model
import logging
logger = logging.getLogger(__name__)
@@ -18,8 +18,10 @@ Return ONLY valid JSON (no markdown):
"check_reason": "brief reason why a check is needed (e.g. 'jumping over a pit')",
"directives": ["short imperative rules for the next character reply"],
"resolution_text": "what actually happens as result of the action — written as narrator prose (1-2 sentences). Only if needs_check=true and roll/outcome provided.",
"status_quo_update": "optional short update about the world state"
"status_quo_update": "optional short update about the world state",
"scene_update": {"place": "", "time_of_day": "", "day": "", "weather": "", "exits": [], "layout_note": ""}
}
scene_update: only include keys that changed (partial). Omit scene_update if nothing changed.
If needs_check=false: directives may still guide tone/pacing, resolution_text must be empty string.
If needs_check=true and roll/outcome are provided: resolution_text MUST reflect the outcome.
- critical failure (1): embarrassing or painful failure with extra complication
@@ -32,17 +34,25 @@ After the character replied, update persistent state.
Return ONLY valid JSON (no markdown):
{
"status_quo_update": "what changed in the world/state (1-3 sentences)",
"facts": ["durable facts only"],
"facts": ["durable fact strings OR {\"text\":\"...\",\"rp_day\":\"день 1\"}"],
"choices": [{"id":"a","label":"..."}, ...],
"affinity_delta": 0,
"stats_delta": {"lust": 0, "stamina": 0, "tension": 0},
"scene_update": {"place": "", "place_id": "", "time_of_day": "", "day": "", "weather": "", "exits": [], "layout_note": ""},
"quest_updates": [{"title": "quest title", "status": "active|done|failed"}],
"outfit_update": ["danbooru_tag", "danbooru_tag"]
}
Rules:
- status_quo_update: internal DM state only (facts, location, mood). Never address the player, never use headers like "Status quo"/"Статус кво", P.S., or author commentary.
- affinity_delta: integer -2..+2. Positive if character warmed up to player, negative if pushed away. 0 if neutral.
- stats_delta: each lust/stamina/tension -2..+2 (0 if unchanged). lust=arousal, stamina=energy, tension=stress.
- scene_update: partial location/time schema; only keys that changed. Do not duplicate all of status_quo into scene_update.
- quest_updates: only include if a quest was clearly started, completed, or failed. Empty array otherwise.
- choices: 0-4 options for what the player can do next.
- outfit_update: ONLY include if the character's clothing visibly changed (put on, took off, changed outfit). Use exact danbooru-style underscore_tags (e.g. ["white_dress", "red_ribbon", "barefoot"]). Empty array if no change."""
- choices: 0-4 options for what the player can do next. REQUIRED when scripted beats are exhausted — never return an empty choices array unless the session truly ended.
- outfit_update: ONLY if clothing visibly changed. Use danbooru underscore_tags WITH COLOR when possible
(e.g. white_tank_top, black_sports_shorts, gold_championship_belt, blue_jeans, red_ribbon).
Every garment tag should include a color prefix unless the item is inherently colorless (barefoot, nude).
Never bare generic tags like sports_shorts or torn_tank_top without a color. Empty array if no change."""
async def narrator_pre(
@@ -53,6 +63,7 @@ async def narrator_pre(
user_message: str,
roll: int | None = None,
outcome: str | None = None,
extra_context: str = "",
) -> dict:
roll_block = f"Roll d20={roll}\nOutcome={outcome}\n\n" if roll is not None else ""
user = (
@@ -63,10 +74,20 @@ async def narrator_pre(
f"Facts:\n{facts_block}\n\n"
f"Recent context:\n{context}\n"
)
raw = await send_message_with_model(
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
NARRATOR_MODEL,
)
if extra_context:
user += f"\n--- Session state ---\n{extra_context}\n---\n"
try:
raw = await send_message_with_model(
[{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}],
NARRATOR_MODEL,
)
except LLMError as e:
logger.warning("Narrator-pre LLM failed (model=%s): %s", NARRATOR_MODEL, e)
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
except Exception as e:
logger.warning("Narrator-pre unexpected error: %s", e)
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
@@ -76,10 +97,11 @@ async def narrator_pre(
try:
data = json.loads(cleaned)
if isinstance(data, dict):
data["_ok"] = True
return data
except Exception:
logger.warning("Narrator-pre JSON parse failed. Raw=%.500s", raw)
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""}
return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": "", "_ok": False}
async def narrator_post(
@@ -87,17 +109,42 @@ async def narrator_post(
context: str,
global_plot: str,
facts_block: str,
is_opening: bool = False,
extra_context: str = "",
) -> dict:
opening_block = ""
if is_opening:
opening_block = (
"\n\nOPENING SCENE: This is the first greeting, not a mid-conversation reply. "
"Extract the character's INITIAL visible clothing from the greeting into outfit_update "
"(danbooru underscore tags WITH color prefixes: white_shirt, black_shorts, gold_belt), "
"even if clothing did not change during the scene. "
"Set status_quo to describe the opening situation. "
"Fill scene_update from greeting and scenario (place, time_of_day, day, layout_note). "
"If the greeting shows clear warmth or hostility toward the player, set affinity_delta "
"non-zero (-2..+2); use 0 only if truly neutral.\n"
)
user = (
f"Persona: {persona_name}\n\n"
f"Global plot:\n{global_plot}\n\n"
f"Facts:\n{facts_block}\n\n"
f"Recent context:\n{context}\n"
f"{opening_block}"
)
raw = await send_message_with_model(
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
NARRATOR_MODEL,
)
if extra_context:
user += f"\n--- Session state ---\n{extra_context}\n---\n"
try:
raw = await send_message_with_model(
[{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}],
NARRATOR_MODEL,
)
except LLMError as e:
logger.warning("Narrator-post LLM failed (model=%s): %s", NARRATOR_MODEL, e)
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
except Exception as e:
logger.warning("Narrator-post unexpected error: %s", e)
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
@@ -107,7 +154,8 @@ async def narrator_post(
try:
data = json.loads(cleaned)
if isinstance(data, dict):
data["_ok"] = True
return data
except Exception:
logger.warning("Narrator-post JSON parse failed. Raw=%.500s", raw)
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []}
return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": [], "_ok": False}
+405 -5
View File
@@ -1,7 +1,7 @@
import json
import os
from services.llm import send_message_with_model, send_message
from services.llm import LLMError, send_message_with_model, send_message
import logging
logger = logging.getLogger(__name__)
@@ -63,7 +63,19 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
{"role": "system", "content": ARC_SYSTEM},
{"role": "user", "content": user},
]
raw = await (send_message_with_model(messages, PLOT_MODEL) if PLOT_MODEL else send_message(messages))
try:
raw = await (
send_message_with_model(messages, PLOT_MODEL)
if PLOT_MODEL
else send_message(messages)
)
except LLMError as e:
logger.warning("generate_plot_arc LLM failed (model=%s): %s", PLOT_MODEL or "SYSTEM", e)
return {}
except Exception as e:
logger.warning("generate_plot_arc unexpected error: %s", e)
return {}
cleaned = raw.strip()
# common OpenRouter formatting: fenced json
if cleaned.startswith("```"):
@@ -79,17 +91,236 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
return {}
def should_advance_arc(user_text: str) -> str | None:
BEAT_MATCH_SYSTEM = """You decide whether the player's latest message should fire ONE scripted plot beat.
Return ONLY valid JSON (no markdown):
{"fire_beat_id": "id from list or null", "confidence": "high|low"}
Rules:
- Fire only if the message clearly matches that beat's narrative intent RIGHT NOW.
- event_driven:rest — stopping to rest, sleep, camp, sauna break, recuperate (not mere sitting still in scene).
- event_driven:travel — leaving, driving, journey, going to a new place, hitting the road.
- event_driven:help_request — explicit plea for help/rescue/assistance.
- event_driven:after_fail / after_success — follow-up to a recent failure/success beat.
- Casual talk, flirting, exploring the current place without leaving does NOT fire travel.
- If nothing fits well, return null.
- Pick at most ONE beat; prefer high confidence only."""
def dice_outcome_to_beat_trigger(outcome: str | None) -> str | None:
"""Map d20 outcome to event_driven beat trigger (after_fail / after_success)."""
o = (outcome or "").strip().lower()
if o in ("failure", "critical failure"):
return "event_driven:after_fail"
if o in ("success", "critical success"):
return "event_driven:after_success"
return None
def should_advance_arc_keywords(user_text: str) -> str | None:
"""Legacy keyword fallback when LLM match is unavailable."""
t = (user_text or "").lower()
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн"]):
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн", "привала", "заправк", "саун"]):
return "event_driven:rest"
if any(x in t for x in ["идем дальше", "пойдем дальше", "в путь", "продолжаем путь", "уходим", "возвращаемся", "переходим"]):
if any(
x in t
for x in [
"идем дальше", "пойдем дальше", "пойдём дальше", "едем дальше", "едем",
"поехали", "выезжаем", "выезжаю", "в путь", "продолжаем путь",
"уходим", "возвращаемся", "переходим", "за рул", "машин", "автомоб",
"дорог", "трас", "шосс", "приех", "прибыва", "стади", "на стадион",
"отправляемся", "выдвигаемся", "в дорогу",
]
):
return "event_driven:travel"
if any(x in t for x in ["помоги", "помочь", "нужна помощь", "спасите", "help"]):
return "event_driven:help_request"
return None
def _parse_llm_json(raw: str) -> dict | list | None:
cleaned = (raw or "").strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip()
try:
return json.loads(cleaned)
except json.JSONDecodeError:
return None
async def classify_plot_beat(
user_text: str,
beats: list[dict],
recent_context: str = "",
last_dice_outcome: str | None = None,
) -> str | None:
"""LLM: return beat id to fire, or None."""
pending = [b for b in beats if isinstance(b, dict) and b.get("id")]
if not pending or not (user_text or "").strip():
return None
lines = []
for b in pending[:8]:
lines.append(
json.dumps(
{
"id": b.get("id"),
"title": b.get("title", ""),
"trigger": b.get("trigger", ""),
"injection": (b.get("injection") or "")[:200],
},
ensure_ascii=False,
)
)
user = (
f"Player message:\n{user_text.strip()}\n\n"
f"Pending beats:\n" + "\n".join(lines)
)
if recent_context.strip():
user += f"\n\nRecent chat:\n{recent_context.strip()[-2500:]}\n"
if last_dice_outcome:
user += f"\nLast dice outcome this turn: {last_dice_outcome}\n"
messages = [
{"role": "system", "content": BEAT_MATCH_SYSTEM},
{"role": "user", "content": user},
]
try:
raw = await (
send_message_with_model(messages, PLOT_MODEL)
if PLOT_MODEL
else send_message(messages)
)
except LLMError as e:
logger.warning("classify_plot_beat LLM failed: %s", e)
return None
except Exception as e:
logger.warning("classify_plot_beat unexpected: %s", e)
return None
data = _parse_llm_json(raw)
if not isinstance(data, dict):
return None
bid = data.get("fire_beat_id")
if bid in (None, "", "null", "none"):
return None
bid = str(bid).strip()
if data.get("confidence") == "low":
return None
valid_ids = {str(b.get("id")) for b in pending}
if bid in valid_ids:
logger.info("classify_plot_beat: fired %s", bid)
return bid
return None
def pop_beat_by_id(arc: dict, beat_id: str) -> tuple[dict, list[dict]]:
beats = arc.get("beats") or []
matched, remaining = [], []
for b in beats:
if isinstance(b, dict) and str(b.get("id")) == str(beat_id) and not matched:
matched.append(b)
else:
remaining.append(b)
arc["beats"] = remaining
return arc, matched
def beat_title(beat: dict) -> str:
return ((beat.get("title") or beat.get("injection") or "")[:120]).strip()
def count_active_quests(quests: list | None) -> int:
return sum(1 for q in (quests or []) if q.get("status") == "active")
def prune_beats_for_done_quests(arc: dict, quests: list | None) -> tuple[dict, list[dict]]:
"""Drop beats whose title already matches a done/failed quest (manual quest close desync)."""
done_titles = {
(q.get("title") or "").strip().lower()
for q in (quests or [])
if q.get("status") in ("done", "failed")
}
if not done_titles:
return arc, []
removed, kept = [], []
for b in arc.get("beats") or []:
if isinstance(b, dict) and beat_title(b).lower() in done_titles:
removed.append(b)
else:
kept.append(b)
arc["beats"] = kept
return arc, removed
def pop_next_beats(arc: dict, max_beats: int = 1) -> tuple[dict, list[dict]]:
beats = arc.get("beats") or []
if not isinstance(beats, list) or not beats:
return arc, []
n = min(max_beats, len(beats))
matched = [b for b in beats[:n] if isinstance(b, dict)]
arc["beats"] = beats[n:]
return arc, matched
async def process_arc_beats(
arc: dict,
quests: list | None,
user_text: str,
*,
recent_context: str = "",
last_dice_outcome: str | None = None,
allow_stuck_recovery: bool = True,
) -> tuple[dict, list[dict], list[dict], str]:
"""
Prune completed beats, then fire by dice outcome, LLM match, keywords, or stuck recovery.
Returns (arc, fired_beats, pruned_beats, mode).
mode: '' | 'after_dice' | 'llm' | 'trigger' | 'stuck_recovery' | 'pruned'
"""
if not arc:
return arc, [], [], ""
arc, pruned = prune_beats_for_done_quests(arc, quests)
beats_pending = arc.get("beats") or []
dice_trig = dice_outcome_to_beat_trigger(last_dice_outcome)
if dice_trig and beats_pending:
arc, fired = pop_matching_beats(arc, dice_trig, max_beats=1)
if fired:
logger.info(
"process_arc_beats: after_dice %s -> %s",
last_dice_outcome,
fired[0].get("id"),
)
return arc, fired, pruned, "after_dice"
if beats_pending:
beat_id = await classify_plot_beat(
user_text, beats_pending, recent_context, last_dice_outcome
)
if beat_id:
arc, fired = pop_beat_by_id(arc, beat_id)
if fired:
return arc, fired, pruned, "llm"
trig = should_advance_arc_keywords(user_text)
if trig:
arc, fired = pop_matching_beats(arc, trig, max_beats=1)
if fired:
return arc, fired, pruned, "trigger"
if allow_stuck_recovery and arc.get("beats") and count_active_quests(quests) == 0:
arc, fired = pop_next_beats(arc, 1)
if fired:
return arc, fired, pruned, "stuck_recovery"
if pruned:
return arc, [], pruned, "pruned"
return arc, [], [], ""
PHASE_ORDER = ["opening", "hook", "complication", "reveal", "climax", "aftermath"]
@@ -108,6 +339,129 @@ def advance_phase(arc: dict) -> bool:
return True
BEATS_APPEND_SYSTEM = """You are a narrative designer for an RPG chat.
The plot arc has NO remaining scripted beats. Generate 2-3 NEW beats to continue play.
Return ONLY valid JSON (no markdown):
{
"beats": [
{"id":"b_new_1","title":"short quest title","trigger":"event_driven:rest|event_driven:travel|event_driven:help_request|event_driven:after_fail|event_driven:after_success",
"injection":"1-3 sentences in-world",
"choices":[{"id":"a","label":"..."},{"id":"b","label":"..."}]}
],
"next_beat_hint": "what to push next",
"phase": "hook|complication|reveal|climax|aftermath"
}
Match the current scene and completed quests. Do not restart finished storylines."""
async def replenish_arc_beats(
arc: dict,
persona_name: str,
recent_context: str,
quests: list,
genre: str = "adventure",
) -> dict:
"""Append new beats when arc.beats is empty so plot/quest engine can continue."""
if arc.get("beats"):
return arc
quest_lines = "\n".join(
f" [{q.get('status')}] {q.get('title')}" for q in (quests or [])
) or " (none)"
user = (
f"Character: {persona_name}\n"
f"Genre: {format_genres(genre)}\n"
f"Current arc title: {arc.get('title', '')}\n"
f"Phase: {arc.get('phase', 'aftermath')}\n"
f"Boundaries: {json.dumps(arc.get('boundaries', []), ensure_ascii=False)}\n"
f"Quests:\n{quest_lines}\n\n"
f"Recent chat:\n{recent_context[-4000:]}\n"
)
messages = [
{"role": "system", "content": BEATS_APPEND_SYSTEM},
{"role": "user", "content": user},
]
try:
raw = await (
send_message_with_model(messages, PLOT_MODEL)
if PLOT_MODEL
else send_message(messages)
)
except LLMError as e:
logger.warning("replenish_arc_beats failed: %s", e)
return arc
except Exception as e:
logger.warning("replenish_arc_beats unexpected: %s", e)
return arc
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
if cleaned.endswith("```"):
cleaned = cleaned.rsplit("```", 1)[0]
cleaned = cleaned.strip()
try:
data = json.loads(cleaned)
except Exception:
logger.warning("replenish_arc_beats JSON parse failed. Raw=%.400s", raw)
return arc
new_beats = data.get("beats") if isinstance(data, dict) else []
if isinstance(new_beats, list) and new_beats:
arc["beats"] = new_beats
logger.info("replenish_arc_beats: added %d beats", len(new_beats))
if isinstance(data, dict) and data.get("next_beat_hint"):
arc["next_beat_hint"] = data["next_beat_hint"]
if isinstance(data, dict) and data.get("phase"):
arc["phase"] = data["phase"]
return arc
async def reconcile_plot_arc(
session_id: str,
*,
replenish_if_empty: bool = True,
recent_context: str = "",
persona_name: str = "Character",
genre: str = "adventure",
) -> tuple[dict, bool]:
"""
Prune beats that match done quests; replenish if empty. Persists arc when changed.
Returns (arc, changed).
"""
from services.memory import get_session, get_quests, update_session_plot_arc, seed_quests_from_arc
session = await get_session(session_id)
if not session or not session.get("rpg_enabled"):
return {}, False
try:
arc = json.loads(session.get("plot_arc_json") or "{}")
except (json.JSONDecodeError, TypeError):
arc = {}
if not isinstance(arc, dict):
arc = {}
quests = await get_quests(session_id)
arc, pruned = prune_beats_for_done_quests(arc, quests)
changed = bool(pruned)
if replenish_if_empty and not arc.get("beats"):
arc = await replenish_arc_beats(
arc,
persona_name,
recent_context,
quests,
genre=session.get("genre") or genre,
)
if arc.get("beats"):
changed = True
await seed_quests_from_arc(session_id, arc)
if changed:
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
return arc, changed
def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dict, list[dict]]:
beats = arc.get("beats", [])
if not isinstance(beats, list):
@@ -121,3 +475,49 @@ def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dic
arc["beats"] = remaining
return arc, matched
def normalize_choice(
raw: dict,
*,
source: str = "narrator",
beat: dict | None = None,
) -> dict | None:
"""Normalize a choice dict for storage/UI. Adds source and optional beat metadata."""
if not isinstance(raw, dict):
return None
label = (raw.get("label") or "").strip()
if not label:
return None
cid = (raw.get("id") or label[:1].lower() or "a").strip()
out = {"id": cid, "label": label, "source": source}
if beat and source == "plot_beat":
if beat.get("id"):
out["beat_id"] = beat["id"]
title = (beat.get("title") or "").strip()
if title:
out["beat_title"] = title
injection = (beat.get("injection") or "").strip()
if injection:
out["beat_injection"] = injection
return out
def choices_from_beat(beat: dict) -> list[dict]:
if not isinstance(beat, dict):
return []
return [
c for c in (
normalize_choice(item, source="plot_beat", beat=beat)
for item in (beat.get("choices") or [])
)
if c
]
def choices_from_narrator(raw_choices: list) -> list[dict]:
if not isinstance(raw_choices, list):
return []
return [
c for c in (normalize_choice(item, source="narrator") for item in raw_choices) if c
]
+321
View File
@@ -0,0 +1,321 @@
import json
DEFAULT_NARRATIVE_STATS = {"lust": 0, "stamina": 10, "tension": 0}
STAT_KEYS = ("lust", "stamina", "tension")
def parse_stats_json(raw: str | None) -> dict:
try:
data = json.loads(raw or "{}") if isinstance(raw, str) else (raw or {})
except Exception:
data = {}
if not isinstance(data, dict):
data = {}
out = dict(DEFAULT_NARRATIVE_STATS)
for k in STAT_KEYS:
try:
out[k] = max(0, min(10, int(data.get(k, out[k]))))
except (TypeError, ValueError):
pass
return out
def parse_scene_json(raw: str | None) -> dict:
try:
data = json.loads(raw or "{}") if isinstance(raw, str) else (raw or {})
except Exception:
data = {}
return data if isinstance(data, dict) else {}
def merge_scene(existing: dict, update: dict | None) -> dict:
if not update or not isinstance(update, dict):
return dict(existing)
merged = dict(existing)
for k, v in update.items():
if v is None:
continue
if k == "exits" and isinstance(v, list):
merged["exits"] = v
elif isinstance(v, str) and v.strip():
merged[k] = v.strip()
elif v != "":
merged[k] = v
return merged
def apply_stats_delta(stats: dict, delta: dict | None) -> dict:
out = parse_stats_json(json.dumps(stats, ensure_ascii=False))
if not delta or not isinstance(delta, dict):
return out
for k in STAT_KEYS:
try:
d = int(delta.get(k, 0))
except (TypeError, ValueError):
d = 0
d = max(-2, min(2, d))
if d:
out[k] = max(0, min(10, out[k] + d))
return out
def scene_prompt_block(scene: dict) -> str:
if not scene:
return ""
lines = []
for key, label in (
("place", "Location"),
("place_id", "Location ID"),
("time_of_day", "Time of day"),
("day", "Day"),
("weather", "Weather"),
("layout_note", "Layout"),
):
v = (scene.get(key) or "").strip()
if v:
lines.append(f"{label}: {v}")
exits = scene.get("exits")
if isinstance(exits, list) and exits:
lines.append("Exits: " + ", ".join(str(e) for e in exits))
if not lines:
return ""
return (
"\n\n--- Scene ---\n"
+ "\n".join(lines)
+ "\nUse this setting consistently. Do not recite scene labels verbatim.\n---"
)
# Scoped to current player until multi-user / multi-card relationships exist.
RELATIONSHIP_TARGET_LABEL = "current player (the human you are roleplaying with)"
def _mandatory_state_preamble() -> str:
return (
"MANDATORY ROLEPLAY DIRECTIVE (binding — overrides default cheerfulness or energy):\n"
f"- Target: {RELATIONSHIP_TARGET_LABEL} only.\n"
"- Your reply MUST visibly obey this state in tone, body language, pacing, and what the character attempts.\n"
"- Do NOT contradict a low stamina / high tension / low affinity reading with peppy or intimate behavior unless the text above explicitly allows recovery.\n"
"- Never name affinity, lust, stamina, tension, stats, meters, or numeric values in dialogue.\n"
)
def _band_instruction(value: int, bands: list[tuple[int, str]]) -> str:
"""bands: list of (min_inclusive, instruction) sorted by min descending."""
v = int(value)
for min_val, text in bands:
if v >= min_val:
return text
return bands[-1][1] if bands else ""
def affinity_prompt_block(affinity: int) -> str:
aff = max(-30, min(30, int(affinity)))
attitude = _band_instruction(
aff,
[
(10, "Devoted trust: openly affectionate, seeks closeness, defends the player, vulnerable honesty."),
(5, "Warm bond: friendly, teasing allowed, volunteers help, remembers small kindnesses."),
(1, "Slight fondness: polite-positive, rare soft moments, not yet intimate."),
(0, "Neutral professional distance: neither warm nor cold unless scene demands."),
(-1, "Cool and guarded: short answers, deflects personal topics, skeptical of motives."),
(-5, "Hostile or deeply distrustful: sarcasm, refusal, may threaten to leave or expose the player."),
(-30, "Open enmity: antagonistic, undermines, may sabotage or attack socially/physically if fitting genre."),
],
)
return (
"\n\n--- Relationship toward player (MANDATORY) ---\n"
+ _mandatory_state_preamble()
+ f"Affinity (internal, not spoken): level {aff}.\n"
f"Required attitude toward {RELATIONSHIP_TARGET_LABEL}: {attitude}\n"
"---"
)
def stats_prompt_block(stats: dict) -> str:
s = parse_stats_json(json.dumps(stats, ensure_ascii=False))
lust = s["lust"]
stamina = s["stamina"]
tension = s["tension"]
lust_line = _band_instruction(
lust,
[
(9, "Overwhelming arousal colors every beat: breathy, distracted, struggles to stay on task."),
(7, "Strong desire: flushed skin, lingering touch, voice unsteady, thoughts drift to intimacy."),
(5, "Clear attraction: meaningful glances, leaning in, playful double meanings if genre fits."),
(3, "Mild warmth: subtle flirt only when appropriate; easily redirected."),
(1, "Little romantic charge: platonic focus unless the player escalates."),
(0, "No romantic or sexual undertone in body language or subtext."),
],
)
stamina_line = _band_instruction(
stamina,
[
(9, "Peak energy: brisk movement, sharp focus, may offer to take strenuous actions."),
(7, "Well-rested: alert, steady pace, normal exertion."),
(5, "Average fatigue: fine for talk/light action; heavy labor needs justification."),
(4, "Tired: slower reactions, sits when possible, voice softer."),
(3, "Heavy fatigue: frequent pauses, avoids running/fighting, may ask to stop."),
(2, "Exhausted: barely moves, slumped posture, short sentences, needs support to walk."),
(1, "On the verge of collapse: eyelids heavy, may stumble or nearly pass out; minimal action only."),
(0, "Cannot sustain activity: collapse/immediate sleep/rest is imminent unless helped."),
],
)
tension_line = _band_instruction(
tension,
[
(9, "Breaking point: trembling, tears or rage close to surface, irrational snap decisions."),
(7, "High stress: clipped speech, hyper-vigilant, jumps at sounds, defensive."),
(5, "Uneasy: fidgeting, forced smiles, changes subject from danger."),
(3, "Mild edge: occasional sigh, watches exits, relaxes with reassurance."),
(1, "Mostly calm: normal breathing, open posture."),
(0, "Fully at ease in body and voice."),
],
)
return (
"\n\n--- Physical & emotional state (MANDATORY) ---\n"
+ _mandatory_state_preamble()
+ f"Internal scales (010, never spoken): lust/arousal={lust}, stamina/energy={stamina}, tension/stress={tension}.\n"
f"- Lust/arousal: {lust_line}\n"
f"- Stamina/energy: {stamina_line}\n"
f"- Tension/stress: {tension_line}\n"
"If multiple apply, combine them (e.g. low stamina + high tension = shaky exhaustion, not peppy panic).\n"
"---"
)
def format_narrator_outcome_for_llm(data: dict) -> str:
"""Turn stored narrator JSON into a binding user-turn for the character model."""
roll = data.get("roll")
outcome = (data.get("outcome") or "").strip().lower()
text = (data.get("text") or "").strip()
lines = [
"--- Narrator ruling (MANDATORY — your next in-character reply MUST follow this) ---",
f"Roll d20={roll}. Outcome: {outcome}.",
f"What ACTUALLY happened (canonical truth): {text}",
]
if outcome in ("failure", "critical failure"):
lines.append(
"The player's action FAILED as they imagined it. "
"Do NOT write a success version: no crowd fleeing, no intimidation working, "
"no effortless victory. Show the failure, embarrassment, or partial result above."
)
elif outcome == "critical success":
lines.append(
"The attempt succeeded dramatically. You may show amplified success aligned with the outcome above."
)
else:
lines.append(
"The attempt succeeded. Your reply must align with the narrator outcome above, not contradict it."
)
lines.append("Respond as the character to THIS outcome only. Never cite dice, rolls, or stats.")
lines.append("---")
return "\n".join(lines)
def format_user_message_for_llm(content: str, *, has_dice_resolution: bool) -> str:
if not has_dice_resolution:
return content
return (
"[Player stated intent — canonical outcome is in the narrator ruling immediately below]\n"
+ content
)
def scene_to_sd_hint(scene: dict) -> str:
if not scene:
return ""
parts = []
for k in ("place", "time_of_day", "day", "weather", "layout_note"):
v = (scene.get(k) or "").strip()
if v:
parts.append(f"{k}: {v}")
return "\n".join(parts)
async def apply_narrator_post(session_id: str, post: dict, rpg_settings: dict, session: dict | None = None) -> dict:
"""Persist narrator_post fields into session. Returns summary of what changed."""
from services.memory import (
get_session,
update_session_status_quo,
update_session_affinity,
update_session_outfit,
update_session_scene,
update_session_narrative_stats,
update_session_facts,
upsert_quest,
)
from services.rpg_facts import merge_facts_persist, rp_day_from_scene, parse_facts_list
if session is None:
session = await get_session(session_id) or {}
applied = {
"status_quo": False,
"facts_added": 0,
"affinity_delta": 0,
"quests_updated": 0,
"scene": False,
"outfit": False,
}
sq = (post.get("status_quo_update") or "").strip()
if sq:
await update_session_status_quo(session_id, sq)
applied["status_quo"] = True
post_facts = post.get("facts") or []
if isinstance(post_facts, list) and post_facts:
rp_day = rp_day_from_scene(session.get("scene_json"))
before = len(parse_facts_list(session.get("facts_json") or "[]"))
merged = await merge_facts_persist(
session.get("facts_json", "[]"),
post_facts,
rp_day_default=rp_day,
scene_context=json.dumps(
parse_scene_json(session.get("scene_json")), ensure_ascii=False
),
status_quo=session.get("status_quo") or "",
)
await update_session_facts(session_id, merged)
after = len(parse_facts_list(merged))
applied["facts_added"] = max(0, after - before)
if rpg_settings.get("affinity", True):
delta = int(post.get("affinity_delta") or 0)
if delta:
await update_session_affinity(session_id, delta)
applied["affinity_delta"] = delta
outfit_update = post.get("outfit_update")
if isinstance(outfit_update, list) and outfit_update:
from services.outfit_tags import outfit_list_to_json
await update_session_outfit(session_id, outfit_list_to_json(outfit_update))
applied["outfit"] = True
scene_update = post.get("scene_update")
if isinstance(scene_update, dict) and scene_update:
merged = merge_scene(parse_scene_json(session.get("scene_json")), scene_update)
await update_session_scene(session_id, json.dumps(merged, ensure_ascii=False))
applied["scene"] = True
if rpg_settings.get("stats", False):
stats_delta = post.get("stats_delta")
if isinstance(stats_delta, dict) and stats_delta:
current = parse_stats_json(session.get("narrative_stats_json"))
updated = apply_stats_delta(current, stats_delta)
await update_session_narrative_stats(
session_id, json.dumps(updated, ensure_ascii=False)
)
if rpg_settings.get("quests", True):
for qu in post.get("quest_updates") or []:
title = (qu.get("title") or "").strip()
if title:
await upsert_quest(session_id, title[:120], qu.get("status", "active"))
applied["quests_updated"] += 1
return applied
+48
View File
@@ -0,0 +1,48 @@
"""Run ComfyUI generation from SdPromptBundle (single hybrid prompt for Anima)."""
import logging
import os
from services import sdbackend as sd_service
from services.memory import update_message_image, update_message_prompt, update_message_prompt_alt
from services.sd_prompt import SdPromptBundle
logger = logging.getLogger(__name__)
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
async def run_sd_for_message(bundle: SdPromptBundle | None, msg_id: int | None) -> dict:
"""Generate image, persist prompts/paths on message. Returns fields for API/SSE."""
out = {
"image_prompt": None,
"image_prompt_alt": None,
"image_path": None,
"image_path_alt": None,
"image_error": None,
"image_error_alt": None,
}
if not bundle or not bundle.tag_full:
return out
out["image_prompt"] = bundle.tag_full
if bundle.desc_full and bundle.desc_full != bundle.tag_full:
out["image_prompt_alt"] = bundle.desc_full
if msg_id:
await update_message_prompt(msg_id, bundle.tag_full)
if out["image_prompt_alt"]:
await update_message_prompt_alt(msg_id, out["image_prompt_alt"])
if not SD_AUTO_GENERATE:
return out
rel, err = await sd_service.generate_from_full_prompt(bundle.tag_full)
if rel:
out["image_path"] = f"/static/{rel}"
if msg_id:
await update_message_image(msg_id, rel)
else:
out["image_error"] = err
return out
+650 -67
View File
@@ -2,26 +2,115 @@ import json
import logging
import os
import re
from dataclasses import dataclass
from services.llm import send_message, send_message_with_model
from services.personas import get_persona
logger = logging.getLogger(__name__)
NEGATIVE_PROMPT_SEPARATOR = "\n\n__NEGATIVE_PROMPT__\n\n"
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
Given a roleplay chat excerpt, output ONLY valid JSON (no markdown):
{
"should_generate": true,
"shot_type": "first_person_pov" | "landscape" | "third_person",
"action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'",
"environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'"
"action_tags": "booru-style tags for pose/action/expression",
"environment_tags": "booru-style tags for location/lighting/time"
}
Rules:
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'.
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined.
- Do NOT include appearance/character tags — those are provided separately.
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
- Do NOT invent tags. If unsure — omit.
- Keep each field to 3-6 tags."""
- Keep action_tags and environment_tags to 3-6 tags each.
- shot_type: default "first_person_pov" for dialogue/intimacy at arm's length. "third_person" only for wide action (fight, chase). "landscape" only when environment is the focus.
- should_generate: false for non-visual beats (pure internal monologue, time skips with no new pose, empty lines).
- NEVER use negative words in tag fields (not, without, naked, nsfw, etc.)."""
ANIMA_BUILDER_EXTRA = """
Anima hybrid mode — ALSO include:
"pov_cue": "face_to_face" | "walking_together" | "doorway_invite" | "reach_to_viewer" | "dialogue_close",
"viewer_body_visible": false,
"scene_description": "ONE short English sentence (max 40 words). Camera POV: what the viewer sees. Mood/atmosphere only — do NOT repeat tags from action_tags/environment_tags. Do NOT list comma-separated booru tags."
POV / interaction rules:
- Default viewer_body_visible: false. The viewer's body, hands, or face must NOT appear in the image — only the character toward the camera.
- For hugs, embraces: use arms_out, reaching_towards_viewer, inviting_hug — NOT holding_hands, lifting, carrying, nose_rub (these draw a second body in POV).
- For long messages with time skips ("About an hour later..."), illustrate ONLY the final visible beat (usually the last paragraph).
- scene_description: describe HER toward the camera only — NEVER "someone", "both", "with you", "hand in hand with", or another person's body.
- NEVER use tags: looking_at_each_other, couple, 2girls, 2boys, multiple_girls. For POV walking together omit holding_hands (use walking, smiling, reaching_towards_viewer instead).
- pov_cue: pick the framing that matches the CURRENT beat (walking_together for strolling side by side, doorway_invite for doorway with arms open, reach_to_viewer when she reaches toward camera, face_to_face for close dialogue).
- Illustrate ONLY the beat under === ILLUSTRATE ===; use === Context === for outfit/location hints only.
- Do NOT put English sentences in action_tags or environment_tags — tags only."""
POV_CUE_PHRASES: dict[str, str] = {
"face_to_face": "POV: close face-to-face, she looks directly at you",
"walking_together": "POV: walking beside you, profile and shared path visible",
"doorway_invite": "POV: she blocks the doorway, arms open toward you",
"reach_to_viewer": "POV: she reaches toward the camera",
"dialogue_close": "POV: close conversation, she faces you at arm's length",
}
POV_CUE_DEFAULT = "POV: she stands before you, facing the camera"
POV_INTERACTION_NEGATIVE = (
"duplicate, clone, multiple_girls, 2girls, extra_person, pov hands, "
"disembodied hands, extra arms, second person"
)
_CONTACT_ACTION_KEYWORDS = (
"hug", "holding_hands", "hand_holding", "arms_out", "embrace",
"reaching", "inviting_hug", "arm_around", "cuddling",
)
_JUNK_STANDALONE_TAGS = frozenset({
"white", "black", "skin", "ear", "ears", "girl", "boy", "fox", "wolf", "cat",
"short", "tall", "golden", "silver", "red", "blue", "green", "purple",
"pink", "brown", "eye", "eyes", "hair",
})
_INVALID_TAGS = frozenset({
"pumped_up", "pumped", "looking_at_each_other", "couple",
"2girls", "2boys", "multiple_girls", "multiple_boys", "duo",
})
_POV_DROP_ACTION_TAGS = frozenset({
"holding_hands", "hand_holding", "looking_at_each_other", "couple",
"lifting", "carry", "carrying", "princess_carry", "nose_rub", "nose_boop",
})
_TIME_SKIP_RE = re.compile(
r"(?i)\b(?:about an hour later|hours later|later that (?:day|evening|night)|"
r"the next (?:day|morning|evening)|meanwhile|after (?:some )?time)\b[.…\s]*",
)
_POV_MOOD_FALLBACK: dict[str, str] = {
"walking_together": "Easy warmth and quiet laughter in the afternoon light.",
"doorway_invite": "Cool air and playful tension as she waits in the doorway.",
"reach_to_viewer": "A charged moment as she reaches toward the camera.",
"face_to_face": "Her expression softens in close focus toward the camera.",
"dialogue_close": "Intimate calm in the space between you.",
}
_INDOOR_ENV_MARKERS = frozenset({"doorway", "indoors", "indoor", "apartment", "inside", "room"})
_OUTDOOR_ENV_MARKERS = frozenset({"outdoor", "outdoors", "outside", "street"})
_POV_PROSE_BANNED = re.compile(
r"\b(someone|both|together with|hand in hand with|another person|second person|"
r"your hands|your fingers|your embrace|your heat|intertwined|with you|"
r"demands your|before you)\b",
re.IGNORECASE,
)
SD_ANIMA_DUAL_COMPARE = os.getenv("SD_ANIMA_DUAL_COMPARE", "false").lower() in ("1", "true", "yes")
@dataclass
class SdPromptBundle:
tag_full: str
negative: str
desc_full: str | None = None
def extract_image_prompt_tag(text: str) -> str | None:
@@ -44,7 +133,7 @@ SD_UNET = os.getenv("SD_UNET", "")
SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip()
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
@@ -56,37 +145,201 @@ def _is_anima() -> bool:
return bool(SD_UNET) and not SD_CHECKPOINT
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
def anima_dual_enabled() -> bool:
return _is_anima() and SD_ANIMA_DUAL_COMPARE
def _builder_system() -> str:
if _is_anima():
return PROMPT_BUILDER_SYSTEM + ANIMA_BUILDER_EXTRA
return PROMPT_BUILDER_SYSTEM
def _normalize_shot_type(scene: dict) -> dict:
st = (scene.get("shot_type") or "").strip().lower()
if st == "landscape":
scene["shot_type"] = "landscape"
return _sanitize_scene_fields(scene)
if st == "third_person":
action = (scene.get("action_tags") or "").lower()
wide = ("battle", "fight", "chase", "running", "crowd", "wide_shot", "group_shot")
if any(w in action for w in wide):
scene["shot_type"] = "third_person"
return _sanitize_scene_fields(scene)
scene["shot_type"] = "first_person_pov"
if scene.get("viewer_body_visible") is None:
scene["viewer_body_visible"] = False
return _sanitize_scene_fields(scene)
def _split_tag_input(tag_str: str) -> list[str]:
return [t.strip() for t in (tag_str or "").split(",") if t.strip()]
def _is_sentence_like_tag(tag: str) -> bool:
t = tag.strip()
if len(t) > 45:
return True
if re.search(r"[.!?]", t):
return True
words = t.split()
return len(words) >= 5 and "_" not in t
def _filter_tag_field(tag_str: str, *, for_pov: bool, field: str) -> str:
kept: list[str] = []
for raw in _split_tag_input(tag_str):
key = raw.lower().replace(" ", "_")
if key in _INVALID_TAGS:
continue
if _is_sentence_like_tag(raw):
continue
if for_pov and field == "action" and key in _POV_DROP_ACTION_TAGS:
continue
kept.append(raw if "_" in raw else key)
return ", ".join(kept)
def _reconcile_environment_tags(env_str: str) -> str:
tags = _split_tag_input(env_str)
keys = {t.lower().replace(" ", "_") for t in tags}
has_indoor = bool(keys & _INDOOR_ENV_MARKERS) or any(
any(m in k for m in _INDOOR_ENV_MARKERS) for k in keys
)
has_outdoor = bool(keys & _OUTDOOR_ENV_MARKERS) or any(
any(m in k for m in _OUTDOOR_ENV_MARKERS) for k in keys
)
if has_indoor and has_outdoor:
tags = [t for t in tags if t.lower().replace(" ", "_") not in _OUTDOOR_ENV_MARKERS]
return ", ".join(tags)
def _sanitize_pov_prose(desc: str, scene: dict) -> str:
if not desc or not desc.strip():
return ""
if scene.get("shot_type") != "first_person_pov":
return desc.strip()
kept: list[str] = []
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
s = sentence.strip()
if not s:
continue
if _POV_PROSE_BANNED.search(s):
continue
if re.search(r"\bwolfgirl\b", s, re.I) and re.search(
r"\b(walks|walking|stands)\b", s, re.I
):
continue
kept.append(s)
out = " ".join(kept).strip()
return re.sub(r"\bat the viewer\b", "at the camera", out, flags=re.IGNORECASE)
def _sanitize_scene_fields(scene: dict) -> dict:
scene = dict(scene)
for_pov = scene.get("shot_type") == "first_person_pov"
scene["action_tags"] = _filter_tag_field(
scene.get("action_tags") or "", for_pov=for_pov, field="action"
)
env = _filter_tag_field(scene.get("environment_tags") or "", for_pov=False, field="env")
scene["environment_tags"] = _reconcile_environment_tags(env)
scene["scene_description"] = _sanitize_pov_prose(
(scene.get("scene_description") or "").strip(), scene
)
return scene
def _scene_should_generate(scene: dict) -> bool:
if scene.get("should_generate") is False:
return False
return True
def _sanitize_tags_string(tag_str: str) -> str:
if not tag_str:
return ""
out: list[str] = []
seen: set[str] = set()
for raw in tag_str.split(","):
t = raw.strip()
if not t:
continue
key = t.lower().replace(" ", "_")
if key in seen:
continue
if key in _INVALID_TAGS:
continue
if "_" not in key and key in _JUNK_STANDALONE_TAGS:
continue
if len(key) <= 2:
continue
seen.add(key)
out.append(t if "_" in t else key)
return ", ".join(out)
def _quality_prefix() -> str:
if _is_pony():
quality = "score_9, score_8_up, score_7_up, source_anime, highres"
elif _is_anima():
quality = "masterpiece, best quality, score_7, anime"
else:
quality = "masterpiece, best quality, highres"
return "score_9, score_8_up, score_7_up, source_anime, highres"
if _is_anima():
return "masterpiece, best quality, score_7, anime"
return "masterpiece, best quality, highres"
parts = [quality]
appearance = (persona or {}).get("appearance_tags", "")
if appearance:
parts.append(appearance)
if outfit_tags:
parts.append(outfit_tags)
def _appearance_for_persona(persona: dict | None) -> str:
"""Tag core uses appearance_tags only (prose is for LLM context, not Comfy tag line)."""
return _sanitize_tags_string((persona or {}).get("appearance_tags", ""))
if scene.get("shot_type") == "landscape":
parts.append(scene.get("environment_tags", ""))
else:
if scene.get("shot_type") == "first_person_pov":
parts.append("pov, first-person view, looking at viewer")
parts.append(scene.get("action_tags", ""))
parts.append(scene.get("environment_tags", ""))
def _dedupe_outfit_tags(outfit_tags: str) -> str:
tags = _split_tag_input(outfit_tags)
keys = {t.lower().replace(" ", "_") for t in tags}
if len(keys & {"jeans", "ripped_jeans", "black_jeans"}) > 1 and "jeans" in keys:
tags = [t for t in tags if t.lower().replace(" ", "_") != "jeans"]
return ", ".join(tags)
def _scene_has_physical_contact(scene: dict) -> bool:
action = (scene.get("action_tags") or "").lower()
return any(k in action for k in _CONTACT_ACTION_KEYWORDS)
def _infer_pov_cue_from_action(action_tags: str) -> str:
action = (action_tags or "").lower()
if any(k in action for k in ("holding_hands", "hand_holding", "walking", "strolling")):
return "walking_together"
if any(k in action for k in ("doorway", "door", "entry", "threshold")):
if any(k in action for k in ("arms_out", "hug", "embrace", "inviting")):
return "doorway_invite"
if any(k in action for k in ("arms_out", "reaching", "inviting_hug", "hug", "embrace")):
return "reach_to_viewer"
if any(k in action for k in ("sitting", "lying", "bed")):
return "dialogue_close"
return "face_to_face"
def _build_pov_phrase(scene: dict) -> str:
if scene.get("shot_type") != "first_person_pov":
return ""
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
if cue in POV_CUE_PHRASES:
return POV_CUE_PHRASES[cue]
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
return POV_CUE_PHRASES.get(inferred, POV_CUE_DEFAULT)
def _append_lora(parts: list[str], persona: dict | None) -> None:
lora = (persona or {}).get("lora_name", "")
weight = (persona or {}).get("lora_weight", 0.8)
if lora:
parts.append(f"<lora:{lora}:{weight}>")
def _dedupe_comma_join(parts: list[str]) -> str:
positive = ", ".join(p.strip() for p in parts if p and p.strip())
seen, deduped = set(), []
seen: set[str] = set()
deduped: list[str] = []
for tag in positive.split(", "):
t = tag.strip()
if t and t not in seen:
@@ -95,53 +348,152 @@ def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str =
return ", ".join(deduped)
async def generate_sd_prompt(
messages: list,
persona_id: str,
outfit_json: str = "[]",
) -> tuple[str | None, str | None]:
persona = await get_persona(persona_id)
# Generate only if persona has appearance tags
if not persona or not (persona.get("appearance_tags") or "").strip():
logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id)
return None, None
def _build_tag_core(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
"""Anchor + structure: quality, appearance, outfit, action/env tags, LoRA. No POV prose, no scene_description."""
parts = [_quality_prefix()]
appearance = _appearance_for_persona(persona)
if appearance:
parts.append(appearance)
if outfit_tags:
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
if scene.get("shot_type") == "landscape":
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
else:
if not _is_anima() and scene.get("shot_type") == "first_person_pov":
parts.append("pov, first-person view, looking at viewer")
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
_append_lora(parts, persona)
return _dedupe_comma_join(parts)
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
if not recent:
return None, None
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
def build_positive_prompt_tags_only(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
"""Tags + contextual POV phrase (Anima) or legacy Pony path."""
if not _is_anima():
return build_positive_prompt(scene, persona, outfit_tags)
core = _build_tag_core(scene, persona, outfit_tags)
pov = _build_pov_phrase(scene)
if pov:
return f"{core}, {pov}" if core else pov
return core
builder_messages = [
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
{"role": "user", "content": f"Chat:\n{excerpt}"},
]
try:
if SD_PROMPT_MODEL:
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
else:
raw = await send_message(builder_messages)
raw = raw.strip()
if raw.startswith("```"):
raw = re.sub(r"^```\w*\n?", "", raw)
raw = re.sub(r"\n?```$", "", raw)
scene = json.loads(raw)
if not isinstance(scene, dict):
logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw)
return None, None
except Exception as e:
logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", ""))
return None, None
def _tag_tokens_for_dedupe(tag_line: str) -> set[str]:
tokens: set[str] = set()
for part in tag_line.replace("<lora:", " ").split(","):
for word in re.split(r"[\s_./]+", part.lower()):
w = word.strip()
if len(w) >= 4:
tokens.add(w)
return tokens
try:
outfit_list = json.loads(outfit_json or "[]")
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
except Exception:
outfit_tags = ""
positive = build_positive_prompt(scene, persona, outfit_tags)
def _trim_redundant_scene_description(desc: str, tag_line: str) -> str:
tag_tokens = _tag_tokens_for_dedupe(tag_line)
if not tag_tokens or not desc.strip():
return desc.strip()
kept: list[str] = []
for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()):
s = sentence.strip()
if not s:
continue
words = [w.lower() for w in re.findall(r"[a-zA-Z]{4,}", s)]
if not words:
kept.append(s)
continue
overlap = sum(1 for w in words if w in tag_tokens) / len(words)
if overlap < 0.62:
kept.append(s)
return " ".join(kept).strip()
def _extract_illustrate_content(content: str, max_chars: int = 1400) -> str:
"""Long assistant posts (first_mes): use final beat after time-skip, last paragraphs."""
text = strip_image_prompt_tag(content).strip()
if not text:
return ""
chunks = _TIME_SKIP_RE.split(text)
if len(chunks) > 1:
text = chunks[-1].strip()
if len(text) <= max_chars:
return text
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
if paragraphs:
for n in (1, 2, 3):
tail = "\n\n".join(paragraphs[-n:])
if len(tail) <= max_chars:
return tail
return paragraphs[-1][-max_chars:]
return text[-max_chars:]
def _fallback_mood_prose(scene: dict) -> str:
cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_")
if cue in _POV_MOOD_FALLBACK:
return _POV_MOOD_FALLBACK[cue]
inferred = _infer_pov_cue_from_action(scene.get("action_tags", ""))
return _POV_MOOD_FALLBACK.get(inferred, "Soft atmosphere; her expression toward the camera.")
def _cap_scene_description(desc: str, max_words: int = 40, max_chars: int = 220) -> str:
words = desc.split()
if len(words) > max_words:
desc = " ".join(words[:max_words])
if len(desc) > max_chars:
desc = desc[: max_chars - 3] + "..."
return desc
def build_positive_prompt_hybrid(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
"""Production Anima prompt: tag core + POV cue + short mood prose."""
if not _is_anima():
return build_positive_prompt(scene, persona, outfit_tags)
base = build_positive_prompt_tags_only(scene, persona, outfit_tags)
desc = _trim_redundant_scene_description(
(scene.get("scene_description") or "").strip(),
base,
)
desc = _cap_scene_description(desc)
if not desc:
desc = _cap_scene_description(_fallback_mood_prose(scene))
if not desc:
return base
lora = (persona or {}).get("lora_name", "")
weight = (persona or {}).get("lora_weight", 0.8)
lora_suffix = f" <lora:{lora}:{weight}>" if lora else ""
if lora_suffix and base.endswith(lora_suffix):
base = base[: -len(lora_suffix)]
return f"{base}. {desc}{lora_suffix}"
return f"{base}. {desc}"
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
"""Legacy entry: Pony/non-Anima full prompt; Anima delegates to tags-only."""
if _is_anima():
return build_positive_prompt_tags_only(scene, persona, outfit_tags)
parts = [_quality_prefix()]
appearance = _appearance_for_persona(persona)
if appearance:
parts.append(appearance)
if outfit_tags:
parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags)))
if scene.get("shot_type") == "landscape":
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
else:
if scene.get("shot_type") == "first_person_pov":
parts.append("pov, first-person view, looking at viewer")
parts.append(_sanitize_tags_string(scene.get("action_tags", "")))
parts.append(_sanitize_tags_string(scene.get("environment_tags", "")))
_append_lora(parts, persona)
return _dedupe_comma_join(parts)
def _negative_for_scene(scene: dict) -> str:
if _is_pony():
negative = PONY_NEGATIVE
elif _is_anima():
@@ -151,6 +503,237 @@ async def generate_sd_prompt(
if scene.get("shot_type") == "first_person_pov":
negative += ", third person, over the shoulder"
viewer_visible = scene.get("viewer_body_visible") is True
if not viewer_visible or _scene_has_physical_contact(scene):
negative += ", " + POV_INTERACTION_NEGATIVE
full = positive + f"\n\nNegative prompt: {negative}"
return full, negative
return negative
def _format_builder_user_block(
persona: dict, messages: list[dict], outfit_json: str, scene_json: str = "{}"
) -> str:
lines: list[str] = []
tags = (persona.get("appearance_tags") or "").strip()
lines.append(f"Character appearance (tags): {tags}")
prose = (persona.get("appearance_prose") or "").strip()
if _is_anima() and prose and prose != tags:
snippet = prose[:300] + ("..." if len(prose) > 300 else "")
lines.append(f"Character notes (do not copy into tags or scene_description): {snippet}")
try:
outfit_list = json.loads(outfit_json or "[]")
outfit_ref = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
except Exception:
outfit_ref = ""
if outfit_ref:
lines.append(f"Current outfit (tags): {outfit_ref}")
from services.rpg_state import parse_scene_json, scene_to_sd_hint
scene_hint = scene_to_sd_hint(parse_scene_json(scene_json))
if scene_hint:
lines.append(f"Scene (location/time):\n{scene_hint}")
recent = [m for m in messages if m.get("role") in ("user", "assistant")][-6:]
if not recent:
lines.append("\nChat:\n(no messages — return should_generate=false)")
return "\n".join(lines)
illustrate: list[dict] = []
if recent[-1]["role"] == "assistant":
illustrate = [recent[-1]]
if len(recent) >= 2 and recent[-2]["role"] == "user":
illustrate.insert(0, recent[-2])
else:
illustrate = [recent[-1]]
if len(recent) >= 2 and recent[-2]["role"] == "assistant":
illustrate.insert(0, recent[-2])
context = [m for m in recent if m not in illustrate]
lines.append("\n=== ILLUSTRATE (draw THIS beat only) ===")
for m in illustrate:
raw = m.get("content", "")
content = _extract_illustrate_content(raw) if m.get("role") == "assistant" else strip_image_prompt_tag(raw)
lines.append(f"{m['role']}: {content}")
if context:
lines.append("\n=== Context (outfit/location hints only — do not illustrate old beats) ===")
for m in context:
content = strip_image_prompt_tag(m.get("content", ""))
if len(content) > 800:
content = content[:797] + "..."
lines.append(f"{m['role']}: {content}")
return "\n".join(lines)
def _parse_scene_json(raw: str) -> dict:
cleaned = raw.strip()
if cleaned.startswith("```"):
cleaned = re.sub(r"^```\w*\n?", "", cleaned)
cleaned = re.sub(r"\n?```$", "", cleaned)
scene = json.loads(cleaned)
if not isinstance(scene, dict):
raise ValueError("LLM returned non-object JSON")
return _normalize_shot_type(scene)
def _bundle_from_scene(scene: dict, persona: dict, outfit_tags: str) -> SdPromptBundle:
negative = _negative_for_scene(scene)
if _is_anima():
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
tag_full = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
desc_full = None
if anima_dual_enabled():
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
desc_full = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=desc_full)
positive = build_positive_prompt(scene, persona, outfit_tags)
tag_full = positive + NEGATIVE_PROMPT_SEPARATOR + negative
return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=None)
def _parse_chat_excerpt(excerpt: str) -> list[dict]:
messages: list[dict] = []
for line in (excerpt or "").splitlines():
line = line.strip()
if not line:
continue
lower = line.lower()
if lower.startswith("user:"):
messages.append({"role": "user", "content": line[5:].strip()})
elif lower.startswith("assistant:"):
messages.append({"role": "assistant", "content": line[10:].strip()})
elif lower.startswith("system:"):
messages.append({"role": "system", "content": line[7:].strip()})
else:
messages.append({"role": "user", "content": line})
return messages
async def run_prompt_builder(
persona_id: str,
*,
messages: list[dict] | None = None,
chat_excerpt: str = "",
outfit_json: str = "[]",
appearance_override: str | None = None,
use_prose: bool = False,
) -> dict:
"""Debug: full SD prompt builder pipeline with LLM raw output."""
persona = await get_persona(persona_id) or {}
if appearance_override is not None:
persona = {**persona, "appearance_tags": appearance_override}
recent = messages if messages is not None else _parse_chat_excerpt(chat_excerpt)
recent = [m for m in recent if m.get("role") in ("user", "assistant")]
user_block = _format_builder_user_block(persona, recent, outfit_json)
builder_messages = [
{"role": "system", "content": _builder_system()},
{"role": "user", "content": user_block},
]
model_used = SD_PROMPT_MODEL or "SYSTEM_MODEL"
result: dict = {
"persona_id": persona_id,
"sd_prompt_model": model_used,
"builder_system": _builder_system(),
"builder_user": user_block,
"anima_dual": anima_dual_enabled(),
}
raw = ""
try:
if SD_PROMPT_MODEL:
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
else:
raw = await send_message(builder_messages)
result["llm_raw"] = raw
scene = _parse_scene_json(raw)
result["scene"] = scene
if not _scene_should_generate(scene):
result["skipped"] = True
result["error"] = "should_generate=false"
return result
try:
outfit_tags = ", ".join(json.loads(outfit_json or "[]"))
except Exception:
outfit_tags = ""
negative = _negative_for_scene(scene)
if _is_anima():
tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags)
hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags)
result["tag_positive"] = tags_only
result["hybrid_positive"] = hybrid
result["negative"] = negative
result["tags_only_full"] = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative
result["hybrid_full"] = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative
result["tag_full"] = result["hybrid_full"]
else:
positive = build_positive_prompt(scene, persona, outfit_tags)
result["tag_positive"] = positive
result["negative"] = negative
result["tag_full"] = positive + NEGATIVE_PROMPT_SEPARATOR + negative
except Exception as e:
result["error"] = str(e)
result["llm_raw"] = raw or result.get("llm_raw", "")
return result
async def generate_sd_prompt(
messages: list,
persona_id: str,
outfit_json: str = "[]",
scene_json: str = "{}",
) -> SdPromptBundle | None:
persona = await get_persona(persona_id)
if not persona:
return None
recent = [m for m in messages if m["role"] in ("user", "assistant")]
if not recent:
return None
user_block = _format_builder_user_block(persona, recent, outfit_json, scene_json)
builder_messages = [
{"role": "system", "content": _builder_system()},
{"role": "user", "content": user_block},
]
raw = ""
try:
if SD_PROMPT_MODEL:
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
else:
raw = await send_message(builder_messages)
scene = _parse_scene_json(raw)
except Exception as e:
logger.warning("sd_prompt failed: %s raw=%.200s", e, raw)
return None
if not _scene_should_generate(scene):
logger.info("sd_prompt: skipped (should_generate=false)")
return None
try:
outfit_list = json.loads(outfit_json or "[]")
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
except Exception:
outfit_tags = ""
bundle = _bundle_from_scene(scene, persona, outfit_tags)
if anima_dual_enabled() and bundle.desc_full:
logger.info(
"Anima prompts: hybrid=%.80s | tags_only=%.80s",
bundle.tag_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
bundle.desc_full.split(NEGATIVE_PROMPT_SEPARATOR)[0],
)
return bundle
+323 -24
View File
@@ -3,6 +3,7 @@ import logging
import os
import uuid
from pathlib import Path
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import httpx
from dotenv import load_dotenv
@@ -11,7 +12,178 @@ load_dotenv()
logger = logging.getLogger(__name__)
SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/")
def _parse_basic_auth() -> httpx.BasicAuth | None:
"""
Vast Caddy on mapped ports often uses Basic realm=restricted.
Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD.
"""
raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip()
if raw:
if ":" in raw:
user, _, password = raw.partition(":")
else:
user, password = "", raw
return httpx.BasicAuth(user, password)
user = (os.getenv("SD_COMFY_USER") or "").strip()
password = (os.getenv("SD_COMFY_PASSWORD") or "").strip()
if user or password:
return httpx.BasicAuth(user, password)
return None
SD_BASIC_AUTH = _parse_basic_auth()
def _parse_comfy_config() -> tuple[str, dict[str, str]]:
"""
SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=...
API paths must be base + /prompt, not ...?token=xxx/prompt
"""
raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip()
extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip()
parsed = urlparse(raw)
base = f"{parsed.scheme}://{parsed.netloc}"
path = (parsed.path or "").rstrip("/")
if path and path != "/":
base = f"{base}{path}"
query: dict[str, str] = {}
for key, values in parse_qs(parsed.query).items():
if values:
query[key] = values[-1]
if extra_token:
query["token"] = extra_token
base = base.rstrip("/")
# Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply
if "trycloudflare.com" in base.lower():
if query.pop("token", None):
logger.info(
"SD_BASE_URL is trycloudflare tunnel: Vast token stripped. "
"Use tunnel for port 8188 only (see instance Port Mapping)."
)
return base, query
SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config()
def _comfy_url(path: str) -> str:
if not path.startswith("/"):
path = f"/{path}"
return f"{SD_BASE_URL}{path}"
def _log_comfy_target() -> str:
if SD_QUERY_PARAMS.get("token"):
return f"{SD_BASE_URL}?token=***"
return SD_BASE_URL
def _absolute_url(location: str, fallback_path: str = "/") -> str:
if not location:
return _comfy_url(fallback_path)
if location.startswith(("http://", "https://")):
return location
if location.startswith("/"):
return f"{SD_BASE_URL}{location}"
return f"{SD_BASE_URL}/{location}"
def _url_with_token(url: str) -> str:
"""Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect)."""
if not SD_QUERY_PARAMS.get("token"):
return url
p = urlparse(url)
q: dict[str, str] = {}
for key, values in parse_qs(p.query).items():
if values:
q[key] = values[-1]
q.update(SD_QUERY_PARAMS)
return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), ""))
def _merge_params(extra: dict | None) -> dict | None:
if not SD_QUERY_PARAMS and not extra:
return None
merged = dict(SD_QUERY_PARAMS)
if extra:
merged.update(extra)
return merged
def _is_vast_gateway() -> bool:
return "trycloudflare.com" not in SD_BASE_URL.lower()
def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=timeout,
follow_redirects=False,
auth=SD_BASIC_AUTH,
)
async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None:
"""
Vast Caddy: browser opens /?token=… and gets a session cookie; API then works.
Prime with redirects so Set-Cookie is collected, then merge into the API client.
"""
token = SD_QUERY_PARAMS.get("token")
if not token or not _is_vast_gateway():
return
try:
async with httpx.AsyncClient(
timeout=30,
follow_redirects=True,
auth=SD_BASIC_AUTH,
) as prime:
r = await prime.get(_comfy_url("/"), params={"token": token})
client.cookies.update(prime.cookies)
logger.info(
"Comfy gateway prime GET /?token=*** → %s, cookies=%s",
r.status_code,
list(prime.cookies.keys()) or "(none)",
)
except Exception as e:
logger.warning("Comfy gateway prime failed: %s", e)
async def _comfy_request(
client: httpx.AsyncClient,
method: str,
path: str,
*,
params: dict | None = None,
**kwargs,
) -> httpx.Response:
"""
Comfy API: trycloudflare tunnel = no token.
Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached.
"""
url = _comfy_url(path)
extra = params or {}
token = SD_QUERY_PARAMS.get("token")
use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
if token and _is_vast_gateway():
await _prime_comfy_gateway(client)
req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None)
resp: httpx.Response | None = None
for hop in range(6):
resp = await client.request(method, url, params=req_params, **kwargs)
if resp.status_code not in (301, 302, 303, 307, 308):
return resp
loc = _absolute_url(resp.headers.get("location", ""), path)
url = _url_with_token(loc) if use_vast_auth else loc
req_params = extra or None
logger.info("Comfy redirect %s hop %s%s", resp.status_code, hop + 1, url.split("?")[0])
assert resp is not None
return resp
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
SD_CFG = float(os.getenv("SD_CFG", "7"))
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
@@ -26,6 +198,8 @@ SD_DEFAULT_NEGATIVE = os.getenv(
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "")
SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7"))
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
@@ -38,20 +212,38 @@ def _use_anima() -> bool:
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
# Try new separator first
sep = "__NEGATIVE_PROMPT__"
if f"\n{sep}\n" in full_prompt:
pos, _, neg = full_prompt.partition(f"\n{sep}\n")
return pos.strip(), neg.strip()
# Fallback to old format
if "\n\nNegative prompt:" in full_prompt:
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
return pos.strip(), neg.strip()
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
def _build_workflow(positive: str, negative: str) -> dict:
def _workflow_uses_anima(overrides: dict | None) -> bool:
if overrides and overrides.get("checkpoint"):
return False
if overrides and overrides.get("unet"):
return True
return _use_anima()
def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict:
seed = int(uuid.uuid4().int % 2**32)
if _use_anima():
return {
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}},
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}},
"15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}},
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
o = overrides or {}
if _workflow_uses_anima(o):
unet = o.get("unet") or SD_UNET
clip = o.get("clip") or SD_CLIP
vae = o.get("vae") or SD_VAE
workflow = {
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}},
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}},
"15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}},
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 720, "batch_size": 1}},
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
"19": {
@@ -68,9 +260,24 @@ def _build_workflow(positive: str, negative: str) -> dict:
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
}
# Standard checkpoint workflow (Pony / SDXL)
if SD_STYLE_LORA:
workflow["46"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": SD_STYLE_LORA,
"model": ["44", 0],
"clip": ["45", 0],
"strength_model": SD_STYLE_LORA_WEIGHT,
"strength_clip": SD_STYLE_LORA_WEIGHT,
},
}
workflow["19"]["inputs"]["model"] = ["46", 0]
workflow["11"]["inputs"]["clip"] = ["46", 1]
workflow["12"]["inputs"]["clip"] = ["46", 1]
return workflow
ckpt = o.get("checkpoint") or SD_CHECKPOINT
return {
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}},
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
@@ -89,24 +296,78 @@ def _build_workflow(positive: str, negative: str) -> dict:
}
async def comfy_api_request(
method: str,
path: str,
*,
params: dict | None = None,
json_body: dict | None = None,
timeout: float = 60,
) -> tuple[int, dict | str, dict]:
"""
Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset).
"""
async with _make_comfy_client(timeout=timeout) as client:
await _prime_comfy_gateway(client)
token = SD_QUERY_PARAMS.get("token")
use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None)
req_params = _merge_params(params) if use_vast else (params or None)
req_kwargs: dict = {}
if json_body is not None and method.upper() not in ("GET", "HEAD"):
req_kwargs["json"] = json_body
resp = await _comfy_request(
client,
method.upper(),
path,
params=req_params,
**req_kwargs,
)
headers = {
k: resp.headers.get(k)
for k in ("content-type", "location", "www-authenticate")
if resp.headers.get(k)
}
try:
body = resp.json()
except Exception:
body = resp.text[:8000]
return resp.status_code, body, headers
async def fetch_object_info() -> dict:
status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120)
if status != 200 or not isinstance(body, dict):
raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}")
return body
async def check_sd() -> bool:
try:
async with httpx.AsyncClient(timeout=5) as client:
r = await client.get(f"{SD_BASE_URL}/system_stats")
async with _make_comfy_client(timeout=15) as client:
await _prime_comfy_gateway(client)
r = await _comfy_request(client, "GET", "/system_stats")
return r.status_code == 200
except Exception:
return False
async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]:
async def txt2img(
prompt: str,
negative_prompt: str | None = None,
*,
overrides: dict | None = None,
) -> tuple[bytes, str]:
neg = negative_prompt or SD_DEFAULT_NEGATIVE
workflow = _build_workflow(prompt, neg)
workflow = _build_workflow(prompt, neg, overrides)
client_id = uuid.uuid4().hex
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
async with httpx.AsyncClient(timeout=300) as client:
resp = await client.post(
f"{SD_BASE_URL}/prompt",
logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt)
async with _make_comfy_client() as client:
await _prime_comfy_gateway(client)
resp = await _comfy_request(
client,
"POST",
"/prompt",
json={"prompt": workflow, "client_id": client_id},
)
resp.raise_for_status()
@@ -115,7 +376,7 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
for _ in range(300):
await asyncio.sleep(1)
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
hist = await _comfy_request(client, "GET", f"/history/{prompt_id}")
data = hist.json()
if prompt_id in data:
entry = data[prompt_id]
@@ -127,9 +388,15 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
for node_output in outputs.values():
if "images" in node_output:
img_info = node_output["images"][0]
img_resp = await client.get(
f"{SD_BASE_URL}/view",
params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")},
img_resp = await _comfy_request(
client,
"GET",
"/view",
params={
"filename": img_info["filename"],
"subfolder": img_info.get("subfolder", ""),
"type": img_info.get("type", "output"),
},
)
img_resp.raise_for_status()
image_bytes = img_resp.content
@@ -145,11 +412,43 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
raise RuntimeError("ComfyUI generation timed out or produced no output")
async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]:
async def generate_from_full_prompt(
full_prompt: str,
*,
overrides: dict | None = None,
) -> tuple[str | None, str | None]:
positive, negative = split_prompt_and_negative(full_prompt)
try:
_, rel_path = await txt2img(positive, negative)
_, rel_path = await txt2img(positive, negative, overrides=overrides)
return rel_path, None
except httpx.HTTPStatusError as e:
code = e.response.status_code
if code == 401:
logger.error(
"ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) "
"and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. "
"Test: curl -u user:pass http://IP:PORT/system_stats "
"or open /?token=… in browser then curl with cookies. "
"Alternative: trycloudflare URL for localhost:8188 in Port Mapping."
)
elif code in (301, 302, 303, 307, 308):
logger.error(
"ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. "
"SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com "
"(no ?token=). Location: %s",
code,
e.response.headers.get("location"),
)
else:
logger.error("ComfyUI HTTP %s: %s", code, e)
return None, str(e)
except httpx.ConnectError as e:
logger.error(
"ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. "
"Use trycloudflare URL from Port Mapping for localhost:8188.",
e,
)
return None, str(e)
except Exception as e:
logger.error("ComfyUI error: %s", e)
return None, str(e)
+31
View File
@@ -0,0 +1,31 @@
import logging
from services.memory import get_session
logger = logging.getLogger(__name__)
async def resolve_session_persona(
session_id: str,
requested: str | None = None,
*,
create_persona: str | None = None,
) -> str:
"""
Session.persona_id is the source of truth.
requested is ignored when it disagrees (logged). create_persona used only if session missing.
"""
session = await get_session(session_id)
if not session:
return (create_persona or requested or "default").strip() or "default"
bound = (session.get("persona_id") or "default").strip() or "default"
req = (requested or "").strip()
if req and req != bound:
logger.warning(
"persona_id mismatch session=%s bound=%s requested=%s (using bound)",
session_id,
bound,
req,
)
return bound
+21
View File
@@ -0,0 +1,21 @@
import logging
from services.chat_prompt import get_system_prompt
from services.memory import get_all_sessions, get_history, upsert_static_system_message
logger = logging.getLogger(__name__)
async def migrate_static_system_messages() -> int:
"""Rebuild stored system rows from sessions.persona_id (strip legacy RPG text)."""
updated = 0
for session in await get_all_sessions():
sid = session["session_id"]
persona_id = session.get("persona_id") or "default"
history = await get_history(sid)
static = await get_system_prompt(persona_id, history, "")
if await upsert_static_system_message(sid, static, history):
updated += 1
if updated:
logger.info("Migrated %s session system message(s) to static persona prompt", updated)
return updated