Fixed RPG
This commit is contained in:
+124
-33
@@ -7,7 +7,19 @@ import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
def parse_card_v2(data: dict) -> dict:
|
||||
def _normalize_alternate_greetings(inner: dict) -> list[str]:
|
||||
raw = inner.get("alternate_greetings") or []
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = []
|
||||
for item in raw:
|
||||
text = str(item).strip()
|
||||
if text and text not in out:
|
||||
out.append(text)
|
||||
return out
|
||||
|
||||
|
||||
def parse_card_v2(data: dict, card_id: str | None = None) -> dict:
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, str):
|
||||
inner = json.loads(inner)
|
||||
@@ -17,12 +29,15 @@ def parse_card_v2(data: dict) -> dict:
|
||||
if isinstance(entries, dict):
|
||||
entries = list(entries.values())
|
||||
|
||||
alternates = _normalize_alternate_greetings(inner)
|
||||
cid = card_id or (
|
||||
inner.get("name", "imported").lower().replace(" ", "_")[:48]
|
||||
+ "_"
|
||||
+ uuid.uuid4().hex[:8]
|
||||
)
|
||||
|
||||
return {
|
||||
"card_id": (
|
||||
inner.get("name", "imported").lower().replace(" ", "_")[:48]
|
||||
+ "_"
|
||||
+ uuid.uuid4().hex[:8]
|
||||
),
|
||||
"card_id": cid,
|
||||
"name": inner.get("name", "Character"),
|
||||
"description": inner.get("description", ""),
|
||||
"personality": inner.get("personality", ""),
|
||||
@@ -31,10 +46,22 @@ def parse_card_v2(data: dict) -> dict:
|
||||
"mes_example": inner.get("mes_example", ""),
|
||||
"appearance_tags": _extract_appearance(inner),
|
||||
"lorebook_json": json.dumps(entries, ensure_ascii=False),
|
||||
"alternate_greetings": alternates,
|
||||
"alternate_greetings_json": json.dumps(alternates, ensure_ascii=False),
|
||||
"raw_json": json.dumps(data if "data" in data else {"data": inner}, ensure_ascii=False),
|
||||
}
|
||||
|
||||
|
||||
def parse_card_bytes(content: bytes, filename: str) -> dict:
|
||||
if filename.lower().endswith(".png"):
|
||||
card = parse_png_card(content)
|
||||
if not card:
|
||||
raise ValueError("PNG does not contain character card metadata")
|
||||
card["_png_bytes"] = content
|
||||
return card
|
||||
return parse_card_v2(json.loads(content.decode("utf-8")))
|
||||
|
||||
|
||||
def _extract_appearance(inner: dict) -> str:
|
||||
"""Extract booru-style appearance tags from character fields."""
|
||||
import re
|
||||
@@ -107,12 +134,18 @@ def build_system_prompt(card: dict) -> str:
|
||||
|
||||
async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0.8) -> dict:
|
||||
card_id = card["card_id"]
|
||||
alt_json = card.get("alternate_greetings_json")
|
||||
if alt_json is None:
|
||||
alts = card.get("alternate_greetings") or []
|
||||
alt_json = json.dumps(alts, ensure_ascii=False) if isinstance(alts, list) else "[]"
|
||||
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT OR REPLACE INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes,
|
||||
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json, avatar_path)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json,
|
||||
avatar_path, alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
card_id,
|
||||
card["name"],
|
||||
@@ -127,6 +160,7 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0
|
||||
card.get("appearance_tags", ""),
|
||||
card["lorebook_json"],
|
||||
card.get("avatar_path", ""),
|
||||
alt_json,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
@@ -140,7 +174,7 @@ async def get_character(card_id: str) -> dict | None:
|
||||
"SELECT * FROM characters WHERE card_id = ?", (card_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return dict(row) if row else None
|
||||
return card_to_api(dict(row)) if row else None
|
||||
|
||||
|
||||
async def list_characters() -> list:
|
||||
@@ -171,9 +205,31 @@ async def update_appearance_tags(card_id: str, appearance_tags: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
def card_to_api(card: dict) -> dict:
|
||||
alts = card.get("alternate_greetings")
|
||||
if alts is None:
|
||||
try:
|
||||
alts = json.loads(card.get("alternate_greetings_json") or "[]")
|
||||
except Exception:
|
||||
alts = []
|
||||
if not isinstance(alts, list):
|
||||
alts = []
|
||||
return {**card, "alternate_greetings": alts}
|
||||
|
||||
|
||||
async def preview_card_file(content: bytes, filename: str) -> dict:
|
||||
card = parse_card_bytes(content, filename)
|
||||
png_bytes = card.pop("_png_bytes", None)
|
||||
preview = card_to_api(card)
|
||||
preview["is_png"] = bool(png_bytes)
|
||||
preview["alternate_count"] = len(preview.get("alternate_greetings") or [])
|
||||
return preview
|
||||
|
||||
|
||||
async def update_character(card_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "personality", "scenario", "first_mes",
|
||||
"mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path"}
|
||||
"mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path",
|
||||
"alternate_greetings_json"}
|
||||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not updates:
|
||||
return False
|
||||
@@ -187,37 +243,72 @@ async def update_character(card_id: str, fields: dict) -> bool:
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def import_card_file(content: bytes, filename: str, lora_name: str = "", lora_weight: float = 0.8) -> dict:
|
||||
if filename.lower().endswith(".png"):
|
||||
card = parse_png_card(content)
|
||||
if not card:
|
||||
raise ValueError("PNG does not contain character card metadata")
|
||||
# Use the PNG itself as avatar
|
||||
avatar_rel = _save_avatar_bytes(content, f"card_{card['card_id']}")
|
||||
async def import_card_file(
|
||||
content: bytes,
|
||||
filename: str,
|
||||
lora_name: str = "",
|
||||
lora_weight: float = 0.8,
|
||||
overrides: dict | None = None,
|
||||
card_id: str | None = None,
|
||||
) -> dict:
|
||||
card = parse_card_bytes(content, filename)
|
||||
png_bytes = card.pop("_png_bytes", None)
|
||||
|
||||
if card_id:
|
||||
card["card_id"] = card_id
|
||||
|
||||
if overrides:
|
||||
for key in (
|
||||
"name", "description", "personality", "scenario", "first_mes",
|
||||
"mes_example", "appearance_tags", "lorebook_json",
|
||||
):
|
||||
if key in overrides and overrides[key] is not None:
|
||||
card[key] = overrides[key]
|
||||
if overrides.get("alternate_greetings_json") is not None:
|
||||
card["alternate_greetings_json"] = overrides["alternate_greetings_json"]
|
||||
elif overrides.get("alternate_greetings") is not None:
|
||||
alts = overrides["alternate_greetings"]
|
||||
if isinstance(alts, str):
|
||||
try:
|
||||
alts = json.loads(alts)
|
||||
except Exception:
|
||||
alts = []
|
||||
card["alternate_greetings"] = alts
|
||||
card["alternate_greetings_json"] = json.dumps(alts, ensure_ascii=False)
|
||||
|
||||
if png_bytes:
|
||||
avatar_rel = _save_avatar_bytes(png_bytes, f"card_{card['card_id']}")
|
||||
card["avatar_path"] = avatar_rel
|
||||
else:
|
||||
card = parse_card_v2(json.loads(content.decode("utf-8")))
|
||||
|
||||
saved = await save_character(card, lora_name=lora_name, lora_weight=lora_weight)
|
||||
|
||||
persona_id = f"card_{saved['card_id']}"
|
||||
from services.personas import create_persona, get_persona
|
||||
from services.personas import create_persona, get_persona, patch_persona
|
||||
|
||||
existing = await get_persona(persona_id)
|
||||
persona_fields = {
|
||||
"name": saved["name"],
|
||||
"emoji": "🎭",
|
||||
"description": (saved["description"] or "")[:80] or "Character card",
|
||||
"prompt": build_system_prompt(saved),
|
||||
"sd_enabled": True,
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": saved.get("appearance_tags", ""),
|
||||
"avatar_path": saved.get("avatar_path", ""),
|
||||
"personality": saved.get("personality", ""),
|
||||
"scenario": saved.get("scenario", ""),
|
||||
"first_mes": saved.get("first_mes", ""),
|
||||
"mes_example": saved.get("mes_example", ""),
|
||||
"lorebook_json": saved.get("lorebook_json", "[]"),
|
||||
"alternate_greetings_json": saved.get("alternate_greetings_json", "[]"),
|
||||
}
|
||||
if not existing:
|
||||
await create_persona(
|
||||
persona_id=persona_id,
|
||||
name=saved["name"],
|
||||
emoji="🎭",
|
||||
description=saved["description"][:80] or "Character card",
|
||||
prompt=build_system_prompt(saved),
|
||||
sd_enabled=True,
|
||||
lora_name=lora_name,
|
||||
lora_weight=lora_weight,
|
||||
appearance_tags=saved.get("appearance_tags", ""),
|
||||
avatar_path=saved.get("avatar_path", ""),
|
||||
)
|
||||
return saved
|
||||
await create_persona(persona_id=persona_id, **persona_fields)
|
||||
else:
|
||||
await patch_persona(persona_id, persona_fields)
|
||||
|
||||
return card_to_api(saved)
|
||||
|
||||
|
||||
def _save_avatar_bytes(png_bytes: bytes, prefix: str) -> str:
|
||||
|
||||
+61
-54
@@ -1,12 +1,18 @@
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_KEY = os.getenv("ROUTER_KEY")
|
||||
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
MODEL = "google/gemini-2.5-flash"
|
||||
|
||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo")
|
||||
SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash")
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
||||
@@ -14,66 +20,67 @@ HEADERS = {
|
||||
"HTTP-Referer": "http://localhost:8000",
|
||||
}
|
||||
|
||||
|
||||
def _clean(messages: list) -> list:
|
||||
"""Filter out messages with empty content."""
|
||||
return [m for m in messages if (m.get("content") or "").strip()]
|
||||
|
||||
|
||||
async def _post(model: str, messages: list, extra: dict | None = None) -> str:
|
||||
payload = {"model": model, "messages": _clean(messages), **(extra or {})}
|
||||
async with httpx.AsyncClient(timeout=90) as client:
|
||||
r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
async def send_message(messages: list) -> str:
|
||||
"""Обычный запрос — используем для внутренних нужд"""
|
||||
payload = {
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
response = await client.post(
|
||||
OPENROUTER_URL,
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
"""System model — narrator, facts, SD prompt."""
|
||||
return await _post(SYSTEM_MODEL, messages)
|
||||
|
||||
|
||||
async def send_message_with_model(messages: list, model: str) -> str:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=90) as client:
|
||||
response = await client.post(
|
||||
OPENROUTER_URL,
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
"""Explicit model — plot arc, narrator override."""
|
||||
return await _post(model, messages)
|
||||
|
||||
|
||||
async def stream_message(messages: list):
|
||||
"""Стриминг — отдаём чанки по мере получения"""
|
||||
"""Chat model stream — roleplay dialogue."""
|
||||
payload = {
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"model": CHAT_MODEL,
|
||||
"messages": _clean(messages),
|
||||
"stream": True,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
OPENROUTER_URL,
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:] # убираем "data: "
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json
|
||||
chunk = json.loads(data)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except Exception:
|
||||
continue
|
||||
timeout = httpx.Timeout(connect=10, read=120, write=10, pool=5)
|
||||
chunk_count = 0
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
async with client.stream("POST", OPENROUTER_URL, headers=HEADERS, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
buf = ""
|
||||
async for raw in response.aiter_bytes():
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
if not buf and chunk_count == 0:
|
||||
logger.info("stream first bytes: %.200s", text)
|
||||
buf += text
|
||||
while "\n" in buf:
|
||||
line, buf = buf.split("\n", 1)
|
||||
line = line.rstrip("\r")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
return
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if content:
|
||||
chunk_count += 1
|
||||
yield content
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("stream_message error after %d chunks: %s", chunk_count, e)
|
||||
raise
|
||||
finally:
|
||||
logger.info("stream_message finished: %d chunks", chunk_count)
|
||||
|
||||
@@ -380,6 +380,15 @@ async def update_session_rpg_settings(session_id: str, settings_json: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_outfit(session_id: str, outfit_json: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET outfit_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(outfit_json, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def upsert_quest(session_id: str, title: str, status: str = "active"):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
async with db.execute(
|
||||
|
||||
@@ -69,6 +69,7 @@ def _row_to_persona(row: dict) -> dict:
|
||||
"mes_example": row.get("mes_example", "") or "",
|
||||
"lorebook_json": row.get("lorebook_json", "[]") or "[]",
|
||||
"avatar_path": row.get("avatar_path", "") or "",
|
||||
"alternate_greetings_json": row.get("alternate_greetings_json", "[]") or "[]",
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +123,7 @@ async def create_persona(
|
||||
mes_example: str = "",
|
||||
lorebook_json: str = "[]",
|
||||
avatar_path: str = "",
|
||||
alternate_greetings_json: str = "[]",
|
||||
) -> dict:
|
||||
final_prompt = prompt.strip() or build_persona_prompt(
|
||||
{
|
||||
@@ -137,12 +139,14 @@ async def create_persona(
|
||||
"""INSERT INTO personas
|
||||
(persona_id, name, emoji, description, prompt, custom,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path)
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json)
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
persona_id, name, emoji, description, final_prompt,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags,
|
||||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||||
alternate_greetings_json,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
@@ -162,6 +166,7 @@ async def create_persona(
|
||||
"mes_example": mes_example,
|
||||
"lorebook_json": lorebook_json,
|
||||
"avatar_path": avatar_path,
|
||||
"alternate_greetings_json": alternate_greetings_json,
|
||||
}
|
||||
|
||||
|
||||
@@ -227,6 +232,7 @@ async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||||
"mes_example",
|
||||
"lorebook_json",
|
||||
"avatar_path",
|
||||
"alternate_greetings_json",
|
||||
}
|
||||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not updates:
|
||||
|
||||
@@ -35,12 +35,14 @@ Return ONLY valid JSON (no markdown):
|
||||
"facts": ["durable facts only"],
|
||||
"choices": [{"id":"a","label":"..."}, ...],
|
||||
"affinity_delta": 0,
|
||||
"quest_updates": [{"title": "quest title", "status": "active|done|failed"}]
|
||||
"quest_updates": [{"title": "quest title", "status": "active|done|failed"}],
|
||||
"outfit_update": ["danbooru_tag", "danbooru_tag"]
|
||||
}
|
||||
Rules:
|
||||
- affinity_delta: integer -2..+2. Positive if character warmed up to player, negative if pushed away. 0 if neutral.
|
||||
- quest_updates: only include if a quest was clearly started, completed, or failed. Empty array otherwise.
|
||||
- choices: 0-4 options for what the player can do next."""
|
||||
- choices: 0-4 options for what the player can do next.
|
||||
- outfit_update: ONLY include if the character's clothing visibly changed (put on, took off, changed outfit). Use exact danbooru-style underscore_tags (e.g. ["white_dress", "red_ribbon", "barefoot"]). Empty array if no change."""
|
||||
|
||||
|
||||
async def narrator_pre(
|
||||
|
||||
+19
-1
@@ -27,7 +27,7 @@ Return ONLY valid JSON (no markdown):
|
||||
"cast": [{"name":"NPC name","role":"helper|antagonist|bystander","motivation":"..."}],
|
||||
"secrets": ["hidden truths not revealed yet"],
|
||||
"beats": [
|
||||
{"id":"b1","trigger":"event_driven:rest|event_driven:travel|event_driven:help_request|event_driven:after_fail|event_driven:after_success",
|
||||
{"id":"b1","title":"short quest title (3-6 words)","trigger":"event_driven:rest|event_driven:travel|event_driven:help_request|event_driven:after_fail|event_driven:after_success",
|
||||
"injection":"1-3 sentences to introduce the beat WITHOUT breaking current scene",
|
||||
"choices":[{"id":"a","label":"..."},{"id":"b","label":"..."}]}
|
||||
],
|
||||
@@ -90,6 +90,24 @@ def should_advance_arc(user_text: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
PHASE_ORDER = ["opening", "hook", "complication", "reveal", "climax", "aftermath"]
|
||||
|
||||
|
||||
def advance_phase(arc: dict) -> bool:
|
||||
"""Advance arc to next phase if beats are exhausted. Returns True if phase changed."""
|
||||
current = arc.get("phase", "opening")
|
||||
if arc.get("beats"):
|
||||
return False
|
||||
try:
|
||||
idx = PHASE_ORDER.index(current)
|
||||
except ValueError:
|
||||
return False
|
||||
if idx + 1 >= len(PHASE_ORDER):
|
||||
return False
|
||||
arc["phase"] = PHASE_ORDER[idx + 1]
|
||||
return True
|
||||
|
||||
|
||||
def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats", [])
|
||||
if not isinstance(beats, list):
|
||||
|
||||
+65
-34
@@ -1,20 +1,24 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from services.llm import send_message
|
||||
|
||||
from services.llm import send_message, send_message_with_model
|
||||
from services.personas import get_persona
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
|
||||
Given a roleplay chat excerpt and character appearance hints, output ONLY valid JSON (no markdown):
|
||||
Given a roleplay chat excerpt, output ONLY valid JSON (no markdown):
|
||||
{
|
||||
"should_generate": true,
|
||||
"shot_type": "first_person_pov" | "landscape" | "third_person",
|
||||
"appearance_tags": "booru-style tags for character appearance extracted from hints, e.g. 'white hair, wolf ears, wolf tail, yellow eyes'",
|
||||
"action_tags": "booru-style tags for pose/action, e.g. 'sitting, smiling, looking at viewer'",
|
||||
"environment_tags": "booru-style tags for location/lighting, e.g. 'indoors, kitchen, sunlight'"
|
||||
"action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'",
|
||||
"environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'"
|
||||
}
|
||||
Rules:
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be written as single tags: 'white hair' not 'white, hair'. 'wolf ears' not 'wolf, ears'.
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'.
|
||||
- Do NOT include appearance/character tags — those are provided separately.
|
||||
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
|
||||
- Do NOT invent tags. If unsure — omit.
|
||||
- Keep each field to 3-6 tags."""
|
||||
@@ -35,19 +39,38 @@ def strip_image_prompt_tag(text: str) -> str:
|
||||
return re.sub(r"\[IMAGE_PROMPT:.*?\]", "", text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "")
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
SD_UNET = os.getenv("SD_UNET", "")
|
||||
SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip()
|
||||
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia"
|
||||
|
||||
|
||||
def _is_pony() -> bool:
|
||||
return SD_CHECKPOINT in PONY_CHECKPOINTS
|
||||
|
||||
|
||||
def _is_anima() -> bool:
|
||||
return bool(SD_UNET) and not SD_CHECKPOINT
|
||||
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str:
|
||||
if _is_pony():
|
||||
quality = "score_9, score_8_up, score_7_up, source_anime, highres"
|
||||
elif _is_anima():
|
||||
quality = "masterpiece, best quality, score_7, anime"
|
||||
else:
|
||||
quality = "masterpiece, best quality, highres"
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None) -> str:
|
||||
is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS
|
||||
quality = "score_9, score_8_up, score_7_up, source_anime, highres" if is_pony else "masterpiece, best quality, highres"
|
||||
parts = [quality]
|
||||
|
||||
# prefer LLM-extracted appearance over raw persona tags
|
||||
appearance = scene.get("appearance_tags") or (persona or {}).get("appearance_tags", "")
|
||||
appearance = (persona or {}).get("appearance_tags", "")
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
if outfit_tags:
|
||||
parts.append(outfit_tags)
|
||||
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
@@ -75,9 +98,12 @@ def build_positive_prompt(scene: dict, persona: dict | None) -> str:
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
outfit_json: str = "[]",
|
||||
) -> tuple[str | None, str | None]:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona or not persona.get("sd_enabled"):
|
||||
# Generate only if persona has appearance tags
|
||||
if not persona or not (persona.get("appearance_tags") or "").strip():
|
||||
logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id)
|
||||
return None, None
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
|
||||
@@ -86,40 +112,45 @@ async def generate_sd_prompt(
|
||||
|
||||
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
|
||||
|
||||
appearance = persona.get("appearance_tags", "")
|
||||
# For card personas, also include description for better visual context
|
||||
if persona_id.startswith("card_"):
|
||||
from services.character_card import get_character
|
||||
card = await get_character(persona_id[5:])
|
||||
if card and card.get("description"):
|
||||
appearance = f"{appearance}\nCharacter description: {card['description'][:400]}"
|
||||
|
||||
builder_messages = [
|
||||
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Persona appearance hints: {appearance}\n\nChat:\n{excerpt}",
|
||||
},
|
||||
{"role": "user", "content": f"Chat:\n{excerpt}"},
|
||||
]
|
||||
|
||||
try:
|
||||
raw = await send_message(builder_messages)
|
||||
if SD_PROMPT_MODEL:
|
||||
raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL)
|
||||
else:
|
||||
raw = await send_message(builder_messages)
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||
raw = re.sub(r"\n?```$", "", raw)
|
||||
scene = json.loads(raw)
|
||||
except (json.JSONDecodeError, Exception):
|
||||
if not isinstance(scene, dict):
|
||||
logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", ""))
|
||||
return None, None
|
||||
|
||||
try:
|
||||
outfit_list = json.loads(outfit_json or "[]")
|
||||
outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else ""
|
||||
except Exception:
|
||||
outfit_tags = ""
|
||||
|
||||
positive = build_positive_prompt(scene, persona, outfit_tags)
|
||||
|
||||
if _is_pony():
|
||||
negative = PONY_NEGATIVE
|
||||
elif _is_anima():
|
||||
negative = ANIMA_NEGATIVE
|
||||
else:
|
||||
negative = "low quality, blurry, bad anatomy, watermark, text"
|
||||
|
||||
positive = build_positive_prompt(scene, persona)
|
||||
is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS
|
||||
negative = PONY_NEGATIVE if is_pony else "low quality, blurry, bad anatomy, watermark, text"
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
negative += ", third person, over the shoulder"
|
||||
|
||||
full = positive
|
||||
if negative:
|
||||
full += f"\n\nNegative prompt: {negative}"
|
||||
full = positive + f"\n\nNegative prompt: {negative}"
|
||||
return full, negative
|
||||
|
||||
+56
-22
@@ -16,13 +16,26 @@ SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
||||
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
||||
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
||||
SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal")
|
||||
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "NetaYumev35_pretrained_all_in_one.safetensors")
|
||||
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "")
|
||||
SD_DEFAULT_NEGATIVE = os.getenv(
|
||||
"SD_DEFAULT_NEGATIVE",
|
||||
"low quality, worst quality, blurry, bad anatomy, watermark, text",
|
||||
)
|
||||
|
||||
# Anima split-model settings
|
||||
SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors")
|
||||
SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors")
|
||||
SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors")
|
||||
|
||||
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
||||
|
||||
ANIMA_CHECKPOINTS = {"anima-preview3-base.safetensors"}
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
|
||||
|
||||
def _use_anima() -> bool:
|
||||
return bool(SD_UNET) and not SD_CHECKPOINT
|
||||
|
||||
|
||||
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
if "\n\nNegative prompt:" in full_prompt:
|
||||
@@ -32,26 +45,44 @@ def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str) -> dict:
|
||||
"""Minimal KSampler workflow for ComfyUI API."""
|
||||
seed = int(uuid.uuid4().int % 2**32)
|
||||
if _use_anima():
|
||||
return {
|
||||
"44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}},
|
||||
"45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}},
|
||||
"15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}},
|
||||
"28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
||||
"11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}},
|
||||
"12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}},
|
||||
"19": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["44", 0], "positive": ["11", 0], "negative": ["12", 0],
|
||||
"latent_image": ["28", 0], "seed": seed,
|
||||
"steps": SD_STEPS, "cfg": SD_CFG,
|
||||
"sampler_name": os.getenv("SD_SAMPLER", "er_sde"),
|
||||
"scheduler": os.getenv("SD_SCHEDULER", "simple"),
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
}
|
||||
# Standard checkpoint workflow (Pony / SDXL)
|
||||
return {
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
"10": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
"seed": int(uuid.uuid4().int % 2**32),
|
||||
"steps": SD_STEPS,
|
||||
"cfg": SD_CFG,
|
||||
"sampler_name": SD_SAMPLER,
|
||||
"scheduler": SD_SCHEDULER,
|
||||
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
|
||||
"latent_image": ["5", 0], "seed": seed,
|
||||
"steps": SD_STEPS, "cfg": SD_CFG,
|
||||
"sampler_name": SD_SAMPLER, "scheduler": SD_SCHEDULER,
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
@@ -74,7 +105,6 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
# queue the prompt
|
||||
resp = await client.post(
|
||||
f"{SD_BASE_URL}/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
@@ -83,14 +113,17 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
prompt_id = resp.json()["prompt_id"]
|
||||
logger.info("ComfyUI queued prompt_id=%s", prompt_id)
|
||||
|
||||
# poll until done
|
||||
for _ in range(300):
|
||||
await asyncio.sleep(1)
|
||||
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
|
||||
data = hist.json()
|
||||
if prompt_id in data:
|
||||
outputs = data[prompt_id]["outputs"]
|
||||
# find first image output
|
||||
entry = data[prompt_id]
|
||||
# Log any errors from ComfyUI
|
||||
if entry.get("status", {}).get("status_str") == "error":
|
||||
msgs = entry.get("status", {}).get("messages", [])
|
||||
logger.error("ComfyUI workflow error: %s", msgs)
|
||||
outputs = entry.get("outputs", {})
|
||||
for node_output in outputs.values():
|
||||
if "images" in node_output:
|
||||
img_info = node_output["images"][0]
|
||||
@@ -100,12 +133,13 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte
|
||||
)
|
||||
img_resp.raise_for_status()
|
||||
image_bytes = img_resp.content
|
||||
|
||||
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
(IMAGES_DIR / filename).write_bytes(image_bytes)
|
||||
logger.info("ComfyUI done → saved %s", filename)
|
||||
return image_bytes, f"images/{filename}"
|
||||
logger.error("ComfyUI no image output. status=%s outputs_keys=%s",
|
||||
entry.get("status"), list(outputs.keys()))
|
||||
break
|
||||
|
||||
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
||||
|
||||
Reference in New Issue
Block a user