diff --git a/database/db.py b/database/db.py index 38f7a49..1667631 100644 --- a/database/db.py +++ b/database/db.py @@ -86,6 +86,10 @@ async def _migrate_messages_columns(db): await db.execute("ALTER TABLE messages ADD COLUMN image_prompt TEXT") if "image_path" not in cols: await db.execute("ALTER TABLE messages ADD COLUMN image_path TEXT") + if "image_prompt_alt" not in cols: + await db.execute("ALTER TABLE messages ADD COLUMN image_prompt_alt TEXT") + if "image_path_alt" not in cols: + await db.execute("ALTER TABLE messages ADD COLUMN image_path_alt TEXT") async def _migrate_personas_columns(db): @@ -105,6 +109,8 @@ async def _migrate_personas_columns(db): await db.execute("ALTER TABLE personas ADD COLUMN avatar_path TEXT DEFAULT ''") if "alternate_greetings_json" not in cols: await db.execute("ALTER TABLE personas ADD COLUMN alternate_greetings_json TEXT DEFAULT '[]'") + if "appearance_prose" not in cols: + await db.execute("ALTER TABLE personas ADD COLUMN appearance_prose TEXT DEFAULT ''") async def _migrate_sessions_columns(db): @@ -170,3 +176,5 @@ async def _migrate_characters_columns(db): await db.execute("ALTER TABLE characters ADD COLUMN avatar_path TEXT DEFAULT ''") if "alternate_greetings_json" not in cols: await db.execute("ALTER TABLE characters ADD COLUMN alternate_greetings_json TEXT DEFAULT '[]'") + if "appearance_prose" not in cols: + await db.execute("ALTER TABLE characters ADD COLUMN appearance_prose TEXT DEFAULT ''") diff --git a/main.py b/main.py index 802a95a..869840c 100644 --- a/main.py +++ b/main.py @@ -3,9 +3,10 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse -from routers import chat, personas, sessions, characters, images, translate +from routers import chat, personas, sessions, characters, images, translate, debug from database.db import init_db from services.persona_seed import seed_default_personas +from services.system_message_migration import migrate_static_system_messages logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") @@ -14,6 +15,7 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(messag async def lifespan(app: FastAPI): await init_db() await seed_default_personas() + await migrate_static_system_messages() yield @@ -25,6 +27,7 @@ app.include_router(sessions.router) app.include_router(characters.router) app.include_router(images.router) app.include_router(translate.router) +app.include_router(debug.router) app.mount("/static", StaticFiles(directory="static"), name="static") @@ -34,6 +37,11 @@ async def root(): return FileResponse("static/index.html") +@app.get("/debug") +async def debug_page(): + return FileResponse("static/debug.html") + + @app.get("/health") async def health(): return {"status": "ok"} diff --git a/main_carrie-just-a-friendly-hug-3ccb4b5342bb_spec_v2.png b/main_carrie-just-a-friendly-hug-3ccb4b5342bb_spec_v2.png new file mode 100644 index 0000000..8ef6eb0 Binary files /dev/null and b/main_carrie-just-a-friendly-hug-3ccb4b5342bb_spec_v2.png differ diff --git a/main_clingy-obsessive-girlfriend-af26ead7_spec_v2.png b/main_clingy-obsessive-girlfriend-af26ead7_spec_v2.png new file mode 100644 index 0000000..058ce48 Binary files /dev/null and b/main_clingy-obsessive-girlfriend-af26ead7_spec_v2.png differ diff --git a/main_delta-125aa7a6_spec_v2.png b/main_delta-125aa7a6_spec_v2.png new file mode 100644 index 0000000..8f3f8b4 Binary files /dev/null and b/main_delta-125aa7a6_spec_v2.png differ diff --git a/main_violet-merino-d2e9f62b5d77_spec_v2.png b/main_violet-merino-d2e9f62b5d77_spec_v2.png new file mode 100644 index 0000000..da57117 Binary files /dev/null and b/main_violet-merino-d2e9f62b5d77_spec_v2.png differ diff --git a/main_vulpisfoglia-e0a6befda921_spec_v2.png b/main_vulpisfoglia-e0a6befda921_spec_v2.png new file mode 100644 index 0000000..9b52d48 Binary files /dev/null and b/main_vulpisfoglia-e0a6befda921_spec_v2.png differ diff --git a/main_your-scumbag-superheroine-friend-c65bf1fe881c_spec_v2.png b/main_your-scumbag-superheroine-friend-c65bf1fe881c_spec_v2.png new file mode 100644 index 0000000..fcd461e Binary files /dev/null and b/main_your-scumbag-superheroine-friend-c65bf1fe881c_spec_v2.png differ diff --git a/models/schemas.py b/models/schemas.py index 6ed48a4..6587501 100644 --- a/models/schemas.py +++ b/models/schemas.py @@ -24,6 +24,11 @@ class RegenerateRequest(BaseModel): class ForkSessionRequest(BaseModel): until_message_id: int + +class RebindPersonaRequest(BaseModel): + persona_id: str + clear_history: bool = False + class ChatResponse(BaseModel): reply: str session_id: str diff --git a/routers/characters.py b/routers/characters.py index 22c2036..27e349b 100644 --- a/routers/characters.py +++ b/routers/characters.py @@ -22,6 +22,7 @@ class CardPatch(BaseModel): first_mes: Optional[str] = None mes_example: Optional[str] = None appearance_tags: Optional[str] = None + appearance_prose: Optional[str] = None lora_name: Optional[str] = None lora_weight: Optional[float] = None alternate_greetings_json: Optional[str] = None diff --git a/routers/chat.py b/routers/chat.py index 02988a6..2b49a9a 100644 --- a/routers/chat.py +++ b/routers/chat.py @@ -3,14 +3,12 @@ import logging import os import random -import aiosqlite from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel -from database.db import DB_PATH from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest -from services.llm import send_message, stream_message +from services.llm import LLMError, send_message, stream_message from services.memory import ( get_history, add_message, @@ -18,7 +16,6 @@ from services.memory import ( get_or_create_session, get_session, update_session_title, - update_session_persona, get_message_count, get_last_assistant_message_id, update_message_image, @@ -26,7 +23,6 @@ from services.memory import ( update_session_status_quo, update_session_affinity, update_session_genre, - update_session_rpg_settings, update_session_outfit, update_session_plot_arc, upsert_quest, @@ -36,20 +32,23 @@ from services.memory import ( update_message_content, delete_messages_after, delete_message, + upsert_static_system_message, ) from services.personas import get_persona +from services.chat_prompt import get_system_prompt, DEFAULT_PROMPT +from services.session_identity import resolve_session_persona from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag -from services.lorebook import get_lorebook_context +from services.sd_images import run_sd_for_message from services.character_card import get_character from services import sdbackend as sd_service from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase from services.rpg_narrator import narrator_pre, narrator_post +from services.opening import ensure_plot_arc_and_quests, resolve_greeting, process_opening logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["chat"]) -DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True} @@ -72,24 +71,20 @@ def affinity_prompt_block(affinity: int) -> str: return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---" -async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: - persona = await get_persona(persona_id) - if not persona: - return DEFAULT_PROMPT - prompt = persona["prompt"] - recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] - context = recent + [{"role": "user", "content": user_message}] - if persona.get("lorebook_json"): - lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context) - if lore: - prompt += "\n\n" + lore - if persona_id.startswith("card_"): - card = await get_character(persona_id[5:]) - if card: - lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) - if lore: - prompt += "\n\n" + lore - return prompt +def messages_for_llm(history: list, llm_system_content: str) -> list[dict]: + """Build LLM payload: one system message (static + runtime), no duplicate system rows.""" + out: list[dict] = [] + system_used = False + for m in history: + if m["role"] == "system": + if not system_used: + out.append({"role": "system", "content": llm_system_content}) + system_used = True + else: + out.append({"role": m["role"], "content": m["content"]}) + if not system_used: + out.insert(0, {"role": "system", "content": llm_system_content}) + return out @router.get("/history/{session_id}") @@ -100,11 +95,18 @@ async def get_chat_history(session_id: str): @router.get("/system/{session_id}") async def get_system_blob(session_id: str): history = await get_history(session_id) - system_msg = next((m for m in history if m.get("role") == "system"), None) session = await get_session(session_id) + persona_id = (session.get("persona_id") if session else None) or "default" + persona = await get_persona(persona_id) or {} + system_msg = next((m for m in history if m.get("role") == "system"), None) + stored = system_msg.get("content") if system_msg else "" + live_static = await get_system_prompt(persona_id, history, "") + system_prompt = live_static if live_static else stored quests = await get_quests(session_id) return { - "system_prompt": system_msg.get("content") if system_msg else "", + "persona_id": persona_id, + "persona_name": persona.get("name", persona_id), + "system_prompt": system_prompt, "status_quo": session.get("status_quo") if session else "", "facts_json": session.get("facts_json") if session else "[]", "plot_arc_json": session.get("plot_arc_json") if session else "{}", @@ -119,14 +121,21 @@ async def get_system_blob(session_id: str): @router.post("/init") async def init_chat(request: ChatRequest): - persona_id = request.persona_id or "default" - await get_or_create_session(request.session_id, persona_id) + await get_or_create_session( + request.session_id, + request.persona_id or "default", + ) + persona_id = await resolve_session_persona( + request.session_id, + request.persona_id, + create_persona=request.persona_id, + ) history = await get_history(request.session_id) if history: return {"first_mes": None} system_prompt = await get_system_prompt(persona_id, [], "") - await add_message(request.session_id, "system", system_prompt) + await upsert_static_system_message(request.session_id, system_prompt, []) first_mes = None if request.first_mes_override and request.first_mes_override.strip(): @@ -152,53 +161,47 @@ class RpgBootstrapRequest(BaseModel): genre: str = "adventure" +class OpeningProcessRequest(BaseModel): + session_id: str + persona_id: str = "default" + rpg: bool = False + + +@router.post("/opening/process") +async def opening_process(req: OpeningProcessRequest): + await get_or_create_session(req.session_id, req.persona_id) + persona_id = await resolve_session_persona(req.session_id, req.persona_id) + try: + return await process_opening(req.session_id, persona_id, rpg=req.rpg) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + @router.post("/rpg/bootstrap") async def rpg_bootstrap(req: RpgBootstrapRequest): await get_or_create_session(req.session_id, req.persona_id) - session = await get_session(req.session_id) - persona = await get_persona(req.persona_id) or {} - - # Save genre + persona_id = await resolve_session_persona(req.session_id, req.persona_id) await update_session_genre(req.session_id, req.genre) - - arc_json = (session.get("plot_arc_json") or "{}") if session else "{}" - try: - arc = json.loads(arc_json) if isinstance(arc_json, str) else {} - except Exception: - arc = {} - if not arc: - facts_block = facts_to_prompt((session or {}).get("facts_json", "[]")) - arc = await generate_plot_arc( - persona.get("name", req.persona_id), - persona.get("description", ""), - persona.get("scenario", ""), - persona.get("first_mes", ""), - facts_block=facts_block, - genre=req.genre, - ) - if arc: - from services.memory import update_session_plot_arc - await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False)) - - # Seed quests from beats - for beat in arc.get("beats", []): - title = (beat.get("title") or beat.get("injection", "")).strip() - if title: - await upsert_quest(req.session_id, title[:120]) - + persona = await get_persona(persona_id) or {} + greeting = await resolve_greeting(req.session_id, persona) + arc = await ensure_plot_arc_and_quests(req.session_id, persona, greeting, req.genre) quests = await get_quests(req.session_id) return {"plot_arc": arc, "quests": quests} @router.post("/stream") async def chat_stream(request: ChatRequest): - persona_id = request.persona_id or "default" - - await get_or_create_session(request.session_id, persona_id) + await get_or_create_session(request.session_id, request.persona_id) + persona_id = await resolve_session_persona( + request.session_id, + request.persona_id, + create_persona=request.persona_id, + ) history = await get_history(request.session_id) session = await get_session(request.session_id) - system_prompt = await get_system_prompt(persona_id, history, request.message) + static_prompt = await get_system_prompt(persona_id, history, request.message) + runtime_suffix = "" arc = {} roll = None @@ -206,26 +209,27 @@ async def chat_stream(request: ChatRequest): resolution_text = "" narrator_msg = None # shown as narrator bubble before assistant reply rpg_settings = {} + facts_block = "" if session and session.get("rpg_enabled"): rpg_settings = get_rpg_settings(session) facts_block = facts_to_prompt(session.get("facts_json", "[]")) if facts_block: - system_prompt = system_prompt + "\n\n" + facts_block + runtime_suffix += "\n\n" + facts_block try: arc = json.loads(session.get("plot_arc_json") or "{}") except Exception: arc = {} if arc: - system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps( + runtime_suffix += "\n\n--- PlotArc ---\n" + json.dumps( {k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False ) + "\n---" status_quo = (session.get("status_quo") or "").strip() if status_quo: - system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---" + runtime_suffix += "\n\n--- Status quo ---\n" + status_quo + "\n---" if rpg_settings.get("affinity", True): aff = int(session.get("affinity") or 0) - system_prompt = system_prompt + affinity_prompt_block(aff) + runtime_suffix += affinity_prompt_block(aff) if rpg_settings.get("narrator", True): persona = await get_persona(persona_id) or {} @@ -274,7 +278,7 @@ async def chat_stream(request: ChatRequest): pre_sq = (pre.get("status_quo_update") or "").strip() if directives: - system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---" + runtime_suffix += "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---" if pre_sq: await update_session_status_quo(request.session_id, pre_sq) session["status_quo"] = pre_sq @@ -290,50 +294,37 @@ async def chat_stream(request: ChatRequest): ) narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text} - # Inject outcome into system prompt so character reply is consistent if roll is not None: - system_prompt = ( - system_prompt - + f"\n\n--- Mechanics ---\n" - + f"Roll d20={roll}. Outcome: {outcome}.\n" + runtime_suffix += ( + f"\n\n--- Mechanics ---\n" + f"Roll d20={roll}. Outcome: {outcome}.\n" + "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n" + "---" ) - # is_narrator_choice: wrap message so LLM understands context + llm_system = static_prompt + runtime_suffix + user_message_content = request.message if request.is_narrator_choice: user_message_content = f"[Player chose: {request.message}]" - if not history: - await add_message(request.session_id, "system", system_prompt) - elif history[0]["role"] == "system" and history[0]["content"] != system_prompt: - async with aiosqlite.connect(DB_PATH) as db: - await db.execute( - """UPDATE messages SET content = ? - WHERE session_id = ? AND role = 'system' - AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", - (system_prompt, request.session_id, request.session_id), - ) - await db.commit() + await upsert_static_system_message(request.session_id, static_prompt, history) if not request.skip_user_add: await add_message(request.session_id, "user", user_message_content) messages = await get_history(request.session_id) + llm_messages = messages_for_llm(messages, llm_system) full_reply = [] async def generate(): nonlocal arc - # Send narrator BEFORE streaming so it appears above the reply if narrator_msg: yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n" try: - async for chunk in stream_message( - [{"role": m["role"], "content": m["content"]} for m in messages] - ): + async for chunk in stream_message(llm_messages): full_reply.append(chunk) yield f"data: {json.dumps({'chunk': chunk})}\n\n" except Exception as e: @@ -344,97 +335,111 @@ async def chat_stream(request: ChatRequest): complete = "".join(full_reply) display_text = strip_image_prompt_tag(complete) - hist_with_reply = await get_history(request.session_id) + [ - {"role": "assistant", "content": display_text} - ] - sd_result = await generate_sd_prompt( - hist_with_reply, persona_id, - outfit_json=session.get("outfit_json", "[]") if session else "[]" - ) - prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete) - if (display_text or complete).strip(): - await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str) + await add_message(request.session_id, "assistant", display_text or complete) choices = [] debug_blocks = [] quests_updated = [] if session and session.get("rpg_enabled"): - if not arc: + try: + if not arc: + persona = await get_persona(persona_id) or {} + arc = await generate_plot_arc( + persona.get("name", persona_id), + persona.get("description", ""), + persona.get("scenario", ""), + persona.get("first_mes", ""), + facts_block=facts_to_prompt(session.get("facts_json", "[]")), + genre=session.get("genre") or "adventure", + ) + if arc: + await update_session_plot_arc( + request.session_id, json.dumps(arc, ensure_ascii=False) + ) + debug_blocks.append({ + "type": "plot_arc", + "text": json.dumps(arc, ensure_ascii=False, indent=2), + }) + if rpg_settings.get("quests", True): + for beat in arc.get("beats", []): + t = (beat.get("title") or beat.get("injection", "")).strip() + if t: + await upsert_quest(request.session_id, t[:120]) + + trig = should_advance_arc(request.message) + if trig and arc: + arc, beats = pop_matching_beats(arc, trig, max_beats=1) + if beats: + await update_session_plot_arc( + request.session_id, json.dumps(arc, ensure_ascii=False) + ) + inj = beats[0].get("injection", "") + if inj: + debug_blocks.append({"type": "narrator_injection", "text": inj}) + if rpg_settings.get("choices", True): + choices += beats[0].get("choices") or [] + if advance_phase(arc): + await update_session_plot_arc( + request.session_id, json.dumps(arc, ensure_ascii=False) + ) + debug_blocks.append({"type": "phase_advance", "text": arc["phase"]}) + + ctx = [ + m for m in (await get_history(request.session_id)) + if m["role"] in ("user", "assistant") + ][-10:] + new_facts = await extract_facts(ctx) + if new_facts: + merged = merge_facts(session.get("facts_json", "[]"), new_facts) + await update_session_facts(request.session_id, merged) + session["facts_json"] = merged + persona = await get_persona(persona_id) or {} - arc = await generate_plot_arc( - persona.get("name", persona_id), - persona.get("description", ""), - persona.get("scenario", ""), - persona.get("first_mes", ""), - facts_block=facts_to_prompt(session.get("facts_json", "[]")), - genre=session.get("genre") or "adventure", + ctx_txt = "\n".join( + f"{m['role']}: {m['content']}" + for m in ctx[-8:] + if m.get("role") in ("user", "assistant") + ) + post = await narrator_post( + persona.get("name", persona_id), + ctx_txt, + json.dumps(arc, ensure_ascii=False) if arc else "", + facts_to_prompt(session.get("facts_json", "[]")), ) - if arc: - await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) - debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)}) - if rpg_settings.get("quests", True): - for beat in arc.get("beats", []): - t = (beat.get("title") or beat.get("injection", "")).strip() - if t: - await upsert_quest(request.session_id, t[:120]) - trig = should_advance_arc(request.message) - if trig and arc: - arc, beats = pop_matching_beats(arc, trig, max_beats=1) - if beats: - await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) - inj = beats[0].get("injection", "") - if inj: - debug_blocks.append({"type": "narrator_injection", "text": inj}) - if rpg_settings.get("choices", True): - choices += beats[0].get("choices") or [] - if advance_phase(arc): - await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) - debug_blocks.append({"type": "phase_advance", "text": arc["phase"]}) + sq = (post.get("status_quo_update") or "").strip() + if sq: + await update_session_status_quo(request.session_id, sq) + debug_blocks.append({"type": "status_quo", "text": sq}) - ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:] - new_facts = await extract_facts(ctx) - if new_facts: - merged = merge_facts(session.get("facts_json", "[]"), new_facts) - await update_session_facts(request.session_id, merged) - session["facts_json"] = merged + if rpg_settings.get("choices", True): + choices += post.get("choices") or [] - persona = await get_persona(persona_id) or {} - ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant")) - post = await narrator_post( - persona.get("name", persona_id), - ctx_txt, - json.dumps(arc, ensure_ascii=False) if arc else "", - facts_to_prompt(session.get("facts_json", "[]")), - ) + if rpg_settings.get("affinity", True): + delta = int(post.get("affinity_delta") or 0) + if delta: + await update_session_affinity(request.session_id, delta) - sq = (post.get("status_quo_update") or "").strip() - if sq: - await update_session_status_quo(request.session_id, sq) - debug_blocks.append({"type": "status_quo", "text": sq}) + outfit_update = post.get("outfit_update") + if isinstance(outfit_update, list) and outfit_update: + outfit_str = json.dumps(outfit_update, ensure_ascii=False) + await update_session_outfit(request.session_id, outfit_str) + session["outfit_json"] = outfit_str - if rpg_settings.get("choices", True): - choices += post.get("choices") or [] - - if rpg_settings.get("affinity", True): - delta = int(post.get("affinity_delta") or 0) - if delta: - await update_session_affinity(request.session_id, delta) - - outfit_update = post.get("outfit_update") - if isinstance(outfit_update, list) and outfit_update: - outfit_str = json.dumps(outfit_update, ensure_ascii=False) - await update_session_outfit(request.session_id, outfit_str) - session["outfit_json"] = outfit_str - - if rpg_settings.get("quests", True): - for qu in (post.get("quest_updates") or []): - t = (qu.get("title") or "").strip() - if t: - await upsert_quest(request.session_id, t[:120], qu.get("status", "active")) - quests_updated = await get_quests(request.session_id) + if rpg_settings.get("quests", True): + for qu in (post.get("quest_updates") or []): + t = (qu.get("title") or "").strip() + if t: + await upsert_quest( + request.session_id, t[:120], qu.get("status", "active") + ) + quests_updated = await get_quests(request.session_id) + except LLMError as e: + logger.warning("RPG post-process skipped after reply: %s", e) + except Exception as e: + logger.exception("RPG post-process failed after reply: %s", e) count = await get_message_count(request.session_id) if count == 2 and not request.skip_user_add: @@ -443,23 +448,50 @@ async def chat_stream(request: ChatRequest): if (session or {}).get("title", "Новый чат") in ("", "Новый чат"): await update_session_title(request.session_id, f"{persona.get('name', persona_id)} — {preview}") - image_path = None - image_error = None - if prompt_str and SD_AUTO_GENERATE: + updated_session = await get_session(request.session_id) or session + hist = await get_history(request.session_id) + bundle = await generate_sd_prompt( + hist, + persona_id, + outfit_json=updated_session.get("outfit_json", "[]") if updated_session else "[]", + ) + prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(complete) + msg_id = await get_last_assistant_message_id(request.session_id) + + sd_out: dict = {} + if bundle: + yield f"data: {json.dumps({ + 'image_generating': True, + 'image_prompt': bundle.tag_full, + 'image_prompt_alt': bundle.desc_full, + })}\n\n" + sd_out = await run_sd_for_message(bundle, msg_id) + elif prompt_str and SD_AUTO_GENERATE: yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n" rel, err = await sd_service.generate_from_full_prompt(prompt_str) if rel: - image_path = rel - msg_id = await get_last_assistant_message_id(request.session_id) + sd_out["image_path"] = f"/static/{rel}" if msg_id: await update_message_image(msg_id, rel) else: - image_error = err + sd_out["image_error"] = err + sd_out["image_prompt"] = prompt_str - updated_session = await get_session(request.session_id) affinity = updated_session.get("affinity", 0) if updated_session else 0 - yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n" + yield f"data: {json.dumps({ + 'done': True, + 'image_prompt': sd_out.get('image_prompt') or prompt_str, + 'image_prompt_alt': sd_out.get('image_prompt_alt'), + 'image_path': sd_out.get('image_path'), + 'image_path_alt': sd_out.get('image_path_alt'), + 'image_error': sd_out.get('image_error'), + 'image_error_alt': sd_out.get('image_error_alt'), + 'choices': choices, + 'debug': debug_blocks, + 'affinity': affinity, + 'quests': quests_updated, + })}\n\n" return StreamingResponse( generate(), @@ -470,23 +502,24 @@ async def chat_stream(request: ChatRequest): @router.post("/", response_model=ChatResponse) async def chat(request: ChatRequest): - persona_id = request.persona_id or "default" - await get_or_create_session(request.session_id, persona_id) + await get_or_create_session(request.session_id, request.persona_id) + persona_id = await resolve_session_persona( + request.session_id, + request.persona_id, + create_persona=request.persona_id, + ) history = await get_history(request.session_id) - system_prompt = await get_system_prompt(persona_id, history, request.message) - - if not history: - await add_message(request.session_id, "system", system_prompt) + static_prompt = await get_system_prompt(persona_id, history, request.message) + await upsert_static_system_message(request.session_id, static_prompt, history) await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) - reply = await send_message( - [{"role": m["role"], "content": m["content"]} for m in messages] - ) + llm_messages = messages_for_llm(messages, static_prompt) + reply = await send_message(llm_messages) display = strip_image_prompt_tag(reply) - prompt_tuple = await generate_sd_prompt(messages, persona_id) - prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply) + bundle = await generate_sd_prompt(messages, persona_id) + prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(reply) await add_message(request.session_id, "assistant", display, image_prompt=prompt_str) @@ -527,7 +560,6 @@ async def regenerate_chat(req: RegenerateRequest): stream_req = ChatRequest( message=user_text, session_id=req.session_id, - persona_id=req.persona_id, skip_user_add=True, ) return await chat_stream(stream_req) diff --git a/routers/debug.py b/routers/debug.py new file mode 100644 index 0000000..af4532f --- /dev/null +++ b/routers/debug.py @@ -0,0 +1,248 @@ +import json +import os + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from services import sdbackend as sd_service +from services.comfy_models import list_node_types, parse_model_lists +from services.llm import ( + CHAT_MODEL, + LLM_FALLBACK_MODEL, + LLMError, + SYSTEM_MODEL, + send_message, + send_message_with_model, +) +from services.personas import get_all_personas +from services.sd_prompt import ( + SD_PROMPT_MODEL, + anima_dual_enabled, + run_prompt_builder, +) + +router = APIRouter(prefix="/debug", tags=["debug"]) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class SdPromptDebugRequest(BaseModel): + persona_id: str = "default" + chat_excerpt: str = "" + messages: list[ChatMessage] | None = None + outfit_json: str = "[]" + appearance_override: str | None = None + use_prose: bool = False + + +class LlmDebugRequest(BaseModel): + model: str = "" + system: str = "" + user: str = "" + messages: list[ChatMessage] | None = None + + +class ComfyRawRequest(BaseModel): + method: str = "GET" + path: str = "/system_stats" + params_json: str = "{}" + body_json: str = "" + + +class ComfyGenerateRequest(BaseModel): + positive: str + negative: str = "" + unet: str | None = None + clip: str | None = None + vae: str | None = None + checkpoint: str | None = None + + +@router.get("/config") +async def debug_config(): + base = sd_service.SD_BASE_URL + return { + "chat_model": CHAT_MODEL, + "system_model": SYSTEM_MODEL, + "llm_fallback_model": LLM_FALLBACK_MODEL, + "sd_prompt_model": SD_PROMPT_MODEL or SYSTEM_MODEL, + "sd_base_url": base, + "sd_has_token": bool(sd_service.SD_QUERY_PARAMS.get("token")), + "sd_anima_dual": anima_dual_enabled(), + "sd_unet": sd_service.SD_UNET, + "sd_clip": sd_service.SD_CLIP, + "sd_vae": sd_service.SD_VAE, + "sd_checkpoint": sd_service.SD_CHECKPOINT, + "sd_steps": sd_service.SD_STEPS, + "sd_cfg": sd_service.SD_CFG, + "router_key_set": bool(os.getenv("ROUTER_KEY")), + } + + +@router.get("/personas") +async def debug_personas(): + personas = await get_all_personas() + return [ + { + "persona_id": pid, + "name": p.get("name", pid), + "appearance_tags": p.get("appearance_tags", ""), + } + for pid, p in personas.items() + ] + + +@router.post("/sd-prompt") +async def debug_sd_prompt(req: SdPromptDebugRequest): + msgs = None + if req.messages: + msgs = [m.model_dump() for m in req.messages] + return await run_prompt_builder( + req.persona_id, + messages=msgs, + chat_excerpt=req.chat_excerpt, + outfit_json=req.outfit_json, + appearance_override=req.appearance_override, + use_prose=req.use_prose, + ) + + +@router.post("/llm") +async def debug_llm(req: LlmDebugRequest): + if req.messages: + messages = [m.model_dump() for m in req.messages] + else: + messages = [] + if req.system.strip(): + messages.append({"role": "system", "content": req.system.strip()}) + if req.user.strip(): + messages.append({"role": "user", "content": req.user.strip()}) + if not messages: + raise HTTPException(status_code=400, detail="Нужны messages или system/user") + + model = (req.model or "").strip() or SD_PROMPT_MODEL or SYSTEM_MODEL + try: + if model in (SYSTEM_MODEL, "") and not req.model: + text = await send_message(messages) + else: + text = await send_message_with_model(messages, model) + return {"model": model, "response": text} + except LLMError as e: + raise HTTPException(status_code=502, detail=str(e)) + + +@router.get("/comfy/ping") +async def debug_comfy_ping(): + try: + status, body, headers = await sd_service.comfy_api_request("GET", "/system_stats") + return {"ok": status == 200, "status": status, "body": body, "headers": headers} + except Exception as e: + return {"ok": False, "error": str(e)} + + +@router.get("/comfy/models") +async def debug_comfy_models(): + try: + info = await sd_service.fetch_object_info() + return { + "models": parse_model_lists(info), + "configured": { + "unet": sd_service.SD_UNET, + "clip": sd_service.SD_CLIP, + "vae": sd_service.SD_VAE, + "checkpoint": sd_service.SD_CHECKPOINT, + }, + "node_type_count": len(list_node_types(info)), + } + except Exception as e: + raise HTTPException(status_code=502, detail=str(e)) + + +@router.get("/comfy/object_info") +async def debug_comfy_object_info(node: str | None = None): + try: + info = await sd_service.fetch_object_info() + if node: + if node not in info: + raise HTTPException(status_code=404, detail=f"Unknown node: {node}") + return {node: info[node]} + return { + "node_types": list_node_types(info), + "models": parse_model_lists(info), + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=502, detail=str(e)) + + +@router.post("/comfy/raw") +async def debug_comfy_raw(req: ComfyRawRequest): + path = req.path.strip() + if not path.startswith("/"): + path = "/" + path + try: + params = json.loads(req.params_json or "{}") + if not isinstance(params, dict): + raise ValueError("params_json must be object") + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"params_json: {e}") + + body = None + if req.body_json.strip(): + try: + body = json.loads(req.body_json) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"body_json: {e}") + + method = req.method.upper() + if method not in ("GET", "POST", "PUT", "DELETE"): + raise HTTPException(status_code=400, detail="method must be GET|POST|PUT|DELETE") + + try: + status, resp_body, headers = await sd_service.comfy_api_request( + method, + path, + params=params or None, + json_body=body, + timeout=120, + ) + return {"status": status, "headers": headers, "body": resp_body} + except Exception as e: + raise HTTPException(status_code=502, detail=str(e)) + + +@router.post("/comfy/generate") +async def debug_comfy_generate(req: ComfyGenerateRequest): + if not req.positive.strip(): + raise HTTPException(status_code=400, detail="positive required") + + overrides: dict[str, str] = {} + if req.unet: + overrides["unet"] = req.unet + if req.clip: + overrides["clip"] = req.clip + if req.vae: + overrides["vae"] = req.vae + if req.checkpoint: + overrides["checkpoint"] = req.checkpoint + + full = req.positive.strip() + if req.negative.strip(): + full += f"\n\nNegative prompt: {req.negative.strip()}" + + try: + rel, err = await sd_service.generate_from_full_prompt( + full, + overrides=overrides or None, + ) + if not rel: + raise HTTPException(status_code=502, detail=err or "generation failed") + return {"image_path": f"/static/{rel}", "status": "ok"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=502, detail=str(e)) diff --git a/routers/personas.py b/routers/personas.py index 750a90b..1ce2e7a 100644 --- a/routers/personas.py +++ b/routers/personas.py @@ -57,6 +57,7 @@ class PersonaPatch(BaseModel): lora_name: Optional[str] = None lora_weight: Optional[float] = None appearance_tags: Optional[str] = None + appearance_prose: Optional[str] = None personality: Optional[str] = None scenario: Optional[str] = None first_mes: Optional[str] = None diff --git a/routers/sessions.py b/routers/sessions.py index c2a38b5..442b478 100644 --- a/routers/sessions.py +++ b/routers/sessions.py @@ -3,6 +3,7 @@ from services.memory import ( get_all_sessions, get_session, get_or_create_session, + get_history, delete_session, update_session_title, update_session_persona, @@ -17,7 +18,10 @@ from services.memory import ( get_last_message_preview, fork_session, ) -from models.schemas import ForkSessionRequest +from models.schemas import ForkSessionRequest, RebindPersonaRequest +from services.chat_prompt import get_system_prompt +from services.memory import rebind_session_persona +from services.personas import get_persona router = APIRouter(prefix="/sessions", tags=["sessions"]) @@ -46,9 +50,42 @@ async def get_session_route(session_id: str): return s +@router.post("/{session_id}/rebind-persona") +async def rebind_persona(session_id: str, body: RebindPersonaRequest): + session = await get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Сессия не найдена") + persona = await get_persona(body.persona_id) + if not persona: + raise HTTPException(status_code=400, detail="Персонаж не найден") + + hist = [] if body.clear_history else await get_history(session_id) + static = await get_system_prompt(body.persona_id, hist, "") + first_mes = (persona.get("first_mes") or "").strip() if body.clear_history else None + + try: + await rebind_session_persona( + session_id, + body.persona_id, + clear_history=body.clear_history, + static_prompt=static, + first_mes=first_mes or None, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + return { + "persona_id": body.persona_id, + "persona_name": persona.get("name", body.persona_id), + "system_prompt_preview": static[:500], + "clear_history": body.clear_history, + } + + @router.patch("/{session_id}") async def patch_session(session_id: str, data: dict): - await get_or_create_session(session_id, data.get("persona_id", "default")) + create_pid = data.get("persona_id") if "persona_id" in data else None + await get_or_create_session(session_id, create_pid) if "title" in data: await update_session_title(session_id, data["title"]) if "persona_id" in data: diff --git a/services/character_card.py b/services/character_card.py index 2d0e398..eb8004f 100644 --- a/services/character_card.py +++ b/services/character_card.py @@ -45,6 +45,7 @@ def parse_card_v2(data: dict, card_id: str | None = None) -> dict: "first_mes": inner.get("first_mes", ""), "mes_example": inner.get("mes_example", ""), "appearance_tags": _extract_appearance(inner), + "appearance_prose": "", "lorebook_json": json.dumps(entries, ensure_ascii=False), "alternate_greetings": alternates, "alternate_greetings_json": json.dumps(alternates, ensure_ascii=False), @@ -141,13 +142,13 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0 async with aiosqlite.connect(DB_PATH) as db: await db.execute( - """INSERT OR REPLACE INTO characters - (card_id, name, description, personality, scenario, first_mes, - mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json, - avatar_path, alternate_greetings_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + """INSERT INTO characters + (card_id, name, description, personality, scenario, first_mes, mes_example, + raw_json, lora_name, lora_weight, appearance_tags, appearance_prose, lorebook_json, avatar_path, + alternate_greetings_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( - card_id, + card["card_id"], card["name"], card["description"], card["personality"], @@ -157,10 +158,11 @@ async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0 card["raw_json"], lora_name, lora_weight, - card.get("appearance_tags", ""), + card["appearance_tags"], + card.get("appearance_prose", ""), card["lorebook_json"], card.get("avatar_path", ""), - alt_json, + card.get("alternate_greetings_json", "[]"), ), ) await db.commit() @@ -199,8 +201,8 @@ async def delete_character(card_id: str) -> bool: async def update_appearance_tags(card_id: str, appearance_tags: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute( - "UPDATE characters SET appearance_tags = ? WHERE card_id = ?", - (appearance_tags, card_id), + "UPDATE characters SET appearance_tags = ?, appearance_prose = ? WHERE card_id = ?", + (appearance_tags, "", card_id), ) await db.commit() @@ -228,7 +230,7 @@ async def preview_card_file(content: bytes, filename: str) -> dict: async def update_character(card_id: str, fields: dict) -> bool: allowed = {"name", "description", "personality", "scenario", "first_mes", - "mes_example", "appearance_tags", "lora_name", "lora_weight", "avatar_path", + "mes_example", "appearance_tags", "appearance_prose", "lora_name", "lora_weight", "avatar_path", "alternate_greetings_json"} updates = {k: v for k, v in fields.items() if k in allowed} if not updates: @@ -295,6 +297,7 @@ async def import_card_file( "lora_name": lora_name, "lora_weight": lora_weight, "appearance_tags": saved.get("appearance_tags", ""), + "appearance_prose": saved.get("appearance_prose", ""), "avatar_path": saved.get("avatar_path", ""), "personality": saved.get("personality", ""), "scenario": saved.get("scenario", ""), diff --git a/services/chat_prompt.py b/services/chat_prompt.py new file mode 100644 index 0000000..4e43622 --- /dev/null +++ b/services/chat_prompt.py @@ -0,0 +1,26 @@ +from services.personas import get_persona +from services.lorebook import get_lorebook_context +from services.character_card import get_character + +DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." + + +async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: + """Static character prompt only (no RPG runtime blocks).""" + persona = await get_persona(persona_id) + if not persona: + return DEFAULT_PROMPT + prompt = persona["prompt"] + recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] + context = recent + [{"role": "user", "content": user_message}] + if persona.get("lorebook_json"): + lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context) + if lore: + prompt += "\n\n" + lore + if persona_id.startswith("card_"): + card = await get_character(persona_id[5:]) + if card: + lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) + if lore: + prompt += "\n\n" + lore + return prompt diff --git a/services/comfy_models.py b/services/comfy_models.py new file mode 100644 index 0000000..24e5cac --- /dev/null +++ b/services/comfy_models.py @@ -0,0 +1,40 @@ +"""Parse ComfyUI /object_info into usable model lists.""" + +from __future__ import annotations + +# Node types whose combo inputs we expose in the debug UI +_MODEL_NODES: dict[str, tuple[str, str]] = { + "checkpoints": ("CheckpointLoaderSimple", "ckpt_name"), + "unets": ("UNETLoader", "unet_name"), + "clips": ("CLIPLoader", "clip_name"), + "vaes": ("VAELoader", "vae_name"), + "loras": ("LoraLoader", "lora_name"), +} + + +def _combo_options(node_def: dict, input_name: str) -> list[str]: + if not isinstance(node_def, dict): + return [] + required = (node_def.get("input") or {}).get("required") or {} + optional = (node_def.get("input") or {}).get("optional") or {} + spec = required.get(input_name) or optional.get(input_name) + if not spec or not isinstance(spec, (list, tuple)): + return [] + first = spec[0] + if isinstance(first, list): + return [str(x) for x in first] + return [] + + +def parse_model_lists(object_info: dict) -> dict[str, list[str]]: + out: dict[str, list[str]] = {} + for key, (node_type, input_name) in _MODEL_NODES.items(): + node_def = object_info.get(node_type) or {} + options = _combo_options(node_def, input_name) + if options: + out[key] = options + return out + + +def list_node_types(object_info: dict) -> list[str]: + return sorted(k for k in object_info.keys() if isinstance(object_info.get(k), dict)) diff --git a/services/llm.py b/services/llm.py index 2b4ae89..4a1a796 100644 --- a/services/llm.py +++ b/services/llm.py @@ -13,6 +13,8 @@ OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" CHAT_MODEL = os.getenv("CHAT_MODEL", "mistralai/mistral-nemo") SYSTEM_MODEL = os.getenv("SYSTEM_MODEL", "google/gemini-2.5-flash") +# Softer model when primary returns content_filter / empty / API errors (default: CHAT_MODEL). +LLM_FALLBACK_MODEL = (os.getenv("LLM_FALLBACK_MODEL") or "").strip() or CHAT_MODEL HEADERS = { "Authorization": f"Bearer {OPENROUTER_KEY}", @@ -21,26 +23,128 @@ HEADERS = { } +class LLMError(Exception): + """OpenRouter returned an error or an unexpected response shape.""" + + +def _parse_completion_body(data: dict) -> str: + if not isinstance(data, dict): + raise LLMError(f"Invalid API response: expected object, got {type(data).__name__}") + + if data.get("error"): + err = data["error"] + if isinstance(err, dict): + msg = err.get("message") or str(err) + code = err.get("code") + else: + msg = str(err) + code = None + suffix = f" (code={code})" if code is not None else "" + raise LLMError(f"OpenRouter error{suffix}: {msg}") + + choices = data.get("choices") + if not choices: + preview = str(data)[:400] + raise LLMError(f"OpenRouter response has no 'choices'. Body preview: {preview}") + + first = choices[0] if isinstance(choices[0], dict) else {} + message = first.get("message") or {} + if not isinstance(message, dict): + raise LLMError("OpenRouter choice has no message object") + + finish = first.get("finish_reason") or "" + native_finish = first.get("native_finish_reason") or "" + blocked_reasons = {"content_filter", "safety", "moderation"} + if finish in blocked_reasons or str(native_finish).upper() in ( + "PROHIBITED_CONTENT", + "SAFETY", + "BLOCKED", + ): + raise LLMError( + f"Content blocked by provider (finish_reason={finish}, native={native_finish})" + ) + + content = message.get("content") + if content is not None and str(content).strip(): + return str(content) + + refusal = message.get("refusal") + if refusal: + raise LLMError(f"Model refused the request: {refusal}") + + if finish and finish not in ("stop", "length", "tool_calls", "function_call"): + raise LLMError( + f"OpenRouter finished without content (finish_reason={finish}, native={native_finish})" + ) + + raise LLMError("OpenRouter returned empty message content") + + def _clean(messages: list) -> list: """Filter out messages with empty content.""" return [m for m in messages if (m.get("content") or "").strip()] -async def _post(model: str, messages: list, extra: dict | None = None) -> str: +async def _post_once(model: str, messages: list, extra: dict | None = None) -> str: + if not OPENROUTER_KEY: + raise LLMError("ROUTER_KEY is not set in environment") + payload = {"model": model, "messages": _clean(messages), **(extra or {})} async with httpx.AsyncClient(timeout=90) as client: r = await client.post(OPENROUTER_URL, headers=HEADERS, json=payload) - r.raise_for_status() - return r.json()["choices"][0]["message"]["content"] + try: + data = r.json() + except Exception as e: + raise LLMError(f"Non-JSON response (HTTP {r.status_code}): {r.text[:300]}") from e + + if r.status_code >= 400: + try: + _parse_completion_body(data) + except LLMError: + raise + raise LLMError(f"HTTP {r.status_code}: {data}") + + try: + return _parse_completion_body(data) + except LLMError: + logger.warning( + "OpenRouter completion failed model=%s status=%s body=%.500s", + model, + r.status_code, + data, + ) + raise + + +async def _post(model: str, messages: list, extra: dict | None = None) -> str: + """POST completion; on failure retries once with LLM_FALLBACK_MODEL (usually CHAT_MODEL).""" + try: + return await _post_once(model, messages, extra) + except LLMError as primary_err: + fallback = LLM_FALLBACK_MODEL + if not fallback or fallback == model: + raise + logger.info( + "LLM fallback: %s failed (%s) → retrying with %s", + model, + primary_err, + fallback, + ) + try: + return await _post_once(fallback, messages, extra) + except LLMError as fallback_err: + raise LLMError( + f"{primary_err} (fallback {fallback} also failed: {fallback_err})" + ) from fallback_err async def send_message(messages: list) -> str: - """System model — narrator, facts, SD prompt.""" + """SYSTEM_MODEL with automatic fallback to LLM_FALLBACK_MODEL.""" return await _post(SYSTEM_MODEL, messages) async def send_message_with_model(messages: list, model: str) -> str: - """Explicit model — plot arc, narrator override.""" + """Named model (RPG_*, SD_*) with automatic fallback to LLM_FALLBACK_MODEL.""" return await _post(model, messages) @@ -73,10 +177,19 @@ async def stream_message(messages: list): return try: chunk = json.loads(data) - content = chunk["choices"][0]["delta"].get("content", "") + if chunk.get("error"): + err = chunk["error"] + msg = err.get("message", err) if isinstance(err, dict) else err + raise LLMError(f"OpenRouter stream error: {msg}") + choices = chunk.get("choices") or [] + if not choices: + continue + content = (choices[0].get("delta") or {}).get("content", "") if content: chunk_count += 1 yield content + except LLMError: + raise except Exception: continue except Exception as e: diff --git a/services/memory.py b/services/memory.py index 25f5980..0a97ed3 100644 --- a/services/memory.py +++ b/services/memory.py @@ -2,7 +2,8 @@ import aiosqlite from database.db import DB_PATH -async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict: +async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict: + """Existing sessions keep their persona_id; persona_id applies only on INSERT.""" async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( @@ -13,9 +14,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") -> if row: return dict(row) + pid = (persona_id or "default").strip() or "default" await db.execute( "INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)", - (session_id, persona_id), + (session_id, pid), ) await db.commit() @@ -71,24 +73,99 @@ async def update_session_persona(session_id: str, persona_id: str): (persona_id, session_id), ) - # If persona changed, reset RPG state bound to the persona/arc. if prev is not None and prev != persona_id: - await db.execute( - """UPDATE sessions - SET facts_json = '[]', - global_plot = '', - status_quo = '', - plot_arc_json = '{}' - WHERE session_id = ?""", - (session_id,), - ) - await db.execute( - "DELETE FROM action_resolutions WHERE session_id = ?", - (session_id,), - ) + await _reset_persona_bound_state(db, session_id) await db.commit() +async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None: + await db.execute( + """UPDATE sessions + SET facts_json = '[]', + global_plot = '', + status_quo = '', + plot_arc_json = '{}', + outfit_json = '[]', + affinity = 0 + WHERE session_id = ?""", + (session_id,), + ) + await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,)) + await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,)) + + +async def upsert_static_system_message( + session_id: str, static_prompt: str, history: list | None = None +) -> bool: + """Store only static persona prompt in messages. Returns True if written.""" + hist = history if history is not None else await get_history(session_id) + async with aiosqlite.connect(DB_PATH) as db: + if not hist: + await db.execute( + """INSERT INTO messages (session_id, role, content, image_prompt, image_path) + VALUES (?, 'system', ?, NULL, NULL)""", + (session_id, static_prompt), + ) + await db.execute( + "UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", + (session_id,), + ) + await db.commit() + return True + + if hist[0]["role"] == "system": + if hist[0]["content"] == static_prompt: + return False + await db.execute( + """UPDATE messages SET content = ? + WHERE session_id = ? AND role = 'system' + AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", + (static_prompt, session_id, session_id), + ) + await db.commit() + return True + + await db.execute( + """INSERT INTO messages (session_id, role, content, image_prompt, image_path) + VALUES (?, 'system', ?, NULL, NULL)""", + (session_id, static_prompt), + ) + await db.commit() + return True + + +async def delete_dialog_messages(session_id: str) -> None: + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')", + (session_id,), + ) + await db.commit() + + +async def rebind_session_persona( + session_id: str, + persona_id: str, + *, + clear_history: bool = False, + static_prompt: str, + first_mes: str | None = None, +) -> None: + session = await get_session(session_id) + if not session: + raise ValueError("Session not found") + + await update_session_persona(session_id, persona_id) + if clear_history: + await delete_dialog_messages(session_id) + + history = await get_history(session_id) + await upsert_static_system_message(session_id, static_prompt, history) + + if clear_history and first_mes and first_mes.strip(): + await add_message(session_id, "assistant", first_mes.strip()) + + async def update_session_rpg(session_id: str, rpg_enabled: bool): async with aiosqlite.connect(DB_PATH) as db: await db.execute( @@ -178,7 +255,8 @@ async def get_history(session_id: str) -> list: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( - """SELECT id, role, content, image_prompt, image_path + """SELECT id, role, content, image_prompt, image_path, + image_prompt_alt, image_path_alt FROM messages WHERE session_id = ? ORDER BY id""", (session_id,), ) as cursor: @@ -190,6 +268,8 @@ async def get_history(session_id: str) -> list: "content": r["content"], "image_prompt": r["image_prompt"], "image_path": r["image_path"], + "image_prompt_alt": r["image_prompt_alt"], + "image_path_alt": r["image_path_alt"], } for r in rows ] @@ -332,6 +412,33 @@ async def update_message_image(message_id: int, image_path: str): await db.commit() +async def update_message_prompt(message_id: int, image_prompt: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE messages SET image_prompt = ? WHERE id = ?", + (image_prompt, message_id), + ) + await db.commit() + + +async def update_message_prompt_alt(message_id: int, image_prompt_alt: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE messages SET image_prompt_alt = ? WHERE id = ?", + (image_prompt_alt, message_id), + ) + await db.commit() + + +async def update_message_image_alt(message_id: int, image_path_alt: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE messages SET image_path_alt = ? WHERE id = ?", + (image_path_alt, message_id), + ) + await db.commit() + + async def get_last_assistant_message_id(session_id: str) -> int | None: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row diff --git a/services/opening.py b/services/opening.py new file mode 100644 index 0000000..f787621 --- /dev/null +++ b/services/opening.py @@ -0,0 +1,178 @@ +import json +import logging + +from services.memory import ( + get_history, + get_session, + get_last_assistant_message_id, + update_session_plot_arc, + update_session_status_quo, + update_session_affinity, + update_session_outfit, + upsert_quest, + get_quests, +) +from services.personas import get_persona +from services.rpg_facts import facts_to_prompt +from services.rpg_plot import generate_plot_arc +from services.rpg_narrator import narrator_post +from services.sd_prompt import generate_sd_prompt +from services.sd_images import run_sd_for_message + +logger = logging.getLogger(__name__) + +DEFAULT_RPG_SETTINGS = { + "dice": True, + "narrator": True, + "quests": True, + "affinity": True, + "choices": True, +} + + +def get_rpg_settings(session: dict) -> dict: + try: + return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")} + except Exception: + return DEFAULT_RPG_SETTINGS + + +async def resolve_greeting(session_id: str, persona: dict) -> str: + history = await get_history(session_id) + for m in reversed(history): + if m.get("role") == "assistant" and (m.get("content") or "").strip(): + return m["content"].strip() + return (persona.get("first_mes") or "").strip() + + +async def ensure_plot_arc_and_quests( + session_id: str, + persona: dict, + greeting: str, + genre: str, + *, + seed_quests: bool = True, +) -> dict: + session = await get_session(session_id) or {} + arc_json = session.get("plot_arc_json") or "{}" + try: + arc = json.loads(arc_json) if isinstance(arc_json, str) else {} + except Exception: + arc = {} + + if arc: + return arc + + facts_block = facts_to_prompt(session.get("facts_json", "[]")) + arc = await generate_plot_arc( + persona.get("name", "Character"), + persona.get("description", ""), + persona.get("scenario", ""), + greeting, + facts_block=facts_block, + genre=genre, + ) + if not arc: + return {} + + await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False)) + if seed_quests: + for beat in arc.get("beats", []): + title = (beat.get("title") or beat.get("injection", "")).strip() + if title: + await upsert_quest(session_id, title[:120]) + return arc + + +async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict: + session = await get_session(session_id) + if not session: + raise ValueError("Session not found") + + history = await get_history(session_id) + assistant_msgs = [m for m in history if m.get("role") == "assistant"] + if not assistant_msgs: + raise ValueError("No assistant message (first_mes) found") + + first_mes_text = assistant_msgs[-1].get("content", "").strip() + if not first_mes_text: + raise ValueError("Empty first_mes") + + msg_id = await get_last_assistant_message_id(session_id) + persona = await get_persona(persona_id) or {} + rpg_settings = get_rpg_settings(session) + + arc: dict = {} + choices: list = [] + status_quo = session.get("status_quo") or "" + outfit_json = session.get("outfit_json") or "[]" + + if rpg: + genre = session.get("genre") or "adventure" + arc = await ensure_plot_arc_and_quests( + session_id, + persona, + first_mes_text, + genre, + seed_quests=rpg_settings.get("quests", True), + ) + + session = await get_session(session_id) or session + ctx_txt = f"assistant: {first_mes_text}" + arc_json = json.dumps(arc, ensure_ascii=False) if arc else "" + facts_block = facts_to_prompt(session.get("facts_json", "[]")) + + post = await narrator_post( + persona.get("name", persona_id), + ctx_txt, + arc_json, + facts_block, + is_opening=True, + ) + + sq = (post.get("status_quo_update") or "").strip() + if sq: + await update_session_status_quo(session_id, sq) + status_quo = sq + + if rpg_settings.get("choices", True): + choices = post.get("choices") or [] + + if rpg_settings.get("affinity", True): + delta = int(post.get("affinity_delta") or 0) + if delta: + await update_session_affinity(session_id, delta) + + outfit_update = post.get("outfit_update") + if isinstance(outfit_update, list) and outfit_update: + outfit_json = json.dumps(outfit_update, ensure_ascii=False) + await update_session_outfit(session_id, outfit_json) + + if rpg_settings.get("quests", True): + for qu in (post.get("quest_updates") or []): + title = (qu.get("title") or "").strip() + if title: + await upsert_quest(session_id, title[:120], qu.get("status", "active")) + + quests = await get_quests(session_id) + messages = await get_history(session_id) + bundle = await generate_sd_prompt(messages, persona_id, outfit_json=outfit_json) + sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {} + + updated = await get_session(session_id) + affinity = updated.get("affinity", 0) if updated else 0 + + return { + "plot_arc": arc, + "quests": quests, + "outfit_json": outfit_json, + "status_quo": status_quo, + "choices": choices, + "image_prompt": sd_out.get("image_prompt"), + "image_prompt_alt": sd_out.get("image_prompt_alt"), + "image_path": sd_out.get("image_path"), + "image_path_alt": sd_out.get("image_path_alt"), + "image_error": sd_out.get("image_error"), + "image_error_alt": sd_out.get("image_error_alt"), + "affinity": affinity, + } diff --git a/services/personas.py b/services/personas.py index e77d3a7..1414a11 100644 --- a/services/personas.py +++ b/services/personas.py @@ -63,6 +63,7 @@ def _row_to_persona(row: dict) -> dict: "lora_name": row["lora_name"] or "", "lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8, "appearance_tags": row["appearance_tags"] or "", + "appearance_prose": row.get("appearance_prose", "") or "", "personality": row.get("personality", "") or "", "scenario": row.get("scenario", "") or "", "first_mes": row.get("first_mes", "") or "", @@ -117,6 +118,7 @@ async def create_persona( lora_name: str = "", lora_weight: float = 0.8, appearance_tags: str = "", + appearance_prose: str = "", personality: str = "", scenario: str = "", first_mes: str = "", @@ -138,19 +140,19 @@ async def create_persona( await db.execute( """INSERT INTO personas (persona_id, name, emoji, description, prompt, custom, - sd_enabled, lora_name, lora_weight, appearance_tags, + sd_enabled, lora_name, lora_weight, appearance_tags, appearance_prose, personality, scenario, first_mes, mes_example, lorebook_json, avatar_path, alternate_greetings_json) - VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( persona_id, name, emoji, description, final_prompt, - 1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, + 1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, appearance_prose, personality, scenario, first_mes, mes_example, lorebook_json, avatar_path, alternate_greetings_json, ), ) await db.commit() - return { + return { "name": name, "emoji": emoji, "description": description, @@ -160,6 +162,7 @@ async def create_persona( "lora_name": lora_name, "lora_weight": lora_weight, "appearance_tags": appearance_tags, + "appearance_prose": appearance_prose, "personality": personality, "scenario": scenario, "first_mes": first_mes, @@ -226,6 +229,7 @@ async def patch_persona(persona_id: str, fields: dict) -> bool: "lora_name", "lora_weight", "appearance_tags", + "appearance_prose", "personality", "scenario", "first_mes", @@ -255,6 +259,19 @@ async def patch_persona(persona_id: str, fields: dict) -> bool: merged = dict(existing) merged.update(updates) updates["prompt"] = build_persona_prompt(merged) + + if "appearance_tags" in updates and "appearance_prose" not in updates: + tags = updates["appearance_tags"].strip() + if tags: + from services.llm import send_message + try: + prose = await send_message([ + {"role": "system", "content": "Convert danbooru tags to natural English description. Output only the description, no markdown."}, + {"role": "user", "content": f"Tags: {tags}"} + ]) + updates["appearance_prose"] = prose.strip() + except Exception: + pass cols = ", ".join(f"{k} = ?" for k in updates) cur2 = await db.execute( diff --git a/services/rpg_facts.py b/services/rpg_facts.py index 253a95c..3620bdc 100644 --- a/services/rpg_facts.py +++ b/services/rpg_facts.py @@ -1,7 +1,10 @@ import json +import logging import os -from services.llm import send_message_with_model, send_message +from services.llm import LLMError, send_message_with_model, send_message + +logger = logging.getLogger(__name__) FACTS_MODEL = os.getenv("RPG_FACTS_MODEL", "").strip() or "deepseek/deepseek-chat-v3" @@ -51,7 +54,19 @@ async def extract_facts(context_messages: list[dict]) -> list[str]: {"role": "user", "content": transcript}, ] - raw = await (send_message_with_model(messages, FACTS_MODEL) if FACTS_MODEL else send_message(messages)) + try: + raw = await ( + send_message_with_model(messages, FACTS_MODEL) + if FACTS_MODEL + else send_message(messages) + ) + except LLMError as e: + logger.warning("extract_facts LLM failed (model=%s): %s", FACTS_MODEL or "SYSTEM", e) + return [] + except Exception as e: + logger.warning("extract_facts unexpected error: %s", e) + return [] + try: data = json.loads(raw.strip()) if isinstance(data, list): diff --git a/services/rpg_narrator.py b/services/rpg_narrator.py index a611baf..c7fa057 100644 --- a/services/rpg_narrator.py +++ b/services/rpg_narrator.py @@ -2,7 +2,7 @@ import json import os import random -from services.llm import send_message_with_model +from services.llm import LLMError, send_message_with_model import logging logger = logging.getLogger(__name__) @@ -63,10 +63,18 @@ async def narrator_pre( f"Facts:\n{facts_block}\n\n" f"Recent context:\n{context}\n" ) - raw = await send_message_with_model( - [{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}], - NARRATOR_MODEL, - ) + try: + raw = await send_message_with_model( + [{"role": "system", "content": NARRATOR_PRE_SYSTEM}, {"role": "user", "content": user}], + NARRATOR_MODEL, + ) + except LLMError as e: + logger.warning("Narrator-pre LLM failed (model=%s): %s", NARRATOR_MODEL, e) + return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""} + except Exception as e: + logger.warning("Narrator-pre unexpected error: %s", e) + return {"needs_check": False, "directives": [], "status_quo_update": "", "resolution_text": ""} + cleaned = raw.strip() if cleaned.startswith("```"): cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned @@ -87,17 +95,35 @@ async def narrator_post( context: str, global_plot: str, facts_block: str, + is_opening: bool = False, ) -> dict: + opening_block = "" + if is_opening: + opening_block = ( + "\n\nOPENING SCENE: This is the first greeting, not a mid-conversation reply. " + "Extract the character's INITIAL visible clothing from the greeting into outfit_update " + "(danbooru underscore tags), even if clothing did not change during the scene. " + "Set status_quo to describe the opening situation.\n" + ) user = ( f"Persona: {persona_name}\n\n" f"Global plot:\n{global_plot}\n\n" f"Facts:\n{facts_block}\n\n" f"Recent context:\n{context}\n" + f"{opening_block}" ) - raw = await send_message_with_model( - [{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}], - NARRATOR_MODEL, - ) + try: + raw = await send_message_with_model( + [{"role": "system", "content": NARRATOR_POST_SYSTEM}, {"role": "user", "content": user}], + NARRATOR_MODEL, + ) + except LLMError as e: + logger.warning("Narrator-post LLM failed (model=%s): %s", NARRATOR_MODEL, e) + return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []} + except Exception as e: + logger.warning("Narrator-post unexpected error: %s", e) + return {"status_quo_update": "", "facts": [], "choices": [], "affinity_delta": 0, "quest_updates": []} + cleaned = raw.strip() if cleaned.startswith("```"): cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned diff --git a/services/rpg_plot.py b/services/rpg_plot.py index c964813..d70e0fc 100644 --- a/services/rpg_plot.py +++ b/services/rpg_plot.py @@ -1,7 +1,7 @@ import json import os -from services.llm import send_message_with_model, send_message +from services.llm import LLMError, send_message_with_model, send_message import logging logger = logging.getLogger(__name__) @@ -63,7 +63,19 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar {"role": "system", "content": ARC_SYSTEM}, {"role": "user", "content": user}, ] - raw = await (send_message_with_model(messages, PLOT_MODEL) if PLOT_MODEL else send_message(messages)) + try: + raw = await ( + send_message_with_model(messages, PLOT_MODEL) + if PLOT_MODEL + else send_message(messages) + ) + except LLMError as e: + logger.warning("generate_plot_arc LLM failed (model=%s): %s", PLOT_MODEL or "SYSTEM", e) + return {} + except Exception as e: + logger.warning("generate_plot_arc unexpected error: %s", e) + return {} + cleaned = raw.strip() # common OpenRouter formatting: fenced json if cleaned.startswith("```"): diff --git a/services/sd_images.py b/services/sd_images.py new file mode 100644 index 0000000..3374b13 --- /dev/null +++ b/services/sd_images.py @@ -0,0 +1,48 @@ +"""Run ComfyUI generation from SdPromptBundle (single hybrid prompt for Anima).""" + +import logging +import os + +from services import sdbackend as sd_service +from services.memory import update_message_image, update_message_prompt, update_message_prompt_alt +from services.sd_prompt import SdPromptBundle + +logger = logging.getLogger(__name__) + +SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") + + +async def run_sd_for_message(bundle: SdPromptBundle | None, msg_id: int | None) -> dict: + """Generate image, persist prompts/paths on message. Returns fields for API/SSE.""" + out = { + "image_prompt": None, + "image_prompt_alt": None, + "image_path": None, + "image_path_alt": None, + "image_error": None, + "image_error_alt": None, + } + if not bundle or not bundle.tag_full: + return out + + out["image_prompt"] = bundle.tag_full + if bundle.desc_full and bundle.desc_full != bundle.tag_full: + out["image_prompt_alt"] = bundle.desc_full + + if msg_id: + await update_message_prompt(msg_id, bundle.tag_full) + if out["image_prompt_alt"]: + await update_message_prompt_alt(msg_id, out["image_prompt_alt"]) + + if not SD_AUTO_GENERATE: + return out + + rel, err = await sd_service.generate_from_full_prompt(bundle.tag_full) + if rel: + out["image_path"] = f"/static/{rel}" + if msg_id: + await update_message_image(msg_id, rel) + else: + out["image_error"] = err + + return out diff --git a/services/sd_prompt.py b/services/sd_prompt.py index 9e0b31b..d3778fe 100644 --- a/services/sd_prompt.py +++ b/services/sd_prompt.py @@ -2,26 +2,115 @@ import json import logging import os import re +from dataclasses import dataclass from services.llm import send_message, send_message_with_model from services.personas import get_persona logger = logging.getLogger(__name__) +NEGATIVE_PROMPT_SEPARATOR = "\n\n__NEGATIVE_PROMPT__\n\n" + PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models. Given a roleplay chat excerpt, output ONLY valid JSON (no markdown): { "should_generate": true, "shot_type": "first_person_pov" | "landscape" | "third_person", - "action_tags": "booru-style tags for pose/action/expression, e.g. 'sitting, smiling, holding_cup'", - "environment_tags": "booru-style tags for location/lighting/time, e.g. 'indoors, kitchen, sunlight, daytime'" + "action_tags": "booru-style tags for pose/action/expression", + "environment_tags": "booru-style tags for location/lighting/time" } Rules: -- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined: 'fox_ears' not 'fox ears'. +- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be underscore_joined. - Do NOT include appearance/character tags — those are provided separately. - Do NOT include quality tags, model names, style words, 'pov', or category/metadata words. - Do NOT invent tags. If unsure — omit. -- Keep each field to 3-6 tags.""" +- Keep action_tags and environment_tags to 3-6 tags each. +- shot_type: default "first_person_pov" for dialogue/intimacy at arm's length. "third_person" only for wide action (fight, chase). "landscape" only when environment is the focus. +- should_generate: false for non-visual beats (pure internal monologue, time skips with no new pose, empty lines). +- NEVER use negative words in tag fields (not, without, naked, nsfw, etc.).""" + +ANIMA_BUILDER_EXTRA = """ +Anima hybrid mode — ALSO include: + "pov_cue": "face_to_face" | "walking_together" | "doorway_invite" | "reach_to_viewer" | "dialogue_close", + "viewer_body_visible": false, + "scene_description": "ONE short English sentence (max 40 words). Camera POV: what the viewer sees. Mood/atmosphere only — do NOT repeat tags from action_tags/environment_tags. Do NOT list comma-separated booru tags." +POV / interaction rules: +- Default viewer_body_visible: false. The viewer's body, hands, or face must NOT appear in the image — only the character toward the camera. +- For hugs, embraces: use arms_out, reaching_towards_viewer, inviting_hug — NOT holding_hands, lifting, carrying, nose_rub (these draw a second body in POV). +- For long messages with time skips ("About an hour later..."), illustrate ONLY the final visible beat (usually the last paragraph). +- scene_description: describe HER toward the camera only — NEVER "someone", "both", "with you", "hand in hand with", or another person's body. +- NEVER use tags: looking_at_each_other, couple, 2girls, 2boys, multiple_girls. For POV walking together omit holding_hands (use walking, smiling, reaching_towards_viewer instead). +- pov_cue: pick the framing that matches the CURRENT beat (walking_together for strolling side by side, doorway_invite for doorway with arms open, reach_to_viewer when she reaches toward camera, face_to_face for close dialogue). +- Illustrate ONLY the beat under === ILLUSTRATE ===; use === Context === for outfit/location hints only. +- Do NOT put English sentences in action_tags or environment_tags — tags only.""" + +POV_CUE_PHRASES: dict[str, str] = { + "face_to_face": "POV: close face-to-face, she looks directly at you", + "walking_together": "POV: walking beside you, profile and shared path visible", + "doorway_invite": "POV: she blocks the doorway, arms open toward you", + "reach_to_viewer": "POV: she reaches toward the camera", + "dialogue_close": "POV: close conversation, she faces you at arm's length", +} + +POV_CUE_DEFAULT = "POV: she stands before you, facing the camera" + +POV_INTERACTION_NEGATIVE = ( + "duplicate, clone, multiple_girls, 2girls, extra_person, pov hands, " + "disembodied hands, extra arms, second person" +) + +_CONTACT_ACTION_KEYWORDS = ( + "hug", "holding_hands", "hand_holding", "arms_out", "embrace", + "reaching", "inviting_hug", "arm_around", "cuddling", +) + +_JUNK_STANDALONE_TAGS = frozenset({ + "white", "black", "skin", "ear", "ears", "girl", "boy", "fox", "wolf", "cat", + "short", "tall", "slim", "golden", "silver", "red", "blue", "green", "purple", + "pink", "brown", "blonde", "eye", "eyes", "hair", +}) + +_INVALID_TAGS = frozenset({ + "pumped_up", "pumped", "looking_at_each_other", "couple", + "2girls", "2boys", "multiple_girls", "multiple_boys", "duo", +}) + +_POV_DROP_ACTION_TAGS = frozenset({ + "holding_hands", "hand_holding", "looking_at_each_other", "couple", + "lifting", "carry", "carrying", "princess_carry", "nose_rub", "nose_boop", +}) + +_TIME_SKIP_RE = re.compile( + r"(?i)\b(?:about an hour later|hours later|later that (?:day|evening|night)|" + r"the next (?:day|morning|evening)|meanwhile|after (?:some )?time)\b[.…\s]*", +) + +_POV_MOOD_FALLBACK: dict[str, str] = { + "walking_together": "Easy warmth and quiet laughter in the afternoon light.", + "doorway_invite": "Cool air and playful tension as she waits in the doorway.", + "reach_to_viewer": "A charged moment as she reaches toward the camera.", + "face_to_face": "Her expression softens in close focus toward the camera.", + "dialogue_close": "Intimate calm in the space between you.", +} + +_INDOOR_ENV_MARKERS = frozenset({"doorway", "indoors", "indoor", "apartment", "inside", "room"}) +_OUTDOOR_ENV_MARKERS = frozenset({"outdoor", "outdoors", "outside", "street"}) + +_POV_PROSE_BANNED = re.compile( + r"\b(someone|both|together with|hand in hand with|another person|second person|" + r"your hands|your fingers|your embrace|your heat|intertwined|with you|" + r"demands your|before you)\b", + re.IGNORECASE, +) + +SD_ANIMA_DUAL_COMPARE = os.getenv("SD_ANIMA_DUAL_COMPARE", "false").lower() in ("1", "true", "yes") + + +@dataclass +class SdPromptBundle: + tag_full: str + negative: str + desc_full: str | None = None def extract_image_prompt_tag(text: str) -> str | None: @@ -44,7 +133,7 @@ SD_UNET = os.getenv("SD_UNET", "") SD_PROMPT_MODEL = os.getenv("SD_PROMPT_MODEL", "").strip() PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"} -PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored" +PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored" ANIMA_NEGATIVE = "worst quality, low quality, score_1, score_2, score_3, blurry, jpeg artifacts, sepia" @@ -56,37 +145,201 @@ def _is_anima() -> bool: return bool(SD_UNET) and not SD_CHECKPOINT -def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: +def anima_dual_enabled() -> bool: + return _is_anima() and SD_ANIMA_DUAL_COMPARE + + +def _builder_system() -> str: + if _is_anima(): + return PROMPT_BUILDER_SYSTEM + ANIMA_BUILDER_EXTRA + return PROMPT_BUILDER_SYSTEM + + +def _normalize_shot_type(scene: dict) -> dict: + st = (scene.get("shot_type") or "").strip().lower() + if st == "landscape": + scene["shot_type"] = "landscape" + return _sanitize_scene_fields(scene) + if st == "third_person": + action = (scene.get("action_tags") or "").lower() + wide = ("battle", "fight", "chase", "running", "crowd", "wide_shot", "group_shot") + if any(w in action for w in wide): + scene["shot_type"] = "third_person" + return _sanitize_scene_fields(scene) + scene["shot_type"] = "first_person_pov" + if scene.get("viewer_body_visible") is None: + scene["viewer_body_visible"] = False + return _sanitize_scene_fields(scene) + + +def _split_tag_input(tag_str: str) -> list[str]: + return [t.strip() for t in (tag_str or "").split(",") if t.strip()] + + +def _is_sentence_like_tag(tag: str) -> bool: + t = tag.strip() + if len(t) > 45: + return True + if re.search(r"[.!?]", t): + return True + words = t.split() + return len(words) >= 5 and "_" not in t + + +def _filter_tag_field(tag_str: str, *, for_pov: bool, field: str) -> str: + kept: list[str] = [] + for raw in _split_tag_input(tag_str): + key = raw.lower().replace(" ", "_") + if key in _INVALID_TAGS: + continue + if _is_sentence_like_tag(raw): + continue + if for_pov and field == "action" and key in _POV_DROP_ACTION_TAGS: + continue + kept.append(raw if "_" in raw else key) + return ", ".join(kept) + + +def _reconcile_environment_tags(env_str: str) -> str: + tags = _split_tag_input(env_str) + keys = {t.lower().replace(" ", "_") for t in tags} + has_indoor = bool(keys & _INDOOR_ENV_MARKERS) or any( + any(m in k for m in _INDOOR_ENV_MARKERS) for k in keys + ) + has_outdoor = bool(keys & _OUTDOOR_ENV_MARKERS) or any( + any(m in k for m in _OUTDOOR_ENV_MARKERS) for k in keys + ) + if has_indoor and has_outdoor: + tags = [t for t in tags if t.lower().replace(" ", "_") not in _OUTDOOR_ENV_MARKERS] + return ", ".join(tags) + + +def _sanitize_pov_prose(desc: str, scene: dict) -> str: + if not desc or not desc.strip(): + return "" + if scene.get("shot_type") != "first_person_pov": + return desc.strip() + + kept: list[str] = [] + for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()): + s = sentence.strip() + if not s: + continue + if _POV_PROSE_BANNED.search(s): + continue + if re.search(r"\bwolfgirl\b", s, re.I) and re.search( + r"\b(walks|walking|stands)\b", s, re.I + ): + continue + kept.append(s) + out = " ".join(kept).strip() + return re.sub(r"\bat the viewer\b", "at the camera", out, flags=re.IGNORECASE) + + +def _sanitize_scene_fields(scene: dict) -> dict: + scene = dict(scene) + for_pov = scene.get("shot_type") == "first_person_pov" + scene["action_tags"] = _filter_tag_field( + scene.get("action_tags") or "", for_pov=for_pov, field="action" + ) + env = _filter_tag_field(scene.get("environment_tags") or "", for_pov=False, field="env") + scene["environment_tags"] = _reconcile_environment_tags(env) + scene["scene_description"] = _sanitize_pov_prose( + (scene.get("scene_description") or "").strip(), scene + ) + return scene + + +def _scene_should_generate(scene: dict) -> bool: + if scene.get("should_generate") is False: + return False + return True + + +def _sanitize_tags_string(tag_str: str) -> str: + if not tag_str: + return "" + out: list[str] = [] + seen: set[str] = set() + for raw in tag_str.split(","): + t = raw.strip() + if not t: + continue + key = t.lower().replace(" ", "_") + if key in seen: + continue + if key in _INVALID_TAGS: + continue + if "_" not in key and key in _JUNK_STANDALONE_TAGS: + continue + if len(key) <= 2: + continue + seen.add(key) + out.append(t if "_" in t else key) + return ", ".join(out) + + +def _quality_prefix() -> str: if _is_pony(): - quality = "score_9, score_8_up, score_7_up, source_anime, highres" - elif _is_anima(): - quality = "masterpiece, best quality, score_7, anime" - else: - quality = "masterpiece, best quality, highres" + return "score_9, score_8_up, score_7_up, source_anime, highres" + if _is_anima(): + return "masterpiece, best quality, score_7, anime" + return "masterpiece, best quality, highres" - parts = [quality] - appearance = (persona or {}).get("appearance_tags", "") - if appearance: - parts.append(appearance) - if outfit_tags: - parts.append(outfit_tags) +def _appearance_for_persona(persona: dict | None) -> str: + """Tag core uses appearance_tags only (prose is for LLM context, not Comfy tag line).""" + return _sanitize_tags_string((persona or {}).get("appearance_tags", "")) - if scene.get("shot_type") == "landscape": - parts.append(scene.get("environment_tags", "")) - else: - if scene.get("shot_type") == "first_person_pov": - parts.append("pov, first-person view, looking at viewer") - parts.append(scene.get("action_tags", "")) - parts.append(scene.get("environment_tags", "")) +def _dedupe_outfit_tags(outfit_tags: str) -> str: + tags = _split_tag_input(outfit_tags) + keys = {t.lower().replace(" ", "_") for t in tags} + if len(keys & {"jeans", "ripped_jeans", "black_jeans"}) > 1 and "jeans" in keys: + tags = [t for t in tags if t.lower().replace(" ", "_") != "jeans"] + return ", ".join(tags) + + +def _scene_has_physical_contact(scene: dict) -> bool: + action = (scene.get("action_tags") or "").lower() + return any(k in action for k in _CONTACT_ACTION_KEYWORDS) + + +def _infer_pov_cue_from_action(action_tags: str) -> str: + action = (action_tags or "").lower() + if any(k in action for k in ("holding_hands", "hand_holding", "walking", "strolling")): + return "walking_together" + if any(k in action for k in ("doorway", "door", "entry", "threshold")): + if any(k in action for k in ("arms_out", "hug", "embrace", "inviting")): + return "doorway_invite" + if any(k in action for k in ("arms_out", "reaching", "inviting_hug", "hug", "embrace")): + return "reach_to_viewer" + if any(k in action for k in ("sitting", "lying", "bed")): + return "dialogue_close" + return "face_to_face" + + +def _build_pov_phrase(scene: dict) -> str: + if scene.get("shot_type") != "first_person_pov": + return "" + cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_") + if cue in POV_CUE_PHRASES: + return POV_CUE_PHRASES[cue] + inferred = _infer_pov_cue_from_action(scene.get("action_tags", "")) + return POV_CUE_PHRASES.get(inferred, POV_CUE_DEFAULT) + + +def _append_lora(parts: list[str], persona: dict | None) -> None: lora = (persona or {}).get("lora_name", "") weight = (persona or {}).get("lora_weight", 0.8) if lora: parts.append(f"") + +def _dedupe_comma_join(parts: list[str]) -> str: positive = ", ".join(p.strip() for p in parts if p and p.strip()) - seen, deduped = set(), [] + seen: set[str] = set() + deduped: list[str] = [] for tag in positive.split(", "): t = tag.strip() if t and t not in seen: @@ -95,53 +348,152 @@ def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = return ", ".join(deduped) -async def generate_sd_prompt( - messages: list, - persona_id: str, - outfit_json: str = "[]", -) -> tuple[str | None, str | None]: - persona = await get_persona(persona_id) - # Generate only if persona has appearance tags - if not persona or not (persona.get("appearance_tags") or "").strip(): - logger.debug("sd_prompt skip: persona=%s no appearance_tags", persona_id) - return None, None +def _build_tag_core(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: + """Anchor + structure: quality, appearance, outfit, action/env tags, LoRA. No POV prose, no scene_description.""" + parts = [_quality_prefix()] + appearance = _appearance_for_persona(persona) + if appearance: + parts.append(appearance) + if outfit_tags: + parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags))) + if scene.get("shot_type") == "landscape": + parts.append(_sanitize_tags_string(scene.get("environment_tags", ""))) + else: + if not _is_anima() and scene.get("shot_type") == "first_person_pov": + parts.append("pov, first-person view, looking at viewer") + parts.append(_sanitize_tags_string(scene.get("action_tags", ""))) + parts.append(_sanitize_tags_string(scene.get("environment_tags", ""))) + _append_lora(parts, persona) + return _dedupe_comma_join(parts) - recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:] - if not recent: - return None, None - excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent) +def build_positive_prompt_tags_only(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: + """Tags + contextual POV phrase (Anima) or legacy Pony path.""" + if not _is_anima(): + return build_positive_prompt(scene, persona, outfit_tags) + core = _build_tag_core(scene, persona, outfit_tags) + pov = _build_pov_phrase(scene) + if pov: + return f"{core}, {pov}" if core else pov + return core - builder_messages = [ - {"role": "system", "content": PROMPT_BUILDER_SYSTEM}, - {"role": "user", "content": f"Chat:\n{excerpt}"}, - ] - try: - if SD_PROMPT_MODEL: - raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL) - else: - raw = await send_message(builder_messages) - raw = raw.strip() - if raw.startswith("```"): - raw = re.sub(r"^```\w*\n?", "", raw) - raw = re.sub(r"\n?```$", "", raw) - scene = json.loads(raw) - if not isinstance(scene, dict): - logger.warning("sd_prompt: LLM returned non-dict: %.100s", raw) - return None, None - except Exception as e: - logger.warning("sd_prompt failed: %s raw=%.200s", e, locals().get("raw", "")) - return None, None +def _tag_tokens_for_dedupe(tag_line: str) -> set[str]: + tokens: set[str] = set() + for part in tag_line.replace("= 4: + tokens.add(w) + return tokens - try: - outfit_list = json.loads(outfit_json or "[]") - outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else "" - except Exception: - outfit_tags = "" - positive = build_positive_prompt(scene, persona, outfit_tags) +def _trim_redundant_scene_description(desc: str, tag_line: str) -> str: + tag_tokens = _tag_tokens_for_dedupe(tag_line) + if not tag_tokens or not desc.strip(): + return desc.strip() + kept: list[str] = [] + for sentence in re.split(r"(?<=[.!?])\s+", desc.strip()): + s = sentence.strip() + if not s: + continue + words = [w.lower() for w in re.findall(r"[a-zA-Z]{4,}", s)] + if not words: + kept.append(s) + continue + overlap = sum(1 for w in words if w in tag_tokens) / len(words) + if overlap < 0.62: + kept.append(s) + + return " ".join(kept).strip() + + +def _extract_illustrate_content(content: str, max_chars: int = 1400) -> str: + """Long assistant posts (first_mes): use final beat after time-skip, last paragraphs.""" + text = strip_image_prompt_tag(content).strip() + if not text: + return "" + chunks = _TIME_SKIP_RE.split(text) + if len(chunks) > 1: + text = chunks[-1].strip() + if len(text) <= max_chars: + return text + paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] + if paragraphs: + for n in (1, 2, 3): + tail = "\n\n".join(paragraphs[-n:]) + if len(tail) <= max_chars: + return tail + return paragraphs[-1][-max_chars:] + return text[-max_chars:] + + +def _fallback_mood_prose(scene: dict) -> str: + cue = (scene.get("pov_cue") or "").strip().lower().replace("-", "_").replace(" ", "_") + if cue in _POV_MOOD_FALLBACK: + return _POV_MOOD_FALLBACK[cue] + inferred = _infer_pov_cue_from_action(scene.get("action_tags", "")) + return _POV_MOOD_FALLBACK.get(inferred, "Soft atmosphere; her expression toward the camera.") + + +def _cap_scene_description(desc: str, max_words: int = 40, max_chars: int = 220) -> str: + words = desc.split() + if len(words) > max_words: + desc = " ".join(words[:max_words]) + if len(desc) > max_chars: + desc = desc[: max_chars - 3] + "..." + return desc + + +def build_positive_prompt_hybrid(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: + """Production Anima prompt: tag core + POV cue + short mood prose.""" + if not _is_anima(): + return build_positive_prompt(scene, persona, outfit_tags) + + base = build_positive_prompt_tags_only(scene, persona, outfit_tags) + desc = _trim_redundant_scene_description( + (scene.get("scene_description") or "").strip(), + base, + ) + desc = _cap_scene_description(desc) + if not desc: + desc = _cap_scene_description(_fallback_mood_prose(scene)) + if not desc: + return base + + lora = (persona or {}).get("lora_name", "") + weight = (persona or {}).get("lora_weight", 0.8) + lora_suffix = f" " if lora else "" + if lora_suffix and base.endswith(lora_suffix): + base = base[: -len(lora_suffix)] + return f"{base}. {desc}{lora_suffix}" + return f"{base}. {desc}" + + +def build_positive_prompt(scene: dict, persona: dict | None, outfit_tags: str = "") -> str: + """Legacy entry: Pony/non-Anima full prompt; Anima delegates to tags-only.""" + if _is_anima(): + return build_positive_prompt_tags_only(scene, persona, outfit_tags) + + parts = [_quality_prefix()] + appearance = _appearance_for_persona(persona) + if appearance: + parts.append(appearance) + if outfit_tags: + parts.append(_sanitize_tags_string(_dedupe_outfit_tags(outfit_tags))) + if scene.get("shot_type") == "landscape": + parts.append(_sanitize_tags_string(scene.get("environment_tags", ""))) + else: + if scene.get("shot_type") == "first_person_pov": + parts.append("pov, first-person view, looking at viewer") + parts.append(_sanitize_tags_string(scene.get("action_tags", ""))) + parts.append(_sanitize_tags_string(scene.get("environment_tags", ""))) + _append_lora(parts, persona) + return _dedupe_comma_join(parts) + + +def _negative_for_scene(scene: dict) -> str: if _is_pony(): negative = PONY_NEGATIVE elif _is_anima(): @@ -151,6 +503,228 @@ async def generate_sd_prompt( if scene.get("shot_type") == "first_person_pov": negative += ", third person, over the shoulder" + viewer_visible = scene.get("viewer_body_visible") is True + if not viewer_visible or _scene_has_physical_contact(scene): + negative += ", " + POV_INTERACTION_NEGATIVE - full = positive + f"\n\nNegative prompt: {negative}" - return full, negative + return negative + + +def _format_builder_user_block(persona: dict, messages: list[dict], outfit_json: str) -> str: + lines: list[str] = [] + tags = (persona.get("appearance_tags") or "").strip() + lines.append(f"Character appearance (tags): {tags}") + prose = (persona.get("appearance_prose") or "").strip() + if _is_anima() and prose and prose != tags: + snippet = prose[:300] + ("..." if len(prose) > 300 else "") + lines.append(f"Character notes (do not copy into tags or scene_description): {snippet}") + + try: + outfit_list = json.loads(outfit_json or "[]") + outfit_ref = ", ".join(outfit_list) if isinstance(outfit_list, list) else "" + except Exception: + outfit_ref = "" + + if outfit_ref: + lines.append(f"Current outfit (tags): {outfit_ref}") + + recent = [m for m in messages if m.get("role") in ("user", "assistant")][-6:] + if not recent: + lines.append("\nChat:\n(no messages — return should_generate=false)") + return "\n".join(lines) + + illustrate: list[dict] = [] + if recent[-1]["role"] == "assistant": + illustrate = [recent[-1]] + if len(recent) >= 2 and recent[-2]["role"] == "user": + illustrate.insert(0, recent[-2]) + else: + illustrate = [recent[-1]] + if len(recent) >= 2 and recent[-2]["role"] == "assistant": + illustrate.insert(0, recent[-2]) + + context = [m for m in recent if m not in illustrate] + + lines.append("\n=== ILLUSTRATE (draw THIS beat only) ===") + for m in illustrate: + raw = m.get("content", "") + content = _extract_illustrate_content(raw) if m.get("role") == "assistant" else strip_image_prompt_tag(raw) + lines.append(f"{m['role']}: {content}") + + if context: + lines.append("\n=== Context (outfit/location hints only — do not illustrate old beats) ===") + for m in context: + content = strip_image_prompt_tag(m.get("content", "")) + if len(content) > 800: + content = content[:797] + "..." + lines.append(f"{m['role']}: {content}") + + return "\n".join(lines) + + +def _parse_scene_json(raw: str) -> dict: + cleaned = raw.strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```\w*\n?", "", cleaned) + cleaned = re.sub(r"\n?```$", "", cleaned) + scene = json.loads(cleaned) + if not isinstance(scene, dict): + raise ValueError("LLM returned non-object JSON") + return _normalize_shot_type(scene) + + +def _bundle_from_scene(scene: dict, persona: dict, outfit_tags: str) -> SdPromptBundle: + negative = _negative_for_scene(scene) + if _is_anima(): + hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags) + tag_full = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative + desc_full = None + if anima_dual_enabled(): + tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags) + desc_full = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative + return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=desc_full) + + positive = build_positive_prompt(scene, persona, outfit_tags) + tag_full = positive + NEGATIVE_PROMPT_SEPARATOR + negative + return SdPromptBundle(tag_full=tag_full, negative=negative, desc_full=None) + + +def _parse_chat_excerpt(excerpt: str) -> list[dict]: + messages: list[dict] = [] + for line in (excerpt or "").splitlines(): + line = line.strip() + if not line: + continue + lower = line.lower() + if lower.startswith("user:"): + messages.append({"role": "user", "content": line[5:].strip()}) + elif lower.startswith("assistant:"): + messages.append({"role": "assistant", "content": line[10:].strip()}) + elif lower.startswith("system:"): + messages.append({"role": "system", "content": line[7:].strip()}) + else: + messages.append({"role": "user", "content": line}) + return messages + + +async def run_prompt_builder( + persona_id: str, + *, + messages: list[dict] | None = None, + chat_excerpt: str = "", + outfit_json: str = "[]", + appearance_override: str | None = None, + use_prose: bool = False, +) -> dict: + """Debug: full SD prompt builder pipeline with LLM raw output.""" + persona = await get_persona(persona_id) or {} + if appearance_override is not None: + persona = {**persona, "appearance_tags": appearance_override} + + recent = messages if messages is not None else _parse_chat_excerpt(chat_excerpt) + recent = [m for m in recent if m.get("role") in ("user", "assistant")] + + user_block = _format_builder_user_block(persona, recent, outfit_json) + builder_messages = [ + {"role": "system", "content": _builder_system()}, + {"role": "user", "content": user_block}, + ] + model_used = SD_PROMPT_MODEL or "SYSTEM_MODEL" + result: dict = { + "persona_id": persona_id, + "sd_prompt_model": model_used, + "builder_system": _builder_system(), + "builder_user": user_block, + "anima_dual": anima_dual_enabled(), + } + + raw = "" + try: + if SD_PROMPT_MODEL: + raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL) + else: + raw = await send_message(builder_messages) + result["llm_raw"] = raw + scene = _parse_scene_json(raw) + result["scene"] = scene + + if not _scene_should_generate(scene): + result["skipped"] = True + result["error"] = "should_generate=false" + return result + + try: + outfit_tags = ", ".join(json.loads(outfit_json or "[]")) + except Exception: + outfit_tags = "" + + negative = _negative_for_scene(scene) + if _is_anima(): + tags_only = build_positive_prompt_tags_only(scene, persona, outfit_tags) + hybrid = build_positive_prompt_hybrid(scene, persona, outfit_tags) + result["tag_positive"] = tags_only + result["hybrid_positive"] = hybrid + result["negative"] = negative + result["tags_only_full"] = tags_only + NEGATIVE_PROMPT_SEPARATOR + negative + result["hybrid_full"] = hybrid + NEGATIVE_PROMPT_SEPARATOR + negative + result["tag_full"] = result["hybrid_full"] + else: + positive = build_positive_prompt(scene, persona, outfit_tags) + result["tag_positive"] = positive + result["negative"] = negative + result["tag_full"] = positive + NEGATIVE_PROMPT_SEPARATOR + negative + except Exception as e: + result["error"] = str(e) + result["llm_raw"] = raw or result.get("llm_raw", "") + + return result + + +async def generate_sd_prompt( + messages: list, + persona_id: str, + outfit_json: str = "[]", +) -> SdPromptBundle | None: + persona = await get_persona(persona_id) + if not persona: + return None + + recent = [m for m in messages if m["role"] in ("user", "assistant")] + if not recent: + return None + + user_block = _format_builder_user_block(persona, recent, outfit_json) + builder_messages = [ + {"role": "system", "content": _builder_system()}, + {"role": "user", "content": user_block}, + ] + + raw = "" + try: + if SD_PROMPT_MODEL: + raw = await send_message_with_model(builder_messages, SD_PROMPT_MODEL) + else: + raw = await send_message(builder_messages) + scene = _parse_scene_json(raw) + except Exception as e: + logger.warning("sd_prompt failed: %s raw=%.200s", e, raw) + return None + + if not _scene_should_generate(scene): + logger.info("sd_prompt: skipped (should_generate=false)") + return None + + try: + outfit_list = json.loads(outfit_json or "[]") + outfit_tags = ", ".join(outfit_list) if isinstance(outfit_list, list) else "" + except Exception: + outfit_tags = "" + + bundle = _bundle_from_scene(scene, persona, outfit_tags) + if anima_dual_enabled() and bundle.desc_full: + logger.info( + "Anima prompts: hybrid=%.80s | tags_only=%.80s", + bundle.tag_full.split(NEGATIVE_PROMPT_SEPARATOR)[0], + bundle.desc_full.split(NEGATIVE_PROMPT_SEPARATOR)[0], + ) + return bundle diff --git a/services/sdbackend.py b/services/sdbackend.py index aa3874f..3b3bc10 100644 --- a/services/sdbackend.py +++ b/services/sdbackend.py @@ -3,6 +3,7 @@ import logging import os import uuid from pathlib import Path +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse import httpx from dotenv import load_dotenv @@ -11,7 +12,178 @@ load_dotenv() logger = logging.getLogger(__name__) -SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/") + +def _parse_basic_auth() -> httpx.BasicAuth | None: + """ + Vast Caddy on mapped ports often uses Basic realm=restricted. + Set SD_COMFY_HTTP_BASIC=user:password or SD_COMFY_USER + SD_COMFY_PASSWORD. + """ + raw = (os.getenv("SD_COMFY_HTTP_BASIC") or "").strip() + if raw: + if ":" in raw: + user, _, password = raw.partition(":") + else: + user, password = "", raw + return httpx.BasicAuth(user, password) + user = (os.getenv("SD_COMFY_USER") or "").strip() + password = (os.getenv("SD_COMFY_PASSWORD") or "").strip() + if user or password: + return httpx.BasicAuth(user, password) + return None + + +SD_BASIC_AUTH = _parse_basic_auth() + + +def _parse_comfy_config() -> tuple[str, dict[str, str]]: + """ + SD_BASE_URL may be pasted from Vast/Comfy UI with ?token=... + API paths must be base + /prompt, not ...?token=xxx/prompt + """ + raw = (os.getenv("SD_BASE_URL") or "http://127.0.0.1:8188").strip() + extra_token = (os.getenv("SD_COMFY_TOKEN") or "").strip() + parsed = urlparse(raw) + base = f"{parsed.scheme}://{parsed.netloc}" + path = (parsed.path or "").rstrip("/") + if path and path != "/": + base = f"{base}{path}" + query: dict[str, str] = {} + for key, values in parse_qs(parsed.query).items(): + if values: + query[key] = values[-1] + if extra_token: + query["token"] = extra_token + base = base.rstrip("/") + # Cloudflare tunnel to localhost:8188 — direct Comfy API, Vast ?token= does not apply + if "trycloudflare.com" in base.lower(): + if query.pop("token", None): + logger.info( + "SD_BASE_URL is trycloudflare tunnel: Vast token stripped. " + "Use tunnel for port 8188 only (see instance Port Mapping)." + ) + return base, query + + +SD_BASE_URL, SD_QUERY_PARAMS = _parse_comfy_config() + + +def _comfy_url(path: str) -> str: + if not path.startswith("/"): + path = f"/{path}" + return f"{SD_BASE_URL}{path}" + + +def _log_comfy_target() -> str: + if SD_QUERY_PARAMS.get("token"): + return f"{SD_BASE_URL}?token=***" + return SD_BASE_URL + + +def _absolute_url(location: str, fallback_path: str = "/") -> str: + if not location: + return _comfy_url(fallback_path) + if location.startswith(("http://", "https://")): + return location + if location.startswith("/"): + return f"{SD_BASE_URL}{location}" + return f"{SD_BASE_URL}/{location}" + + +def _url_with_token(url: str) -> str: + """Append gateway token to URL (Vast/Cloudflare often strip ?token on redirect).""" + if not SD_QUERY_PARAMS.get("token"): + return url + p = urlparse(url) + q: dict[str, str] = {} + for key, values in parse_qs(p.query).items(): + if values: + q[key] = values[-1] + q.update(SD_QUERY_PARAMS) + return urlunparse((p.scheme, p.netloc, p.path, "", urlencode(q), "")) + + +def _merge_params(extra: dict | None) -> dict | None: + if not SD_QUERY_PARAMS and not extra: + return None + merged = dict(SD_QUERY_PARAMS) + if extra: + merged.update(extra) + return merged + + +def _is_vast_gateway() -> bool: + return "trycloudflare.com" not in SD_BASE_URL.lower() + + +def _make_comfy_client(*, timeout: float = 300) -> httpx.AsyncClient: + return httpx.AsyncClient( + timeout=timeout, + follow_redirects=False, + auth=SD_BASIC_AUTH, + ) + + +async def _prime_comfy_gateway(client: httpx.AsyncClient) -> None: + """ + Vast Caddy: browser opens /?token=… and gets a session cookie; API then works. + Prime with redirects so Set-Cookie is collected, then merge into the API client. + """ + token = SD_QUERY_PARAMS.get("token") + if not token or not _is_vast_gateway(): + return + try: + async with httpx.AsyncClient( + timeout=30, + follow_redirects=True, + auth=SD_BASIC_AUTH, + ) as prime: + r = await prime.get(_comfy_url("/"), params={"token": token}) + client.cookies.update(prime.cookies) + logger.info( + "Comfy gateway prime GET /?token=*** → %s, cookies=%s", + r.status_code, + list(prime.cookies.keys()) or "(none)", + ) + except Exception as e: + logger.warning("Comfy gateway prime failed: %s", e) + + +async def _comfy_request( + client: httpx.AsyncClient, + method: str, + path: str, + *, + params: dict | None = None, + **kwargs, +) -> httpx.Response: + """ + Comfy API: trycloudflare tunnel = no token. + Vast IP:PORT gateway = ?token= + cookie prime; follow redirects with token re-attached. + """ + url = _comfy_url(path) + extra = params or {} + token = SD_QUERY_PARAMS.get("token") + use_vast_auth = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None) + + if token and _is_vast_gateway(): + await _prime_comfy_gateway(client) + + req_params: dict | None = _merge_params(extra) if use_vast_auth else (extra or None) + resp: httpx.Response | None = None + + for hop in range(6): + resp = await client.request(method, url, params=req_params, **kwargs) + if resp.status_code not in (301, 302, 303, 307, 308): + return resp + loc = _absolute_url(resp.headers.get("location", ""), path) + url = _url_with_token(loc) if use_vast_auth else loc + req_params = extra or None + logger.info("Comfy redirect %s hop %s → %s", resp.status_code, hop + 1, url.split("?")[0]) + + assert resp is not None + return resp + + SD_STEPS = int(os.getenv("SD_STEPS", "28")) SD_CFG = float(os.getenv("SD_CFG", "7")) SD_SAMPLER = os.getenv("SD_SAMPLER", "euler") @@ -26,6 +198,8 @@ SD_DEFAULT_NEGATIVE = os.getenv( SD_UNET = os.getenv("SD_UNET", "anima-preview3-base.safetensors") SD_CLIP = os.getenv("SD_CLIP", "qwen_3_06b_base.safetensors") SD_VAE = os.getenv("SD_VAE", "qwen_image_vae.safetensors") +SD_STYLE_LORA = os.getenv("SD_STYLE_LORA", "") +SD_STYLE_LORA_WEIGHT = float(os.getenv("SD_STYLE_LORA_WEIGHT", "0.7")) IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images")) @@ -38,19 +212,37 @@ def _use_anima() -> bool: def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]: + # Try new separator first + sep = "__NEGATIVE_PROMPT__" + if f"\n{sep}\n" in full_prompt: + pos, _, neg = full_prompt.partition(f"\n{sep}\n") + return pos.strip(), neg.strip() + # Fallback to old format if "\n\nNegative prompt:" in full_prompt: pos, _, neg = full_prompt.partition("\n\nNegative prompt:") return pos.strip(), neg.strip() return full_prompt.strip(), SD_DEFAULT_NEGATIVE -def _build_workflow(positive: str, negative: str) -> dict: +def _workflow_uses_anima(overrides: dict | None) -> bool: + if overrides and overrides.get("checkpoint"): + return False + if overrides and overrides.get("unet"): + return True + return _use_anima() + + +def _build_workflow(positive: str, negative: str, overrides: dict | None = None) -> dict: seed = int(uuid.uuid4().int % 2**32) - if _use_anima(): - return { - "44": {"class_type": "UNETLoader", "inputs": {"unet_name": SD_UNET, "weight_dtype": "default"}}, - "45": {"class_type": "CLIPLoader", "inputs": {"clip_name": SD_CLIP, "type": "stable_diffusion", "device": "default"}}, - "15": {"class_type": "VAELoader", "inputs": {"vae_name": SD_VAE}}, + o = overrides or {} + if _workflow_uses_anima(o): + unet = o.get("unet") or SD_UNET + clip = o.get("clip") or SD_CLIP + vae = o.get("vae") or SD_VAE + workflow = { + "44": {"class_type": "UNETLoader", "inputs": {"unet_name": unet, "weight_dtype": "default"}}, + "45": {"class_type": "CLIPLoader", "inputs": {"clip_name": clip, "type": "stable_diffusion", "device": "default"}}, + "15": {"class_type": "VAELoader", "inputs": {"vae_name": vae}}, "28": {"class_type": "EmptyLatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}}, "11": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["45", 0]}}, "12": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["45", 0]}}, @@ -68,9 +260,24 @@ def _build_workflow(positive: str, negative: str) -> dict: "8": {"class_type": "VAEDecode", "inputs": {"samples": ["19", 0], "vae": ["15", 0]}}, "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, } - # Standard checkpoint workflow (Pony / SDXL) + if SD_STYLE_LORA: + workflow["46"] = { + "class_type": "LoraLoader", + "inputs": { + "lora_name": SD_STYLE_LORA, + "model": ["44", 0], + "clip": ["45", 0], + "strength_model": SD_STYLE_LORA_WEIGHT, + "strength_clip": SD_STYLE_LORA_WEIGHT, + }, + } + workflow["19"]["inputs"]["model"] = ["46", 0] + workflow["11"]["inputs"]["clip"] = ["46", 1] + workflow["12"]["inputs"]["clip"] = ["46", 1] + return workflow + ckpt = o.get("checkpoint") or SD_CHECKPOINT return { - "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}}, + "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt}}, "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}}, "6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}}, "7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}}, @@ -89,24 +296,78 @@ def _build_workflow(positive: str, negative: str) -> dict: } +async def comfy_api_request( + method: str, + path: str, + *, + params: dict | None = None, + json_body: dict | None = None, + timeout: float = 60, +) -> tuple[int, dict | str, dict]: + """ + Raw Comfy API call for debug. Returns (status_code, parsed_json_or_text, response_headers_subset). + """ + async with _make_comfy_client(timeout=timeout) as client: + await _prime_comfy_gateway(client) + token = SD_QUERY_PARAMS.get("token") + use_vast = _is_vast_gateway() and (bool(token) or SD_BASIC_AUTH is not None) + req_params = _merge_params(params) if use_vast else (params or None) + req_kwargs: dict = {} + if json_body is not None and method.upper() not in ("GET", "HEAD"): + req_kwargs["json"] = json_body + resp = await _comfy_request( + client, + method.upper(), + path, + params=req_params, + **req_kwargs, + ) + headers = { + k: resp.headers.get(k) + for k in ("content-type", "location", "www-authenticate") + if resp.headers.get(k) + } + try: + body = resp.json() + except Exception: + body = resp.text[:8000] + return resp.status_code, body, headers + + +async def fetch_object_info() -> dict: + status, body, _ = await comfy_api_request("GET", "/object_info", timeout=120) + if status != 200 or not isinstance(body, dict): + raise RuntimeError(f"object_info failed: HTTP {status} {body!s:.300}") + return body + + async def check_sd() -> bool: try: - async with httpx.AsyncClient(timeout=5) as client: - r = await client.get(f"{SD_BASE_URL}/system_stats") + async with _make_comfy_client(timeout=15) as client: + await _prime_comfy_gateway(client) + r = await _comfy_request(client, "GET", "/system_stats") return r.status_code == 200 except Exception: return False -async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]: +async def txt2img( + prompt: str, + negative_prompt: str | None = None, + *, + overrides: dict | None = None, +) -> tuple[bytes, str]: neg = negative_prompt or SD_DEFAULT_NEGATIVE - workflow = _build_workflow(prompt, neg) + workflow = _build_workflow(prompt, neg, overrides) client_id = uuid.uuid4().hex - logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt) - async with httpx.AsyncClient(timeout=300) as client: - resp = await client.post( - f"{SD_BASE_URL}/prompt", + logger.info("ComfyUI request → %s prompt: %.120s", _log_comfy_target(), prompt) + async with _make_comfy_client() as client: + await _prime_comfy_gateway(client) + resp = await _comfy_request( + client, + "POST", + "/prompt", json={"prompt": workflow, "client_id": client_id}, ) resp.raise_for_status() @@ -115,7 +376,7 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte for _ in range(300): await asyncio.sleep(1) - hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}") + hist = await _comfy_request(client, "GET", f"/history/{prompt_id}") data = hist.json() if prompt_id in data: entry = data[prompt_id] @@ -127,9 +388,15 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte for node_output in outputs.values(): if "images" in node_output: img_info = node_output["images"][0] - img_resp = await client.get( - f"{SD_BASE_URL}/view", - params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")}, + img_resp = await _comfy_request( + client, + "GET", + "/view", + params={ + "filename": img_info["filename"], + "subfolder": img_info.get("subfolder", ""), + "type": img_info.get("type", "output"), + }, ) img_resp.raise_for_status() image_bytes = img_resp.content @@ -145,11 +412,43 @@ async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[byte raise RuntimeError("ComfyUI generation timed out or produced no output") -async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]: +async def generate_from_full_prompt( + full_prompt: str, + *, + overrides: dict | None = None, +) -> tuple[str | None, str | None]: positive, negative = split_prompt_and_negative(full_prompt) try: - _, rel_path = await txt2img(positive, negative) + _, rel_path = await txt2img(positive, negative, overrides=overrides) return rel_path, None + except httpx.HTTPStatusError as e: + code = e.response.status_code + if code == 401: + logger.error( + "ComfyUI 401: Vast Caddy needs SD_COMFY_TOKEN (or ?token= in SD_BASE_URL) " + "and/or SD_COMFY_HTTP_BASIC=user:pass from the instance page. " + "Test: curl -u user:pass http://IP:PORT/system_stats " + "or open /?token=… in browser then curl with cookies. " + "Alternative: trycloudflare URL for localhost:8188 in Port Mapping." + ) + elif code in (301, 302, 303, 307, 308): + logger.error( + "ComfyUI %s: wrong URL — use trycloudflare tunnel for 8188, not web UI link. " + "SD_BASE_URL=https://reviewer-relief-edmonton-specializing.trycloudflare.com " + "(no ?token=). Location: %s", + code, + e.response.headers.get("location"), + ) + else: + logger.error("ComfyUI HTTP %s: %s", code, e) + return None, str(e) + except httpx.ConnectError as e: + logger.error( + "ComfyUI connect failed (%s): IP:8188 is often not exposed on Vast. " + "Use trycloudflare URL from Port Mapping for localhost:8188.", + e, + ) + return None, str(e) except Exception as e: logger.error("ComfyUI error: %s", e) return None, str(e) diff --git a/services/session_identity.py b/services/session_identity.py new file mode 100644 index 0000000..671ae6f --- /dev/null +++ b/services/session_identity.py @@ -0,0 +1,31 @@ +import logging + +from services.memory import get_session + +logger = logging.getLogger(__name__) + + +async def resolve_session_persona( + session_id: str, + requested: str | None = None, + *, + create_persona: str | None = None, +) -> str: + """ + Session.persona_id is the source of truth. + requested is ignored when it disagrees (logged). create_persona used only if session missing. + """ + session = await get_session(session_id) + if not session: + return (create_persona or requested or "default").strip() or "default" + + bound = (session.get("persona_id") or "default").strip() or "default" + req = (requested or "").strip() + if req and req != bound: + logger.warning( + "persona_id mismatch session=%s bound=%s requested=%s (using bound)", + session_id, + bound, + req, + ) + return bound diff --git a/services/system_message_migration.py b/services/system_message_migration.py new file mode 100644 index 0000000..4eef6fd --- /dev/null +++ b/services/system_message_migration.py @@ -0,0 +1,21 @@ +import logging + +from services.chat_prompt import get_system_prompt +from services.memory import get_all_sessions, get_history, upsert_static_system_message + +logger = logging.getLogger(__name__) + + +async def migrate_static_system_messages() -> int: + """Rebuild stored system rows from sessions.persona_id (strip legacy RPG text).""" + updated = 0 + for session in await get_all_sessions(): + sid = session["session_id"] + persona_id = session.get("persona_id") or "default" + history = await get_history(sid) + static = await get_system_prompt(persona_id, history, "") + if await upsert_static_system_message(sid, static, history): + updated += 1 + if updated: + logger.info("Migrated %s session system message(s) to static persona prompt", updated) + return updated diff --git a/static/avatars/card_carrie_0578318a_dbee17bc.png b/static/avatars/card_carrie_0578318a_dbee17bc.png new file mode 100644 index 0000000..8ef6eb0 Binary files /dev/null and b/static/avatars/card_carrie_0578318a_dbee17bc.png differ diff --git a/static/avatars/card_carrie_926629e6_09f02497.png b/static/avatars/card_carrie_926629e6_09f02497.png new file mode 100644 index 0000000..8ef6eb0 Binary files /dev/null and b/static/avatars/card_carrie_926629e6_09f02497.png differ diff --git a/static/avatars/card_carrie_926629e6_8b9b8891.png b/static/avatars/card_carrie_926629e6_8b9b8891.png new file mode 100644 index 0000000..8ef6eb0 Binary files /dev/null and b/static/avatars/card_carrie_926629e6_8b9b8891.png differ diff --git a/static/avatars/card_carrie_926629e6_e01546a5.png b/static/avatars/card_carrie_926629e6_e01546a5.png new file mode 100644 index 0000000..8ef6eb0 Binary files /dev/null and b/static/avatars/card_carrie_926629e6_e01546a5.png differ diff --git a/static/css/app.css b/static/css/app.css index 6930b96..960be79 100644 --- a/static/css/app.css +++ b/static/css/app.css @@ -283,7 +283,16 @@ header h1 { font-size: 1.1rem; color: #e94560; } .translate-btn:hover { background: #4a90d9; color: white; } .translate-btn:disabled { opacity: 0.5; cursor: default; } -.chat-image { margin-top: 8px; max-width: 100%; border-radius: 8px; border: 1px solid #0f3460; } +.chat-image-wrap { margin-top: 8px; } +.chat-image-label { + font-size: 0.75rem; + color: #888; + margin-bottom: 4px; + text-transform: uppercase; + letter-spacing: 0.04em; +} +.chat-image { max-width: 100%; border-radius: 8px; border: 1px solid #0f3460; display: block; } +.image-prompt-blocks .image-prompt-block + .image-prompt-block { margin-top: 8px; } .image-generating { display: flex; @@ -587,6 +596,11 @@ textarea:focus { border-color: #e94560; } flex-direction: row !important; padding: 8px 0; } +.hint-text { + font-size: 0.8rem; + color: #888; + margin: 0 0 8px; +} .chat-settings-meta { margin-top: 12px; padding: 10px; background: #1a1a2e; border-radius: 8px; diff --git a/static/css/debug.css b/static/css/debug.css new file mode 100644 index 0000000..e0bcf19 --- /dev/null +++ b/static/css/debug.css @@ -0,0 +1,209 @@ +/* app.css sets body { overflow: hidden; height: 100vh } for chat layout */ +html:has(body.debug-page), +body.debug-page { + height: auto; + min-height: 100vh; + overflow-x: hidden; + overflow-y: auto; +} + +.debug-page { + background: #0f0f1a; + color: #ddd; + min-height: 100vh; + padding-bottom: 48px; +} + +.debug-header { + display: flex; + align-items: center; + gap: 16px; + padding: 12px 20px; + border-bottom: 1px solid #1a2744; + background: #16213e; +} + +.debug-header a { + color: #9b7fd4; + text-decoration: none; +} + +.debug-header h1 { + flex: 1; + margin: 0; + font-size: 1.1rem; +} + +.debug-tabs { + display: flex; + gap: 4px; + padding: 8px 16px; + background: #12121f; + border-bottom: 1px solid #1a2744; + flex-wrap: wrap; +} + +.debug-tabs button { + background: transparent; + border: 1px solid #2a3a5c; + color: #aaa; + padding: 8px 14px; + border-radius: 8px; + cursor: pointer; +} + +.debug-tabs button.active { + background: #1a2744; + color: #e94560; + border-color: #e94560; +} + +.debug-main { + padding: 16px 20px 40px; + max-width: 1200px; + margin: 0 auto; + overflow: visible; +} + +.debug-panel { + display: none; +} + +.debug-panel.active { + display: block; +} + +.debug-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); + gap: 12px; + margin-bottom: 12px; +} + +.debug-grid label, +.debug-main > label { + display: flex; + flex-direction: column; + gap: 4px; + font-size: 0.85rem; + color: #aaa; + margin-bottom: 10px; +} + +.debug-grid input, +.debug-grid select, +.debug-main textarea, +.debug-main input, +.debug-main select { + background: #1a1a2e; + border: 1px solid #0f3460; + color: #eee; + border-radius: 6px; + padding: 8px; + font-family: inherit; +} + +.debug-main textarea { + width: 100%; + box-sizing: border-box; + font-family: ui-monospace, monospace; + font-size: 0.85rem; +} + +.debug-btn { + background: #1a2744; + border: 1px solid #3a5080; + color: #ccc; + padding: 8px 16px; + border-radius: 8px; + cursor: pointer; + margin-bottom: 12px; +} + +.debug-btn.primary { + background: #e94560; + border-color: #e94560; + color: #fff; +} + +.debug-btn:hover { + filter: brightness(1.1); +} + +.debug-row { + display: flex; + flex-wrap: wrap; + gap: 8px; + margin-bottom: 8px; +} + +.debug-out { + background: #0a0a14; + border: 1px solid #1a2744; + border-radius: 8px; + padding: 12px; + overflow: auto; + max-height: 420px; + font-size: 0.8rem; + white-space: pre-wrap; + word-break: break-word; +} + +.debug-out.compact { + max-height: 160px; +} + +.debug-out.small { + max-height: 240px; +} + +.debug-split { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 12px; +} + +@media (max-width: 800px) { + .debug-split { + grid-template-columns: 1fr; + } +} + +.debug-split h3, +.debug-main h3 { + font-size: 0.9rem; + color: #9b7fd4; + margin: 16px 0 8px; +} + +.debug-img-wrap { + margin: 12px 0; +} + +.debug-img-wrap img { + max-width: 100%; + max-height: 512px; + border-radius: 8px; + border: 1px solid #333; +} + +.debug-img-wrap.hidden { + display: none; +} + +.model-list-block { + margin-bottom: 8px; +} + +.model-list-block summary { + cursor: pointer; + color: #9b7fd4; +} + +.model-list-block ul { + margin: 4px 0 0; + padding-left: 1.2rem; + font-size: 0.8rem; + max-height: 120px; + overflow: auto; +} diff --git a/static/debug.html b/static/debug.html new file mode 100644 index 0000000..6fff947 --- /dev/null +++ b/static/debug.html @@ -0,0 +1,134 @@ + + + + + + Debug — AI ChatBot + + + + +
+ ← Чат +

Debug

+ +
+ + + +
+
+
Загрузка…
+
+ +
+
+ + + + +
+ + +
+
+

Scene JSON

+
+
+
+

Теги / гибрид

+
+
+
+
+ LLM raw + builder +
+
+
+ +
+
+ +
+ + + +
+
+ +
+
+ + +
+
+ +

Модели в Comfy

+
+ +

Генерация

+
+ + + + +
+ + + + +
+ +

Raw API

+
+ + +
+ + + +
+
+
+ + + + diff --git a/static/index.html b/static/index.html index 98b9646..88f5572 100644 --- a/static/index.html +++ b/static/index.html @@ -13,6 +13,7 @@

🤖 AI Chat

Новый чат + 🛠 @@ -314,6 +315,9 @@ +

Персонаж чата

+

Смена персонажа перепривязывает этот чат. Историю можно сохранить или очистить.

+
@@ -346,6 +350,6 @@ - + diff --git a/static/js/chat.js b/static/js/chat.js index 02266d8..5bda259 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -1,9 +1,9 @@ import { sessionId, currentPersona, dom } from './state.js'; -import { parseImagePromptFromContent, copyToClipboard } from './utils.js'; +import { parseImagePromptFromContent, copyToClipboard, splitSdPromptForCopy } from './utils.js'; export async function initChat(options = {}) { - if (!sessionId || !currentPersona) return; - const payload = { message: '', session_id: sessionId, persona_id: currentPersona }; + if (!sessionId) return; + const payload = { message: '', session_id: sessionId }; if (options.first_mes_override?.trim()) payload.first_mes_override = options.first_mes_override.trim(); const res = await fetch('/chat/init', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(payload) }); if (!res.ok) return; @@ -16,19 +16,22 @@ export function updateEmptyState() { dom.emptyState?.classList.toggle('hidden', !!hasMessages); } -export function createImagePromptBlock(promptText) { +function createImagePromptBlockSingle(label, promptText) { const block = document.createElement('div'); block.className = 'image-prompt-block'; const header = document.createElement('div'); header.className = 'image-prompt-header'; - header.innerHTML = '🎨 SD prompt'; + header.innerHTML = `🎨 ${label}`; const copyBtn = document.createElement('button'); copyBtn.type = 'button'; copyBtn.className = 'copy-prompt-btn'; copyBtn.textContent = 'Копировать'; - copyBtn.addEventListener('click', async () => { - const ok = await copyToClipboard(promptText); + copyBtn.addEventListener('click', async (e) => { + e.preventDefault(); + e.stopPropagation(); + const full = textEl.textContent?.trim() || promptText || ''; + const ok = await copyToClipboard(splitSdPromptForCopy(full)); copyBtn.textContent = ok ? 'Скопировано' : 'Ошибка'; setTimeout(() => { copyBtn.textContent = 'Копировать'; }, 1500); }); @@ -39,11 +42,10 @@ export function createImagePromptBlock(promptText) { regenBtn.className = 'copy-prompt-btn'; regenBtn.textContent = '🖼 Перегенерировать'; regenBtn.addEventListener('click', async () => { - const wrapper = block.parentElement; + const wrapper = block.closest('.message'); regenBtn.disabled = true; regenBtn.textContent = '⏳…'; - wrapper?.querySelector('.chat-image')?.remove(); - wrapper?.querySelector('.image-error')?.remove(); + wrapper?.querySelectorAll('.chat-image-wrap, .chat-image, .image-error').forEach(el => el.remove()); showImageGenerating(wrapper); try { const res = await fetch('/images/generate', { @@ -76,6 +78,26 @@ export function createImagePromptBlock(promptText) { return block; } +export function createImagePromptBlock(promptText, promptAlt = null) { + const wrap = document.createElement('div'); + wrap.className = 'image-prompt-blocks'; + wrap.appendChild(createImagePromptBlockSingle('SD prompt', promptText)); + const alt = (promptAlt || '').trim(); + const main = (promptText || '').trim(); + if (alt && alt !== main) { + wrap.appendChild(createImagePromptBlockSingle('SD prompt (только теги)', promptAlt)); + } + return wrap; +} + +/** Replace or create tag + optional hybrid prompt blocks under a message. */ +export function ensureImagePromptBlocks(wrapper, tagPrompt, altPrompt = null) { + if (!wrapper || !tagPrompt) return; + wrapper.querySelector('.image-prompt-blocks')?.remove(); + wrapper.querySelectorAll('.image-prompt-block').forEach(el => el.remove()); + wrapper.appendChild(createImagePromptBlock(tagPrompt, altPrompt || null)); +} + const OUTCOME_CLASS = { 'critical failure': 'outcome-crit-fail', 'failure': 'outcome-fail', @@ -113,7 +135,7 @@ function renderNarratorMessage(narrator) { return el; } -function renderChoices(wrapper, choices) { +export function renderChoices(wrapper, choices) { if (!choices?.length) return; const row = document.createElement('div'); row.className = 'choice-row'; @@ -169,12 +191,21 @@ export function updateAffinityDisplay(affinity) { el.className = `affinity-display ${affinity > 5 ? 'affinity-high' : affinity < -3 ? 'affinity-low' : ''}`; } -export function appendChatImage(wrapper, imagePath) { +export function appendChatImage(wrapper, imagePath, label = '') { if (!imagePath) return; + const figure = document.createElement('figure'); + figure.className = 'chat-image-wrap'; + if (label) { + const cap = document.createElement('figcaption'); + cap.className = 'chat-image-label'; + cap.textContent = label; + figure.appendChild(cap); + } const img = document.createElement('img'); img.className = 'chat-image'; img.src = imagePath; - wrapper.appendChild(img); + figure.appendChild(img); + wrapper.appendChild(figure); } export function showImageGenerating(wrapper) { @@ -262,7 +293,7 @@ async function regenerateMessage(messageId, wrapper) { const res = await fetch('/chat/regenerate', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ session_id: sessionId, persona_id: currentPersona, message_id: messageId }), + body: JSON.stringify({ session_id: sessionId, message_id: messageId }), }); if (!res.ok) throw new Error('Ошибка: ' + res.status); removeTyping(); @@ -303,6 +334,8 @@ export async function reloadChatFromServer(id) { m.image_prompt, m.image_path ? `/static/${m.image_path}` : null, m.id, + m.image_prompt_alt, + m.image_path_alt ? `/static/${m.image_path_alt}` : null, ); }); } @@ -344,8 +377,12 @@ async function consumeStream(res) { if (data.image_generating && bubble) { bubble.classList.remove('typing-active'); const wrapper = bubble.parentElement; - if (data.image_prompt && !wrapper.querySelector('.image-prompt-block')) { - wrapper.appendChild(createImagePromptBlock(data.image_prompt)); + if (data.image_prompt) { + ensureImagePromptBlocks( + wrapper, + data.image_prompt, + data.image_prompt_alt || null, + ); } showImageGenerating(wrapper); dom.messagesEl.scrollTop = dom.messagesEl.scrollHeight; @@ -361,14 +398,15 @@ async function consumeStream(res) { bubble.textContent = bubble.textContent.replace(IMAGE_PROMPT_RE, '').trim(); } - if (data.image_prompt && wrapper && !wrapper.querySelector('.image-prompt-block')) { - wrapper.appendChild(createImagePromptBlock(data.image_prompt)); + if (data.image_prompt && wrapper) { + ensureImagePromptBlocks( + wrapper, + data.image_prompt, + data.image_prompt_alt || null, + ); } if (data.image_path && wrapper) { - console.log('[image] appending', data.image_path, 'to', wrapper); - appendChatImage(wrapper, data.image_path); - } else { - console.log('[image] skip: image_path=', data.image_path, 'wrapper=', wrapper); + appendChatImage(wrapper, data.image_path, ''); } if (data.image_error && wrapper) { const err = document.createElement('div'); @@ -388,7 +426,15 @@ async function consumeStream(res) { } } -export function addMessage(role, content = '', imagePrompt = null, imagePath = null, messageId = null) { +export function addMessage( + role, + content = '', + imagePrompt = null, + imagePath = null, + messageId = null, + imagePromptAlt = null, + imagePathAlt = null, +) { updateEmptyState(); const wrapper = document.createElement('div'); wrapper.className = `message ${role}`; @@ -446,8 +492,9 @@ export function addMessage(role, content = '', imagePrompt = null, imagePath = n wrapper.appendChild(translateBtn); } - if (prompt) wrapper.appendChild(createImagePromptBlock(prompt)); - if (imagePath) appendChatImage(wrapper, imagePath); + if (prompt) wrapper.appendChild(createImagePromptBlock(prompt, imagePromptAlt)); + if (imagePath) appendChatImage(wrapper, imagePath, imagePathAlt ? 'Теги' : ''); + if (imagePathAlt) appendChatImage(wrapper, imagePathAlt, 'Гибрид'); attachMessageActions(wrapper, messageId, role); dom.messagesEl.appendChild(wrapper); dom.messagesEl.scrollTop = dom.messagesEl.scrollHeight; @@ -487,7 +534,7 @@ export async function sendMessage(text, isNarratorChoice = false) { const res = await fetch('/chat/stream', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ message: text, session_id: sessionId, persona_id: currentPersona, is_narrator_choice: isNarratorChoice }), + body: JSON.stringify({ message: text, session_id: sessionId, is_narrator_choice: isNarratorChoice }), }); if (!res.ok) throw new Error('Ошибка сервера: ' + res.status); removeTyping(); diff --git a/static/js/chatSettings.js b/static/js/chatSettings.js index f907422..4efb3fb 100644 --- a/static/js/chatSettings.js +++ b/static/js/chatSettings.js @@ -1,7 +1,10 @@ -import { sessionId, currentPersona, dom } from './state.js'; +import { sessionId, currentPersona, setCurrentPersona, dom } from './state.js'; import { GENRE_LABELS, bindGenreGrid, resetGenreGrid } from './utils.js'; +import { personaIndex } from './personas.js'; const chatSettingsGenres = new Set(); +let chatSettingsPersonaId = 'default'; +let chatSettingsInitialPersonaId = 'default'; function updateChatSettingsGenresLabel() { const el = document.getElementById('chatSettingsGenresLabel'); @@ -15,6 +18,26 @@ function updateChatSettingsGenresLabel() { } } +function fillChatSettingsPersonaGrid() { + const grid = document.getElementById('chatSettingsPersonaGrid'); + if (!grid) return; + grid.innerHTML = ''; + for (const p of personaIndex.values()) { + const card = document.createElement('button'); + card.type = 'button'; + card.className = 'persona-pick-card' + (p.persona_id === chatSettingsPersonaId ? ' selected' : ''); + card.dataset.id = p.persona_id; + card.innerHTML = `${p.emoji || '🤖'}${p.name}`; + card.addEventListener('click', () => { + chatSettingsPersonaId = p.persona_id; + grid.querySelectorAll('.persona-pick-card').forEach(c => { + c.classList.toggle('selected', c.dataset.id === chatSettingsPersonaId); + }); + }); + grid.appendChild(card); + } +} + function loadRpgSettingsToDom(prefix, settings) { document.getElementById(`${prefix}SettingDice`).checked = settings.dice !== false; document.getElementById(`${prefix}SettingNarrator`).checked = settings.narrator !== false; @@ -67,6 +90,10 @@ export async function openChatSettings() { const s = await res.json(); document.getElementById('chatSettingsTitle').value = s.title || ''; + chatSettingsPersonaId = s.persona_id || 'default'; + chatSettingsInitialPersonaId = chatSettingsPersonaId; + fillChatSettingsPersonaGrid(); + const rpgOn = !!s.rpg_enabled; document.getElementById('chatSettingsRpg').checked = rpgOn; document.getElementById('chatSettingsRpgBlock').classList.toggle('hidden', !rpgOn); @@ -117,13 +144,45 @@ export function initChatSettings() { document.getElementById('chatSettingsSave')?.addEventListener('click', async () => { if (!sessionId) return; - const { loadSessions, applySessionUi } = await import('./sessions.js'); + const { loadSessions, applySessionUi, renderSystemBlob } = await import('./sessions.js'); + const { reloadChatFromServer } = await import('./chat.js'); + const { highlightPersonaBar } = await import('./personas.js'); const title = document.getElementById('chatSettingsTitle').value.trim(); const rpgOn = document.getElementById('chatSettingsRpg').checked; const genreValue = [...chatSettingsGenres].join(',') || 'adventure'; const settings = readRpgSettingsFromDom('cs'); + if (chatSettingsPersonaId !== chatSettingsInitialPersonaId) { + const pName = personaIndex.get(chatSettingsPersonaId)?.name || chatSettingsPersonaId; + const keepHistory = confirm( + `Перепривязать чат к «${pName}»?\n\n` + + 'OK — сохранить историю сообщений (персонаж в старых репликах может не совпадать).\n' + + 'Отмена — очистить историю и начать с приветствия нового персонажа.', + ); + const clearHistory = !keepHistory; + + const rebindRes = await fetch(`/sessions/${sessionId}/rebind-persona`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + persona_id: chatSettingsPersonaId, + clear_history: clearHistory, + }), + }); + if (!rebindRes.ok) { + const err = await rebindRes.json().catch(() => ({})); + alert(err.detail || 'Не удалось сменить персонажа'); + return; + } + setCurrentPersona(chatSettingsPersonaId); + chatSettingsInitialPersonaId = chatSettingsPersonaId; + highlightPersonaBar(chatSettingsPersonaId); + await reloadChatFromServer(sessionId); + const blobRes = await fetch(`/chat/system/${sessionId}`); + if (blobRes.ok) renderSystemBlob(await blobRes.json()); + } + await fetch(`/sessions/${sessionId}`, { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, @@ -141,7 +200,7 @@ export function initChatSettings() { let arc = {}; try { arc = JSON.parse(s.plot_arc_json || '{}'); } catch { /* ignore */ } if (!arc || !Object.keys(arc).length) { - await bootstrapRpg(sessionId, currentPersona, genreValue, settings); + await bootstrapRpg(sessionId, chatSettingsPersonaId, genreValue, settings); } } diff --git a/static/js/debug.js b/static/js/debug.js new file mode 100644 index 0000000..4815a01 --- /dev/null +++ b/static/js/debug.js @@ -0,0 +1,217 @@ +const $ = (id) => document.getElementById(id); + +function fmt(obj) { + return typeof obj === 'string' ? obj : JSON.stringify(obj, null, 2); +} + +async function api(path, opts = {}) { + const res = await fetch(path, { + headers: { 'Content-Type': 'application/json', ...(opts.headers || {}) }, + ...opts, + }); + const text = await res.text(); + let data; + try { + data = JSON.parse(text); + } catch { + data = text; + } + if (!res.ok) { + const detail = data?.detail || text || res.statusText; + throw new Error(`${res.status}: ${detail}`); + } + return data; +} + +function initTabs() { + const tabs = document.querySelectorAll('#debugTabs button'); + tabs.forEach((btn) => { + btn.addEventListener('click', () => { + tabs.forEach((t) => t.classList.remove('active')); + btn.classList.add('active'); + document.querySelectorAll('.debug-panel').forEach((p) => p.classList.remove('active')); + $(`panel-${btn.dataset.tab}`).classList.add('active'); + }); + }); +} + +async function loadConfig() { + const c = await api('/debug/config'); + $('configOut').textContent = fmt(c); + $('llmModel').placeholder = c.sd_prompt_model || c.system_model; + return c; +} + +async function loadPersonas() { + const list = await api('/debug/personas'); + const sel = $('sdPersona'); + sel.innerHTML = ''; + for (const p of list) { + const opt = document.createElement('option'); + opt.value = p.persona_id; + opt.textContent = `${p.name} (${p.persona_id})`; + sel.appendChild(opt); + } +} + +async function runSdPrompt() { + $('sdScene').textContent = '…'; + $('sdPrompts').textContent = '…'; + const body = { + persona_id: $('sdPersona').value, + chat_excerpt: $('sdChat').value, + outfit_json: $('sdOutfit').value || '[]', + use_prose: $('sdUseProse') ? $('sdUseProse').checked : false, + }; + const app = $('sdAppearance').value.trim(); + if (app) body.appearance_override = app; + + const data = await api('/debug/sd-prompt', { method: 'POST', body: JSON.stringify(body) }); + $('sdScene').textContent = data.scene ? fmt(data.scene) : (data.error || '—'); + const prompts = []; + if (data.tags_only_full) prompts.push('=== TAGS + POV (no prose) ===\n' + data.tags_only_full); + if (data.hybrid_full) prompts.push('\n=== HYBRID (Comfy) ===\n' + data.hybrid_full); + if (!data.tags_only_full && data.tag_full) prompts.push('=== PROMPT ===\n' + data.tag_full); + $('sdPrompts').textContent = prompts.join('\n') || data.error || '—'; + $('sdLlmRaw').textContent = [ + `model: ${data.sd_prompt_model}`, + `dual: ${data.anima_dual}`, + '', + '--- system ---', + data.builder_system || '', + '', + '--- user ---', + data.builder_user || '', + '', + '--- raw ---', + data.llm_raw || data.error || '', + ].join('\n'); + if (data.tag_full || data.hybrid_full) { + const src = data.hybrid_full || data.tag_full; + const parts = src.includes('__NEGATIVE_PROMPT__') + ? src.split('\n\n__NEGATIVE_PROMPT__\n\n') + : src.includes('\n\nNegative prompt:') + ? src.split('\n\nNegative prompt:') + : [src, '']; + $('genPositive').value = parts[0] || ''; + $('genNegative').value = parts[1] || ''; + } +} + +async function runLlm() { + $('llmOut').textContent = '…'; + const data = await api('/debug/llm', { + method: 'POST', + body: JSON.stringify({ + model: $('llmModel').value.trim(), + system: $('llmSystem').value, + user: $('llmUser').value, + }), + }); + $('llmOut').textContent = `model: ${data.model}\n\n${data.response}`; +} + +function fillModelSelect(sel, options, configured) { + const current = sel.querySelector('option')?.value ?? ''; + sel.innerHTML = ``; + for (const name of options || []) { + const opt = document.createElement('option'); + opt.value = name; + opt.textContent = name; + if (name === configured) opt.selected = true; + sel.appendChild(opt); + } +} + +async function loadComfyModels() { + $('comfyModelLists').textContent = 'Загрузка object_info…'; + const data = await api('/debug/comfy/models'); + const { models, configured } = data; + fillModelSelect($('genUnet'), models.unets, configured.unet); + fillModelSelect($('genClip'), models.clips, configured.clip); + fillModelSelect($('genVae'), models.vaes, configured.vae); + fillModelSelect($('genCkpt'), models.checkpoints, configured.checkpoint); + + const wrap = $('comfyModelLists'); + wrap.innerHTML = ''; + for (const [key, list] of Object.entries(models)) { + const block = document.createElement('details'); + block.className = 'model-list-block'; + block.open = key === 'unets' || key === 'checkpoints'; + block.innerHTML = `${key} (${list.length})`; + const ul = document.createElement('ul'); + for (const item of list) { + const li = document.createElement('li'); + li.textContent = item; + ul.appendChild(li); + } + block.appendChild(ul); + wrap.appendChild(block); + } +} + +async function comfyPing() { + $('comfyPingOut').textContent = '…'; + const data = await api('/debug/comfy/ping'); + $('comfyPingOut').textContent = fmt(data); +} + +async function comfyGenerate() { + $('comfyGenOut').textContent = 'Генерация…'; + $('comfyImgWrap').classList.add('hidden'); + const body = { + positive: $('genPositive').value, + negative: $('genNegative').value, + }; + const u = $('genUnet').value; + const c = $('genClip').value; + const v = $('genVae').value; + const ck = $('genCkpt').value; + if (u) body.unet = u; + if (c) body.clip = c; + if (v) body.vae = v; + if (ck) body.checkpoint = ck; + + const data = await api('/debug/comfy/generate', { + method: 'POST', + body: JSON.stringify(body), + }); + $('comfyGenOut').textContent = fmt(data); + if (data.image_path) { + $('comfyImg').src = data.image_path + '?t=' + Date.now(); + $('comfyImgWrap').classList.remove('hidden'); + } +} + +async function comfyRaw() { + $('comfyRawOut').textContent = '…'; + const data = await api('/debug/comfy/raw', { + method: 'POST', + body: JSON.stringify({ + method: $('rawMethod').value, + path: $('rawPath').value, + params_json: $('rawParams').value || '{}', + body_json: $('rawBody').value || '', + }), + }); + $('comfyRawOut').textContent = fmt(data); +} + +function bind() { + initTabs(); + $('btnReloadConfig').addEventListener('click', loadConfig); + $('btnSdPrompt').addEventListener('click', () => runSdPrompt().catch(showErr)); + $('btnLlm').addEventListener('click', () => runLlm().catch(showErr)); + $('btnComfyPing').addEventListener('click', () => comfyPing().catch(showErr)); + $('btnComfyModels').addEventListener('click', () => loadComfyModels().catch(showErr)); + $('btnComfyGen').addEventListener('click', () => comfyGenerate().catch(showErr)); + $('btnComfyRaw').addEventListener('click', () => comfyRaw().catch(showErr)); +} + +function showErr(e) { + alert(e.message || String(e)); +} + +bind(); +loadConfig().catch(showErr); +loadPersonas().catch(showErr); diff --git a/static/js/newChatWizard.js b/static/js/newChatWizard.js index 37035f6..499f2fe 100644 --- a/static/js/newChatWizard.js +++ b/static/js/newChatWizard.js @@ -1,4 +1,9 @@ -import { setSessionId, setCurrentPersona, currentPersona, dom } from './state.js'; +import { + setSessionId, + setCurrentPersona, + getNewChatDefaultPersona, + dom, +} from './state.js'; import { initWizard, GENRE_LABELS, @@ -7,9 +12,9 @@ import { fillGreetingSelect, getSelectedGreeting, } from './utils.js'; -import { personaIndex, highlightPersona } from './personas.js'; +import { personaIndex } from './personas.js'; -let newChatPersonaId = currentPersona; +let newChatPersonaId = getNewChatDefaultPersona(); let newChatGreetingCtx = null; const newChatGenres = new Set(); const newChatModalEl = document.getElementById('newChatModal'); @@ -84,7 +89,7 @@ function fillNewChatPersonaGrid() { const grid = document.getElementById('newChatPersonaGrid'); if (!grid) return; grid.innerHTML = ''; - newChatPersonaId = currentPersona; + newChatPersonaId = getNewChatDefaultPersona(); for (const p of personaIndex.values()) { const card = document.createElement('button'); card.type = 'button'; @@ -121,34 +126,8 @@ function updateNewChatGenresLabel() { } } -async function bootstrapRpg(sid, personaId, genreValue, settings) { - const { updateQuestPanel, addMessage } = await import('./chat.js'); - await fetch(`/sessions/${sid}`, { - method: 'PATCH', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - rpg_enabled: true, - genre: genreValue, - rpg_settings_json: JSON.stringify(settings), - }), - }); - const res = await fetch('/chat/rpg/bootstrap', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ session_id: sid, persona_id: personaId, genre: genreValue }), - }); - if (res.ok) { - const data = await res.json(); - if (data.quests) updateQuestPanel(data.quests); - if (data.plot_arc) { - const title = data.plot_arc.title || ''; - const hint = data.plot_arc.next_beat_hint || ''; - if (title || hint) addMessage('assistant', `📖 ${title}${hint ? '\n' + hint : ''}`); - } - } -} - export function openNewChatWizard() { + import('./personas.js').then(({ refreshPersonaBarHighlight }) => refreshPersonaBarHighlight()); fillNewChatPersonaGrid(); resetGenreGrid(document.getElementById('newChatGenreGrid'), newChatGenres); updateNewChatGenresLabel(); @@ -161,8 +140,17 @@ export function openNewChatWizard() { } export async function createNewChatFromWizard() { - const { clearMessages, initChat, reloadChatFromServer } = await import('./chat.js'); - const { loadSessions, applySessionUi } = await import('./sessions.js'); + const { + clearMessages, + initChat, + reloadChatFromServer, + showImageGenerating, + removeImageGenerating, + updateQuestPanel, + updateAffinityDisplay, + renderChoices, + } = await import('./chat.js'); + const { loadSessions, applySessionUi, renderSystemBlob } = await import('./sessions.js'); const sid = 'sess_' + Math.random().toString(36).slice(2, 10); setSessionId(sid); @@ -176,10 +164,22 @@ export async function createNewChatFromWizard() { newChatWizard?.reset(); try { + const sessionPatch = { persona_id: newChatPersonaId, rpg_enabled: rpg }; + if (rpg) { + sessionPatch.genre = [...newChatGenres].join(',') || 'adventure'; + sessionPatch.rpg_settings_json = JSON.stringify({ + dice: document.getElementById('ncSettingDice')?.checked ?? true, + narrator: document.getElementById('ncSettingNarrator')?.checked ?? true, + quests: document.getElementById('ncSettingQuests')?.checked ?? true, + affinity: document.getElementById('ncSettingAffinity')?.checked ?? true, + choices: document.getElementById('ncSettingChoices')?.checked ?? true, + }); + } + await fetch(`/sessions/${sid}`, { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ persona_id: newChatPersonaId, rpg_enabled: rpg }), + body: JSON.stringify(sessionPatch), }); if (customTitle) { @@ -194,25 +194,61 @@ export async function createNewChatFromWizard() { dom.headerTitle.textContent = rpg ? `${pName} — RPG` : `${pName} — новый чат`; } - highlightPersona(newChatPersonaId); + const { highlightPersonaBar } = await import('./personas.js'); + highlightPersonaBar(newChatPersonaId); const greetingOverride = getNewChatFirstMesOverride(); await initChat(greetingOverride ? { first_mes_override: greetingOverride } : {}); - if (rpg) { - const genreValue = [...newChatGenres].join(',') || 'adventure'; - const settings = { - dice: document.getElementById('ncSettingDice')?.checked ?? true, - narrator: document.getElementById('ncSettingNarrator')?.checked ?? true, - quests: document.getElementById('ncSettingQuests')?.checked ?? true, - affinity: document.getElementById('ncSettingAffinity')?.checked ?? true, - choices: document.getElementById('ncSettingChoices')?.checked ?? true, - }; - await bootstrapRpg(sid, newChatPersonaId, genreValue, settings); + const assistantWrapper = dom.messagesEl.querySelector('.message.assistant'); + showImageGenerating(assistantWrapper); + + let openingData = null; + try { + const openingRes = await fetch('/chat/opening/process', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + session_id: sid, + persona_id: newChatPersonaId, + rpg, + }), + }); + openingData = await openingRes.json(); + if (!openingRes.ok) { + console.error('opening/process failed:', openingData.detail || openingRes.statusText); + } + } finally { + removeImageGenerating(assistantWrapper); } await reloadChatFromServer(sid); + + if (openingData?.quests?.length) { + updateQuestPanel(openingData.quests); + } + if (openingData?.affinity !== undefined) { + updateAffinityDisplay(openingData.affinity); + } + if (openingData?.choices?.length) { + const wrapper = dom.messagesEl.querySelector('.message.assistant'); + if (wrapper) renderChoices(wrapper, openingData.choices); + } + if (openingData?.image_error) { + const wrapper = dom.messagesEl.querySelector('.message.assistant'); + if (wrapper) { + const err = document.createElement('div'); + err.className = 'image-error'; + err.textContent = '🖼 ' + openingData.image_error; + wrapper.appendChild(err); + } + } + const sessionRes = await fetch(`/sessions/${sid}`); if (sessionRes.ok) applySessionUi(await sessionRes.json()); + + const blobRes = await fetch(`/chat/system/${sid}`); + if (blobRes.ok) renderSystemBlob(await blobRes.json()); + await loadSessions(); } catch (e) { console.error('createNewChat error:', e); diff --git a/static/js/personas.js b/static/js/personas.js index 3b5628b..8546132 100644 --- a/static/js/personas.js +++ b/static/js/personas.js @@ -1,5 +1,9 @@ -import { currentPersona, setCurrentPersona, sessionId } from './state.js'; -import { initChat } from './chat.js'; +import { + currentPersona, + sessionId, + getNewChatDefaultPersona, + setNewChatDefaultPersona, +} from './state.js'; import { initWizard, fillGreetingSelect, getSelectedGreeting } from './utils.js'; export let personaIndex = new Map(); @@ -21,12 +25,18 @@ let cardImportWizard; let cardPreview = null; let cardImportFile = null; -export function highlightPersona(personaId) { +export function highlightPersonaBar(personaId) { document.querySelectorAll('.persona-card').forEach(c => { c.classList.toggle('active', c.dataset.id === personaId); }); } +/** Active session → session persona; otherwise new-chat preset. */ +export function refreshPersonaBarHighlight() { + const id = sessionId ? currentPersona : getNewChatDefaultPersona(); + highlightPersonaBar(id); +} + export async function loadPersonas() { const res = await fetch('/personas/'); const personas = await res.json(); @@ -37,9 +47,11 @@ export async function loadPersonas() { const bar = document.getElementById('personaBar'); bar.innerHTML = ''; + const barActiveId = sessionId ? currentPersona : getNewChatDefaultPersona(); + personas.forEach(p => { const card = document.createElement('div'); - card.className = 'persona-card' + (p.persona_id === currentPersona ? ' active' : ''); + card.className = 'persona-card' + (p.persona_id === barActiveId ? ' active' : ''); card.dataset.id = p.persona_id; const isCard = p.persona_id.startsWith('card_'); const isCustomPersona = p.custom && !isCard; @@ -131,16 +143,8 @@ export async function loadPersonas() { } export async function selectPersona(personaId) { - setCurrentPersona(personaId); - highlightPersona(personaId); - if (sessionId) { - await fetch(`/sessions/${sessionId}`, { - method: 'PATCH', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ persona_id: personaId }), - }); - await initChat(); - } + setNewChatDefaultPersona(personaId); + highlightPersonaBar(personaId); } function fillImpCardForm(preview) { diff --git a/static/js/sessions.js b/static/js/sessions.js index 1cf88af..994be06 100644 --- a/static/js/sessions.js +++ b/static/js/sessions.js @@ -2,7 +2,7 @@ import { sessionId, setSessionId, setCurrentPersona, currentPersona, dom, setRpgEnabled, } from './state.js'; import { updateQuestPanel, updateAffinityDisplay } from './chat.js'; -import { highlightPersona, personaIndex } from './personas.js'; +import { highlightPersonaBar, personaIndex } from './personas.js'; import { formatSessionDate } from './utils.js'; import { openNewChatWizard } from './newChatWizard.js'; @@ -114,7 +114,7 @@ export async function loadChatHistory(id) { const s = await sessionRes.json(); if (s.persona_id) { setCurrentPersona(s.persona_id); - highlightPersona(s.persona_id); + highlightPersonaBar(s.persona_id); } applySessionUi(s); } @@ -155,7 +155,7 @@ export async function initSessions() { let _prevBlobSections = {}; -function renderSystemBlob(blob) { +export function renderSystemBlob(blob) { const tryFmt = (str, fallback = '') => { try { return JSON.stringify(JSON.parse(str), null, 2); } catch { return str || fallback; } }; @@ -165,13 +165,18 @@ function renderSystemBlob(blob) { return ` ${icon} [${q.status}] ${q.title}`; }).join('\n'); + const personaLine = blob.persona_id + ? `[persona] ${blob.persona_name || blob.persona_id} (${blob.persona_id})` + : ''; + const sections = { + persona: personaLine, system_prompt: blob.system_prompt ? `[system_prompt]\n${blob.system_prompt}` : '', status_quo: blob.status_quo ? `[status_quo]\n${blob.status_quo}` : '', affinity: blob.affinity != null ? `[affinity] ${blob.affinity}` : '', genre: blob.genre ? `[genre] ${blob.genre}` : '', rpg_settings: blob.rpg_settings_json && blob.rpg_settings_json !== '{}' ? `[rpg_settings]\n${tryFmt(blob.rpg_settings_json)}` : '', - outfit: blob.outfit_json && blob.outfit_json !== '[]' ? `[outfit]\n${tryFmt(blob.outfit_json)}` : '', + outfit: `[outfit]\n${tryFmt(blob.outfit_json ?? '[]')}`, facts: blob.facts_json && blob.facts_json !== '[]' ? `[facts]\n${tryFmt(blob.facts_json)}` : '', plot_arc: blob.plot_arc_json && blob.plot_arc_json !== '{}' ? `[plot_arc]\n${tryFmt(blob.plot_arc_json)}` : '', quests: questLines ? `[quests]\n${questLines}` : '', diff --git a/static/js/state.js b/static/js/state.js index 6818f22..0c4855b 100644 --- a/static/js/state.js +++ b/static/js/state.js @@ -1,17 +1,31 @@ export let sessionId = localStorage.getItem('chat_session_id') || null; -export let currentPersona = localStorage.getItem('persona_id') || 'default'; +/** Persona bound to the active session (from server, not global preset). */ +export let currentPersona = 'default'; export let sidebarOpen = true; export let rpgEnabled = false; + +const NEW_CHAT_PERSONA_KEY = 'new_chat_persona_id'; + export function toggleSidebar() { sidebarOpen = !sidebarOpen; return sidebarOpen; } +export function getNewChatDefaultPersona() { + return localStorage.getItem(NEW_CHAT_PERSONA_KEY) + || localStorage.getItem('persona_id') + || 'default'; +} + +export function setNewChatDefaultPersona(id) { + const pid = id || 'default'; + localStorage.setItem(NEW_CHAT_PERSONA_KEY, pid); +} + export function setSessionId(id) { sessionId = id; if (id) localStorage.setItem('chat_session_id', id); } export function setCurrentPersona(id) { - currentPersona = id; - localStorage.setItem('persona_id', id); + currentPersona = id || 'default'; } export function setRpgEnabled(v) { rpgEnabled = !!v; } diff --git a/static/js/utils.js b/static/js/utils.js index c169608..c4fcd79 100644 --- a/static/js/utils.js +++ b/static/js/utils.js @@ -6,12 +6,32 @@ export function parseImagePromptFromContent(content) { return { text, prompt }; } +export function splitSdPromptForCopy(fullPrompt) { + if (!fullPrompt) return ''; + const marker = '\n\nNegative prompt:'; + const i = fullPrompt.indexOf(marker); + return (i >= 0 ? fullPrompt.slice(0, i) : fullPrompt).trim(); +} + export async function copyToClipboard(text) { + if (!text) return false; try { await navigator.clipboard.writeText(text); return true; } catch { - return false; + try { + const ta = document.createElement('textarea'); + ta.value = text; + ta.setAttribute('readonly', ''); + ta.style.cssText = 'position:fixed;left:-9999px;top:0'; + document.body.appendChild(ta); + ta.select(); + const ok = document.execCommand('copy'); + document.body.removeChild(ta); + return ok; + } catch { + return false; + } } } diff --git a/tests/test_sd_prompt.py b/tests/test_sd_prompt.py new file mode 100644 index 0000000..6a28f20 --- /dev/null +++ b/tests/test_sd_prompt.py @@ -0,0 +1,243 @@ +"""Unit tests for layered Anima prompt assembly (no LLM).""" + +import os +import sys +from unittest.mock import patch + +import pytest + +# Ensure project root on path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from services import sd_prompt as sp + + +@pytest.fixture +def anima(): + with patch.object(sp, "_is_anima", return_value=True), patch.object(sp, "_is_pony", return_value=False): + yield + + +PERSONA_WOLF = { + "appearance_tags": "wolfgirl, white_hair, golden_eyes, wolf_ears, tail, big_breast", + "appearance_prose": "", + "lora_name": "", +} + +PERSONA_CARRIE = { + "appearance_tags": "short_hair, brown_hair, blue_eyes, skinny", + "appearance_prose": "", + "lora_name": "", +} + + +def test_walking_scene_includes_action_tags_and_contextual_pov(anima): + scene = sp._sanitize_scene_fields({ + "shot_type": "first_person_pov", + "pov_cue": "walking_together", + "viewer_body_visible": False, + "action_tags": "holding_hands, walking, smiling, looking_at_each_other", + "environment_tags": "outdoors, sunlight, golden_hour", + "scene_description": "She walks beside you, laughter in the warm afternoon light.", + }) + hybrid = sp.build_positive_prompt_hybrid(scene, PERSONA_WOLF, "") + assert "walking" in hybrid + assert "smiling" in hybrid + assert "holding_hands" not in hybrid + assert "looking_at_each_other" not in hybrid + assert "outdoors" in hybrid + assert "threshold" not in hybrid.lower() + assert "POV: walking beside you" in hybrid + assert "someone" not in hybrid.lower() + assert "both " not in hybrid.lower() + + +def test_hybrid_differs_from_tags_only_when_prose_present(anima): + scene = { + "shot_type": "first_person_pov", + "pov_cue": "walking_together", + "viewer_body_visible": False, + "action_tags": "holding_hands, walking", + "environment_tags": "outdoors, sunlight", + "scene_description": "Shared laughter drifts through the golden afternoon.", + } + tags_only = sp.build_positive_prompt_tags_only(scene, PERSONA_WOLF, "") + hybrid = sp.build_positive_prompt_hybrid(scene, PERSONA_WOLF, "") + assert tags_only != hybrid + assert "Shared laughter" in hybrid + assert "Shared laughter" not in tags_only + + +def test_carrie_doorway_scene(anima): + scene = { + "shot_type": "first_person_pov", + "pov_cue": "doorway_invite", + "viewer_body_visible": False, + "action_tags": "arms_out, inviting_hug, smirk, looking_at_viewer", + "environment_tags": "doorway, apartment, night, indoors", + "scene_description": "She waits in the doorway with playful hunger in half-lidded eyes.", + } + outfit = "crop_top, ripped_jeans, black_jeans" + hybrid = sp.build_positive_prompt_hybrid(scene, PERSONA_CARRIE, outfit) + assert "arms_out" in hybrid + assert "doorway" in hybrid + assert "crop_top" in hybrid + assert "threshold" not in hybrid.lower() + assert "POV: she blocks the doorway" in hybrid + + +def test_pov_inferred_from_action_when_cue_missing(anima): + scene = { + "shot_type": "first_person_pov", + "action_tags": "holding_hands, walking, smiling", + "environment_tags": "outdoors, park", + "scene_description": "", + } + tags = sp.build_positive_prompt_tags_only(scene, PERSONA_WOLF, "") + assert "POV: walking beside you" in tags + + +def test_negative_includes_interaction_block_for_pov_contact(anima): + scene = { + "shot_type": "first_person_pov", + "viewer_body_visible": False, + "action_tags": "arms_out, hug, inviting_hug", + "environment_tags": "doorway", + } + neg = sp._negative_for_scene(scene) + assert "duplicate" in neg + assert "extra_person" in neg + assert "third person" in neg + + +def test_scene_should_generate_false(): + assert sp._scene_should_generate({"should_generate": False}) is False + assert sp._scene_should_generate({"should_generate": True}) is True + assert sp._scene_should_generate({}) is True + + +def test_format_builder_user_block_illustrate_vs_context(anima): + messages = [ + {"role": "assistant", "content": "Long old first_mes " + ("x" * 900)}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "*walks holding your hand*"}, + ] + block = sp._format_builder_user_block(PERSONA_WOLF, messages, "[]") + assert "=== ILLUSTRATE" in block + assert "=== Context" in block + assert "*walks holding your hand*" in block + assert "Long old first_mes" in block + assert len(block.split("Long old first_mes")[1].split("assistant:")[0]) < 900 + + +def test_bundle_from_scene_anima_uses_hybrid_as_tag_full(anima): + scene = { + "should_generate": True, + "shot_type": "first_person_pov", + "pov_cue": "face_to_face", + "action_tags": "smiling", + "environment_tags": "indoors", + "scene_description": "A warm smile greets you.", + } + with patch.object(sp, "anima_dual_enabled", return_value=False): + bundle = sp._bundle_from_scene(scene, PERSONA_WOLF, "") + assert "A warm smile" in bundle.tag_full + assert bundle.desc_full is None + + +def test_user_example_walking_llm_output_cleaned(anima): + """Regression: LLM prose/sentence leakage and second-person refs.""" + scene = sp._sanitize_scene_fields({ + "shot_type": "first_person_pov", + "pov_cue": "walking_together", + "action_tags": ( + "holding_hands, walking, smiling, looking_at_each_other, " + "A wolfgirl walks hand in hand with someone, both smiling and chatting" + ), + "environment_tags": "outdoor, daylight, path", + "scene_description": ( + "A wolfgirl walks hand in hand with someone, both smiling and chatting under the daylight." + ), + }) + persona = {**PERSONA_WOLF, "appearance_tags": PERSONA_WOLF["appearance_tags"] + ", pumped_up"} + tags_only = sp.build_positive_prompt_tags_only(scene, persona, "") + hybrid = sp.build_positive_prompt_hybrid(scene, persona, "") + assert "pumped_up" not in tags_only + assert "someone" not in hybrid.lower() + assert "both " not in hybrid.lower() + assert ". A wolfgirl walks" not in tags_only + assert tags_only != hybrid or not scene.get("scene_description") + + +def test_user_example_carrie_env_reconciled(anima): + scene = sp._sanitize_scene_fields({ + "shot_type": "first_person_pov", + "pov_cue": "doorway_invite", + "action_tags": "arms_out, inviting_hug, smirk, half-lidded_eyes", + "environment_tags": "doorway, nighttime, outdoor", + "scene_description": ( + "Carrie stands in her doorway at night, arms outstretched toward you with a mischievous smirk." + ), + }) + hybrid = sp.build_positive_prompt_hybrid( + scene, PERSONA_CARRIE, "crop_top, ripped_jeans, black_jeans, jeans" + ) + assert "outdoor" not in hybrid.lower() or "doorway" in hybrid + assert ", jeans," not in f", {hybrid}," + assert "someone" not in hybrid.lower() + + +def test_long_first_mes_uses_final_beat(anima): + carrie_tail = ( + "About an hour later...\n\n" + "Carrie stood at her front door, arms out, smirking. " + '"Come on, hug me. Now." It\'s getting cold out.' + ) + long = ("She shops for clothes.\n\n" * 5) + carrie_tail + excerpt = sp._extract_illustrate_content(long) + assert "front door" in excerpt or "hug me" in excerpt + assert "shops for clothes" not in excerpt + + +def test_hybrid_gets_fallback_when_no_scene_description(anima): + scene = sp._sanitize_scene_fields({ + "shot_type": "first_person_pov", + "pov_cue": "walking_together", + "action_tags": "walking, smiling", + "environment_tags": "outdoor, daylight", + "scene_description": "", + }) + tags_only = sp.build_positive_prompt_tags_only(scene, PERSONA_WOLF, "") + hybrid = sp.build_positive_prompt_hybrid(scene, PERSONA_WOLF, "") + assert hybrid != tags_only + assert "afternoon" in hybrid.lower() or "laughter" in hybrid.lower() + + +def test_yuki_pov_drops_lifting_and_nose_rub(anima): + scene = sp._sanitize_scene_fields({ + "shot_type": "first_person_pov", + "pov_cue": "face_to_face", + "action_tags": "arms_out, lifting, nose_rub, smiling", + "environment_tags": "indoors, warm_lighting", + "scene_description": "Her golden eyes soften with warmth toward the camera.", + }) + hybrid = sp.build_positive_prompt_hybrid(scene, {**PERSONA_WOLF, "appearance_tags": "fox_girl, golden_eyes"}, "pink_sweater") + assert "lifting" not in hybrid + assert "nose_rub" not in hybrid + assert "golden" in hybrid.lower() + + +def test_bundle_tags_only_alt_when_dual_compare(anima): + scene = { + "shot_type": "first_person_pov", + "pov_cue": "dialogue_close", + "action_tags": "smiling", + "environment_tags": "indoors", + "scene_description": "Soft light on her face.", + } + with patch.object(sp, "anima_dual_enabled", return_value=True): + bundle = sp._bundle_from_scene(scene, PERSONA_WOLF, "") + assert bundle.desc_full is not None + assert bundle.desc_full != bundle.tag_full + assert "Soft light" in bundle.tag_full + assert "Soft light" not in bundle.desc_full.split(sp.NEGATIVE_PROMPT_SEPARATOR)[0]