Fixed SD RPG
This commit is contained in:
@@ -22,6 +22,7 @@ class CardPatch(BaseModel):
|
||||
first_mes: Optional[str] = None
|
||||
mes_example: Optional[str] = None
|
||||
appearance_tags: Optional[str] = None
|
||||
appearance_prose: Optional[str] = None
|
||||
lora_name: Optional[str] = None
|
||||
lora_weight: Optional[float] = None
|
||||
alternate_greetings_json: Optional[str] = None
|
||||
|
||||
+498
-222
@@ -3,14 +3,12 @@ import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database.db import DB_PATH
|
||||
from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest
|
||||
from services.llm import send_message, stream_message
|
||||
from services.llm import LLMError, send_message, stream_message
|
||||
from services.memory import (
|
||||
get_history,
|
||||
add_message,
|
||||
@@ -18,41 +16,74 @@ from services.memory import (
|
||||
get_or_create_session,
|
||||
get_session,
|
||||
update_session_title,
|
||||
update_session_persona,
|
||||
get_message_count,
|
||||
get_last_assistant_message_id,
|
||||
update_message_image,
|
||||
update_session_facts,
|
||||
update_session_status_quo,
|
||||
update_session_affinity,
|
||||
update_session_genre,
|
||||
update_session_rpg_settings,
|
||||
update_session_outfit,
|
||||
update_session_plot_arc,
|
||||
upsert_quest,
|
||||
get_quests,
|
||||
seed_quests_from_arc,
|
||||
narrator_message_content,
|
||||
parse_narrator_message,
|
||||
add_action_resolution,
|
||||
get_message,
|
||||
update_message_content,
|
||||
delete_messages_after,
|
||||
delete_message,
|
||||
delete_message_and_following,
|
||||
update_message_choices,
|
||||
clear_choices_for_session,
|
||||
upsert_static_system_message,
|
||||
)
|
||||
from services.context_budget import compute_payload_usage, context_warning_line
|
||||
from services.rpg_state import (
|
||||
apply_narrator_post,
|
||||
parse_scene_json,
|
||||
parse_stats_json,
|
||||
scene_prompt_block,
|
||||
affinity_prompt_block,
|
||||
stats_prompt_block,
|
||||
format_narrator_outcome_for_llm,
|
||||
format_user_message_for_llm,
|
||||
)
|
||||
from services.personas import get_persona
|
||||
from services.chat_prompt import get_system_prompt, DEFAULT_PROMPT
|
||||
from services.session_identity import resolve_session_persona
|
||||
from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.rp_sanitize import RP_OUTPUT_REMINDER, strip_ooc_from_reply
|
||||
from services.sd_images import run_sd_for_message
|
||||
from services.character_card import get_character
|
||||
from services import sdbackend as sd_service
|
||||
from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt
|
||||
from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase
|
||||
from services.rpg_facts import extract_facts, merge_facts_persist, facts_to_prompt, rp_day_from_scene
|
||||
from services.rpg_context import format_narrator_context
|
||||
from services.rpg_plot import (
|
||||
generate_plot_arc,
|
||||
process_arc_beats,
|
||||
advance_phase,
|
||||
replenish_arc_beats,
|
||||
reconcile_plot_arc,
|
||||
reconcile_plot_arc,
|
||||
choices_from_beat,
|
||||
choices_from_narrator,
|
||||
)
|
||||
from services.rpg_narrator import narrator_pre, narrator_post
|
||||
from services.opening import ensure_plot_arc_and_quests, resolve_greeting, process_opening
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
||||
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True}
|
||||
DEFAULT_RPG_SETTINGS = {
|
||||
"dice": True,
|
||||
"narrator": True,
|
||||
"quests": True,
|
||||
"affinity": True,
|
||||
"choices": True,
|
||||
"stats": False,
|
||||
}
|
||||
|
||||
|
||||
def get_rpg_settings(session: dict) -> dict:
|
||||
@@ -62,34 +93,61 @@ def get_rpg_settings(session: dict) -> dict:
|
||||
return DEFAULT_RPG_SETTINGS
|
||||
|
||||
|
||||
def affinity_prompt_block(affinity: int) -> str:
|
||||
if affinity >= 10: tone = "very warm, trusting, affectionate"
|
||||
elif affinity >= 5: tone = "friendly and open"
|
||||
elif affinity >= 1: tone = "slightly positive"
|
||||
elif affinity <= -5: tone = "hostile or deeply distrustful"
|
||||
elif affinity <= -1: tone = "cold and wary"
|
||||
else: tone = "neutral"
|
||||
return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---"
|
||||
def build_rpg_runtime_suffix(session: dict, rpg_settings: dict, facts_block: str = "") -> str:
|
||||
runtime_suffix = ""
|
||||
if facts_block:
|
||||
runtime_suffix += "\n\n" + facts_block
|
||||
try:
|
||||
arc = json.loads(session.get("plot_arc_json") or "{}")
|
||||
except Exception:
|
||||
arc = {}
|
||||
if arc:
|
||||
runtime_suffix += "\n\n--- PlotArc ---\n" + json.dumps(
|
||||
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
|
||||
) + "\n---"
|
||||
status_quo = (session.get("status_quo") or "").strip()
|
||||
if status_quo:
|
||||
from services.rp_sanitize import status_quo_prompt_block
|
||||
|
||||
runtime_suffix += status_quo_prompt_block(status_quo)
|
||||
scene = parse_scene_json(session.get("scene_json"))
|
||||
block = scene_prompt_block(scene)
|
||||
if block:
|
||||
runtime_suffix += block
|
||||
if rpg_settings.get("affinity", True):
|
||||
runtime_suffix += affinity_prompt_block(int(session.get("affinity") or 0))
|
||||
if rpg_settings.get("stats", False):
|
||||
stats = parse_stats_json(session.get("narrative_stats_json"))
|
||||
runtime_suffix += stats_prompt_block(stats)
|
||||
return runtime_suffix
|
||||
|
||||
|
||||
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return DEFAULT_PROMPT
|
||||
prompt = persona["prompt"]
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
if persona.get("lorebook_json"):
|
||||
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
if persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card:
|
||||
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt += "\n\n" + lore
|
||||
return prompt
|
||||
def messages_for_llm(history: list, llm_system_content: str) -> list[dict]:
|
||||
"""Build LLM payload: one system message (static + runtime), no duplicate system rows."""
|
||||
out: list[dict] = []
|
||||
system_used = False
|
||||
for m in history:
|
||||
if m["role"] == "system":
|
||||
if not system_used:
|
||||
out.append({"role": "system", "content": llm_system_content})
|
||||
system_used = True
|
||||
elif m["role"] == "narrator":
|
||||
data = parse_narrator_message(m.get("content") or "")
|
||||
if data:
|
||||
out.append({"role": "user", "content": format_narrator_outcome_for_llm(data)})
|
||||
elif m["role"] == "user":
|
||||
has_res = bool(m.get("action_resolution"))
|
||||
out.append({
|
||||
"role": "user",
|
||||
"content": format_user_message_for_llm(
|
||||
m["content"], has_dice_resolution=has_res
|
||||
),
|
||||
})
|
||||
else:
|
||||
out.append({"role": m["role"], "content": m["content"]})
|
||||
if not system_used:
|
||||
out.insert(0, {"role": "system", "content": llm_system_content})
|
||||
return out
|
||||
|
||||
|
||||
@router.get("/history/{session_id}")
|
||||
@@ -100,33 +158,68 @@ async def get_chat_history(session_id: str):
|
||||
@router.get("/system/{session_id}")
|
||||
async def get_system_blob(session_id: str):
|
||||
history = await get_history(session_id)
|
||||
system_msg = next((m for m in history if m.get("role") == "system"), None)
|
||||
session = await get_session(session_id)
|
||||
if session and session.get("rpg_enabled"):
|
||||
persona_id_pre = (session.get("persona_id") or "default")
|
||||
persona_pre = await get_persona(persona_id_pre) or {}
|
||||
await reconcile_plot_arc(
|
||||
session_id,
|
||||
persona_name=persona_pre.get("name", persona_id_pre),
|
||||
recent_context=(session.get("status_quo") or "")[:2000],
|
||||
genre=session.get("genre") or "adventure",
|
||||
)
|
||||
session = await get_session(session_id) or session
|
||||
persona_id = (session.get("persona_id") if session else None) or "default"
|
||||
persona = await get_persona(persona_id) or {}
|
||||
system_msg = next((m for m in history if m.get("role") == "system"), None)
|
||||
stored = system_msg.get("content") if system_msg else ""
|
||||
live_static = await get_system_prompt(persona_id, history, "")
|
||||
system_prompt = live_static if live_static else stored
|
||||
quests = await get_quests(session_id)
|
||||
rpg_settings = get_rpg_settings(session) if session else DEFAULT_RPG_SETTINGS
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]")) if session else ""
|
||||
runtime_suffix = ""
|
||||
if session and session.get("rpg_enabled"):
|
||||
runtime_suffix = build_rpg_runtime_suffix(session, rpg_settings, facts_block)
|
||||
llm_system = system_prompt + runtime_suffix
|
||||
context_usage = compute_payload_usage(history, llm_system)
|
||||
return {
|
||||
"system_prompt": system_msg.get("content") if system_msg else "",
|
||||
"persona_id": persona_id,
|
||||
"persona_name": persona.get("name", persona_id),
|
||||
"system_prompt": system_prompt,
|
||||
"status_quo": session.get("status_quo") if session else "",
|
||||
"global_plot": session.get("global_plot") if session else "",
|
||||
"facts_json": session.get("facts_json") if session else "[]",
|
||||
"plot_arc_json": session.get("plot_arc_json") if session else "{}",
|
||||
"outfit_json": session.get("outfit_json") if session else "[]",
|
||||
"scene_json": session.get("scene_json") if session else "{}",
|
||||
"narrative_stats_json": session.get("narrative_stats_json") if session else "{}",
|
||||
"affinity": session.get("affinity", 0) if session else 0,
|
||||
"genre": session.get("genre", "") if session else "",
|
||||
"rpg_settings_json": session.get("rpg_settings_json") if session else "{}",
|
||||
"rpg_enabled": bool(session.get("rpg_enabled")) if session else False,
|
||||
"quests": quests,
|
||||
"context_usage": context_usage,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def init_chat(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
await get_or_create_session(
|
||||
request.session_id,
|
||||
request.persona_id or "default",
|
||||
)
|
||||
persona_id = await resolve_session_persona(
|
||||
request.session_id,
|
||||
request.persona_id,
|
||||
create_persona=request.persona_id,
|
||||
)
|
||||
history = await get_history(request.session_id)
|
||||
if history:
|
||||
return {"first_mes": None}
|
||||
|
||||
system_prompt = await get_system_prompt(persona_id, [], "")
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
await upsert_static_system_message(request.session_id, system_prompt, [])
|
||||
|
||||
first_mes = None
|
||||
if request.first_mes_override and request.first_mes_override.strip():
|
||||
@@ -152,53 +245,67 @@ class RpgBootstrapRequest(BaseModel):
|
||||
genre: str = "adventure"
|
||||
|
||||
|
||||
class OpeningProcessRequest(BaseModel):
|
||||
session_id: str
|
||||
persona_id: str = "default"
|
||||
rpg: bool = False
|
||||
|
||||
|
||||
@router.post("/opening/process")
|
||||
async def opening_process(req: OpeningProcessRequest):
|
||||
await get_or_create_session(req.session_id, req.persona_id)
|
||||
persona_id = await resolve_session_persona(req.session_id, req.persona_id)
|
||||
try:
|
||||
return await process_opening(req.session_id, persona_id, rpg=req.rpg)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/rpg/bootstrap")
|
||||
async def rpg_bootstrap(req: RpgBootstrapRequest):
|
||||
await get_or_create_session(req.session_id, req.persona_id)
|
||||
session = await get_session(req.session_id)
|
||||
persona = await get_persona(req.persona_id) or {}
|
||||
|
||||
# Save genre
|
||||
persona_id = await resolve_session_persona(req.session_id, req.persona_id)
|
||||
await update_session_genre(req.session_id, req.genre)
|
||||
|
||||
arc_json = (session.get("plot_arc_json") or "{}") if session else "{}"
|
||||
try:
|
||||
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
||||
except Exception:
|
||||
arc = {}
|
||||
if not arc:
|
||||
facts_block = facts_to_prompt((session or {}).get("facts_json", "[]"))
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", req.persona_id),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
persona.get("first_mes", ""),
|
||||
facts_block=facts_block,
|
||||
genre=req.genre,
|
||||
persona = await get_persona(persona_id) or {}
|
||||
greeting = await resolve_greeting(req.session_id, persona)
|
||||
arc = await ensure_plot_arc_and_quests(req.session_id, persona, greeting, req.genre)
|
||||
session = await get_session(req.session_id) or {}
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
if rpg_settings.get("narrator", True) and greeting:
|
||||
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
f"assistant: {greeting}",
|
||||
arc_json,
|
||||
facts_block,
|
||||
is_opening=True,
|
||||
)
|
||||
if arc:
|
||||
from services.memory import update_session_plot_arc
|
||||
await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
|
||||
# Seed quests from beats
|
||||
for beat in arc.get("beats", []):
|
||||
title = (beat.get("title") or beat.get("injection", "")).strip()
|
||||
if title:
|
||||
await upsert_quest(req.session_id, title[:120])
|
||||
|
||||
await apply_narrator_post(req.session_id, post, rpg_settings, session)
|
||||
quests = await get_quests(req.session_id)
|
||||
return {"plot_arc": arc, "quests": quests}
|
||||
updated = await get_session(req.session_id) or {}
|
||||
return {
|
||||
"plot_arc": arc,
|
||||
"quests": quests,
|
||||
"affinity": updated.get("affinity", 0),
|
||||
"scene_json": updated.get("scene_json", "{}"),
|
||||
"narrative_stats_json": updated.get("narrative_stats_json", "{}"),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
await get_or_create_session(request.session_id, request.persona_id)
|
||||
persona_id = await resolve_session_persona(
|
||||
request.session_id,
|
||||
request.persona_id,
|
||||
create_persona=request.persona_id,
|
||||
)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
session = await get_session(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
static_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
runtime_suffix = ""
|
||||
|
||||
arc = {}
|
||||
roll = None
|
||||
@@ -206,26 +313,24 @@ async def chat_stream(request: ChatRequest):
|
||||
resolution_text = ""
|
||||
narrator_msg = None # shown as narrator bubble before assistant reply
|
||||
rpg_settings = {}
|
||||
facts_block = ""
|
||||
|
||||
narrator_extra = ""
|
||||
pre = {}
|
||||
directives: list = []
|
||||
pre_ok = False
|
||||
if session and session.get("rpg_enabled"):
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
if facts_block:
|
||||
system_prompt = system_prompt + "\n\n" + facts_block
|
||||
try:
|
||||
arc = json.loads(session.get("plot_arc_json") or "{}")
|
||||
except Exception:
|
||||
arc = {}
|
||||
if arc:
|
||||
system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps(
|
||||
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
|
||||
) + "\n---"
|
||||
status_quo = (session.get("status_quo") or "").strip()
|
||||
if status_quo:
|
||||
system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---"
|
||||
if rpg_settings.get("affinity", True):
|
||||
aff = int(session.get("affinity") or 0)
|
||||
system_prompt = system_prompt + affinity_prompt_block(aff)
|
||||
|
||||
quests_list = await get_quests(request.session_id)
|
||||
narr_ctx = format_narrator_context(
|
||||
arc, quests_list, session.get("status_quo") or ""
|
||||
)
|
||||
|
||||
if rpg_settings.get("narrator", True):
|
||||
persona = await get_persona(persona_id) or {}
|
||||
@@ -241,7 +346,9 @@ async def chat_stream(request: ChatRequest):
|
||||
json.dumps(arc, ensure_ascii=False) if arc else "",
|
||||
facts_block,
|
||||
request.message,
|
||||
extra_context=narr_ctx,
|
||||
)
|
||||
pre_ok = bool(pre.get("_ok"))
|
||||
|
||||
needs_check = pre.get("needs_check", False) and rpg_settings.get("dice", True)
|
||||
|
||||
@@ -265,6 +372,7 @@ async def chat_stream(request: ChatRequest):
|
||||
request.message,
|
||||
roll=roll,
|
||||
outcome=outcome,
|
||||
extra_context=narr_ctx,
|
||||
)
|
||||
resolution_text = (pre2.get("resolution_text") or "").strip()
|
||||
directives = pre2.get("directives") or []
|
||||
@@ -274,66 +382,95 @@ async def chat_stream(request: ChatRequest):
|
||||
pre_sq = (pre.get("status_quo_update") or "").strip()
|
||||
|
||||
if directives:
|
||||
system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
|
||||
narrator_extra += (
|
||||
"\n\n--- Narrator directives ---\n"
|
||||
+ "\n".join(f"- {d}" for d in directives)
|
||||
+ "\n---"
|
||||
)
|
||||
if pre_sq:
|
||||
await update_session_status_quo(request.session_id, pre_sq)
|
||||
session["status_quo"] = pre_sq
|
||||
|
||||
pre_for_scene = pre2 if needs_check else pre
|
||||
scene_up = pre_for_scene.get("scene_update")
|
||||
if isinstance(scene_up, dict) and scene_up:
|
||||
from services.rpg_state import merge_scene
|
||||
from services.memory import update_session_scene
|
||||
|
||||
merged = merge_scene(
|
||||
parse_scene_json(session.get("scene_json")), scene_up
|
||||
)
|
||||
scene_str = json.dumps(merged, ensure_ascii=False)
|
||||
await update_session_scene(request.session_id, scene_str)
|
||||
session["scene_json"] = scene_str
|
||||
|
||||
if resolution_text:
|
||||
await add_action_resolution(
|
||||
request.session_id,
|
||||
intent_text=request.message,
|
||||
roll=roll,
|
||||
outcome=outcome,
|
||||
resolution_text=resolution_text,
|
||||
message_id=None,
|
||||
)
|
||||
narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text}
|
||||
narrator_msg = {
|
||||
"roll": roll,
|
||||
"outcome": outcome,
|
||||
"text": resolution_text,
|
||||
"original_intent": request.message,
|
||||
}
|
||||
|
||||
# Inject outcome into system prompt so character reply is consistent
|
||||
if roll is not None:
|
||||
system_prompt = (
|
||||
system_prompt
|
||||
+ f"\n\n--- Mechanics ---\n"
|
||||
+ f"Roll d20={roll}. Outcome: {outcome}.\n"
|
||||
+ "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n"
|
||||
+ "---"
|
||||
if roll is not None and resolution_text:
|
||||
narrator_extra += (
|
||||
f"\n\n--- Mechanics (this turn) ---\n"
|
||||
f"Roll d20={roll}. Outcome: {outcome}.\n"
|
||||
f"Narrator resolution: {resolution_text}\n"
|
||||
"The character's next reply MUST match the narrator ruling in the message history "
|
||||
"(immediately after the player's intent). Do NOT re-enact the attempt as full success on failure.\n"
|
||||
"---"
|
||||
)
|
||||
|
||||
# is_narrator_choice: wrap message so LLM understands context
|
||||
runtime_suffix = build_rpg_runtime_suffix(session, rpg_settings, facts_block) + narrator_extra
|
||||
|
||||
llm_system = static_prompt + runtime_suffix
|
||||
if persona_id != "default" or (session and session.get("rpg_enabled")):
|
||||
llm_system += RP_OUTPUT_REMINDER
|
||||
|
||||
user_message_content = request.message
|
||||
if request.is_narrator_choice:
|
||||
user_message_content = f"[Player chose: {request.message}]"
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(system_prompt, request.session_id, request.session_id),
|
||||
)
|
||||
await db.commit()
|
||||
await upsert_static_system_message(request.session_id, static_prompt, history)
|
||||
|
||||
user_msg_id = None
|
||||
if not request.skip_user_add:
|
||||
await add_message(request.session_id, "user", user_message_content)
|
||||
await clear_choices_for_session(request.session_id)
|
||||
user_msg_id = await add_message(request.session_id, "user", user_message_content)
|
||||
if narrator_msg and narrator_msg.get("roll") is not None and user_msg_id:
|
||||
await add_action_resolution(
|
||||
request.session_id,
|
||||
intent_text=request.message,
|
||||
roll=narrator_msg["roll"],
|
||||
outcome=narrator_msg["outcome"],
|
||||
resolution_text=narrator_msg["text"],
|
||||
message_id=user_msg_id,
|
||||
)
|
||||
narrator_msg["user_message_id"] = user_msg_id
|
||||
if narrator_msg and (narrator_msg.get("text") or "").strip():
|
||||
await add_message(
|
||||
request.session_id,
|
||||
"narrator",
|
||||
narrator_message_content(narrator_msg),
|
||||
)
|
||||
messages = await get_history(request.session_id)
|
||||
usage = compute_payload_usage(messages, llm_system)
|
||||
warn = context_warning_line(usage.get("percent", 0))
|
||||
if warn:
|
||||
llm_system += warn
|
||||
llm_messages = messages_for_llm(messages, llm_system)
|
||||
|
||||
full_reply = []
|
||||
|
||||
async def generate():
|
||||
nonlocal arc
|
||||
|
||||
# Send narrator BEFORE streaming so it appears above the reply
|
||||
if narrator_msg:
|
||||
yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n"
|
||||
|
||||
try:
|
||||
async for chunk in stream_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
):
|
||||
async for chunk in stream_message(llm_messages):
|
||||
full_reply.append(chunk)
|
||||
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||||
except Exception as e:
|
||||
@@ -342,99 +479,176 @@ async def chat_stream(request: ChatRequest):
|
||||
return
|
||||
|
||||
complete = "".join(full_reply)
|
||||
display_text = strip_image_prompt_tag(complete)
|
||||
raw_display = strip_image_prompt_tag(complete)
|
||||
display_text = strip_ooc_from_reply(raw_display)
|
||||
|
||||
hist_with_reply = await get_history(request.session_id) + [
|
||||
{"role": "assistant", "content": display_text}
|
||||
]
|
||||
sd_result = await generate_sd_prompt(
|
||||
hist_with_reply, persona_id,
|
||||
outfit_json=session.get("outfit_json", "[]") if session else "[]"
|
||||
)
|
||||
prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete)
|
||||
|
||||
if (display_text or complete).strip():
|
||||
await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str)
|
||||
if (display_text or raw_display).strip():
|
||||
await add_message(request.session_id, "assistant", display_text or raw_display)
|
||||
|
||||
choices = []
|
||||
debug_blocks = []
|
||||
quests_updated = []
|
||||
narrator_meta = {}
|
||||
|
||||
if session and session.get("rpg_enabled"):
|
||||
if not arc:
|
||||
persona = await get_persona(persona_id) or {}
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", persona_id),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
persona.get("first_mes", ""),
|
||||
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
|
||||
genre=session.get("genre") or "adventure",
|
||||
)
|
||||
try:
|
||||
if not arc:
|
||||
persona = await get_persona(persona_id) or {}
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", persona_id),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
persona.get("first_mes", ""),
|
||||
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
|
||||
genre=session.get("genre") or "adventure",
|
||||
)
|
||||
if arc:
|
||||
await update_session_plot_arc(
|
||||
request.session_id, json.dumps(arc, ensure_ascii=False)
|
||||
)
|
||||
debug_blocks.append({
|
||||
"type": "plot_arc",
|
||||
"text": json.dumps(arc, ensure_ascii=False, indent=2),
|
||||
})
|
||||
if rpg_settings.get("quests", True):
|
||||
await seed_quests_from_arc(request.session_id, arc)
|
||||
|
||||
quests_list = await get_quests(request.session_id)
|
||||
if arc:
|
||||
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)})
|
||||
if rpg_settings.get("quests", True):
|
||||
for beat in arc.get("beats", []):
|
||||
t = (beat.get("title") or beat.get("injection", "")).strip()
|
||||
if t:
|
||||
await upsert_quest(request.session_id, t[:120])
|
||||
beat_ctx = "\n".join(
|
||||
f"{m['role']}: {m['content']}"
|
||||
for m in (await get_history(request.session_id))[-6:]
|
||||
if m.get("role") in ("user", "assistant")
|
||||
)
|
||||
arc, beats, pruned, beat_mode = await process_arc_beats(
|
||||
arc,
|
||||
quests_list,
|
||||
request.message,
|
||||
recent_context=beat_ctx,
|
||||
last_dice_outcome=outcome if roll is not None else None,
|
||||
)
|
||||
if pruned or beats:
|
||||
await update_session_plot_arc(
|
||||
request.session_id, json.dumps(arc, ensure_ascii=False)
|
||||
)
|
||||
if pruned:
|
||||
debug_blocks.append({
|
||||
"type": "plot_arc_prune",
|
||||
"text": f"Removed {len(pruned)} beat(s) already completed as quests",
|
||||
})
|
||||
if beats:
|
||||
inj = beats[0].get("injection", "")
|
||||
if inj:
|
||||
debug_blocks.append({"type": "narrator_injection", "text": inj})
|
||||
if rpg_settings.get("choices", True):
|
||||
choices += choices_from_beat(beats[0])
|
||||
if beat_mode in ("after_dice", "llm", "trigger", "stuck_recovery"):
|
||||
debug_blocks.append({
|
||||
"type": "plot_arc",
|
||||
"text": (
|
||||
f"Beat fired ({beat_mode}): "
|
||||
f"«{beats[0].get('title', '')}»"
|
||||
),
|
||||
})
|
||||
if advance_phase(arc):
|
||||
await update_session_plot_arc(
|
||||
request.session_id, json.dumps(arc, ensure_ascii=False)
|
||||
)
|
||||
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
|
||||
if pruned and not arc.get("beats"):
|
||||
narrator_meta["arc_pruned"] = len(pruned)
|
||||
if beat_mode:
|
||||
narrator_meta["beat_mode"] = beat_mode
|
||||
|
||||
trig = should_advance_arc(request.message)
|
||||
if trig and arc:
|
||||
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
|
||||
if beats:
|
||||
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
inj = beats[0].get("injection", "")
|
||||
if inj:
|
||||
debug_blocks.append({"type": "narrator_injection", "text": inj})
|
||||
if rpg_settings.get("choices", True):
|
||||
choices += beats[0].get("choices") or []
|
||||
if advance_phase(arc):
|
||||
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
|
||||
ctx = [
|
||||
m for m in (await get_history(request.session_id))
|
||||
if m["role"] in ("user", "assistant")
|
||||
][-10:]
|
||||
new_facts = await extract_facts(
|
||||
ctx,
|
||||
rp_day_hint=rp_day_from_scene(session.get("scene_json")),
|
||||
existing_json=session.get("facts_json", "[]"),
|
||||
)
|
||||
if new_facts:
|
||||
merged = await merge_facts_persist(
|
||||
session.get("facts_json", "[]"),
|
||||
new_facts,
|
||||
rp_day_default=rp_day_from_scene(session.get("scene_json")),
|
||||
scene_context=json.dumps(
|
||||
parse_scene_json(session.get("scene_json")),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
status_quo=session.get("status_quo") or "",
|
||||
)
|
||||
await update_session_facts(request.session_id, merged)
|
||||
session["facts_json"] = merged
|
||||
|
||||
ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:]
|
||||
new_facts = await extract_facts(ctx)
|
||||
if new_facts:
|
||||
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
|
||||
await update_session_facts(request.session_id, merged)
|
||||
session["facts_json"] = merged
|
||||
persona = await get_persona(persona_id) or {}
|
||||
ctx_txt = "\n".join(
|
||||
f"{m['role']}: {m['content']}"
|
||||
for m in ctx[-8:]
|
||||
if m.get("role") in ("user", "assistant")
|
||||
)
|
||||
narr_ctx_post = format_narrator_context(
|
||||
arc, await get_quests(request.session_id), session.get("status_quo") or ""
|
||||
)
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
json.dumps(arc, ensure_ascii=False) if arc else "",
|
||||
facts_to_prompt(session.get("facts_json", "[]")),
|
||||
extra_context=narr_ctx_post,
|
||||
)
|
||||
|
||||
persona = await get_persona(persona_id) or {}
|
||||
ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant"))
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
json.dumps(arc, ensure_ascii=False) if arc else "",
|
||||
facts_to_prompt(session.get("facts_json", "[]")),
|
||||
)
|
||||
sq = (post.get("status_quo_update") or "").strip()
|
||||
if sq:
|
||||
debug_blocks.append({"type": "status_quo", "text": sq})
|
||||
|
||||
sq = (post.get("status_quo_update") or "").strip()
|
||||
if sq:
|
||||
await update_session_status_quo(request.session_id, sq)
|
||||
debug_blocks.append({"type": "status_quo", "text": sq})
|
||||
if rpg_settings.get("choices", True):
|
||||
choices += choices_from_narrator(post.get("choices") or [])
|
||||
|
||||
if rpg_settings.get("choices", True):
|
||||
choices += post.get("choices") or []
|
||||
applied = await apply_narrator_post(
|
||||
request.session_id, post, rpg_settings, session
|
||||
)
|
||||
narrator_meta = {
|
||||
"pre_ok": pre_ok,
|
||||
"post_ok": bool(post.get("_ok")),
|
||||
"choices_count": len(choices),
|
||||
"directives_count": len(directives),
|
||||
"dice": roll is not None,
|
||||
**applied,
|
||||
}
|
||||
|
||||
if rpg_settings.get("affinity", True):
|
||||
delta = int(post.get("affinity_delta") or 0)
|
||||
if delta:
|
||||
await update_session_affinity(request.session_id, delta)
|
||||
if not arc.get("beats"):
|
||||
persona = await get_persona(persona_id) or {}
|
||||
arc = await replenish_arc_beats(
|
||||
arc,
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
await get_quests(request.session_id),
|
||||
session.get("genre") or "adventure",
|
||||
)
|
||||
if arc.get("beats"):
|
||||
await update_session_plot_arc(
|
||||
request.session_id, json.dumps(arc, ensure_ascii=False)
|
||||
)
|
||||
debug_blocks.append({
|
||||
"type": "plot_arc",
|
||||
"text": f"Added {len(arc.get('beats', []))} new plot beats",
|
||||
})
|
||||
narrator_meta["beats_replenished"] = len(arc.get("beats", []))
|
||||
if rpg_settings.get("quests", True):
|
||||
await seed_quests_from_arc(request.session_id, arc)
|
||||
outfit_update = post.get("outfit_update")
|
||||
if isinstance(outfit_update, list) and outfit_update:
|
||||
from services.outfit_tags import outfit_list_to_json
|
||||
|
||||
outfit_update = post.get("outfit_update")
|
||||
if isinstance(outfit_update, list) and outfit_update:
|
||||
outfit_str = json.dumps(outfit_update, ensure_ascii=False)
|
||||
await update_session_outfit(request.session_id, outfit_str)
|
||||
session["outfit_json"] = outfit_str
|
||||
|
||||
if rpg_settings.get("quests", True):
|
||||
for qu in (post.get("quest_updates") or []):
|
||||
t = (qu.get("title") or "").strip()
|
||||
if t:
|
||||
await upsert_quest(request.session_id, t[:120], qu.get("status", "active"))
|
||||
session["outfit_json"] = outfit_list_to_json(outfit_update)
|
||||
quests_updated = await get_quests(request.session_id)
|
||||
except LLMError as e:
|
||||
logger.warning("RPG post-process skipped after reply: %s", e)
|
||||
except Exception as e:
|
||||
logger.exception("RPG post-process failed after reply: %s", e)
|
||||
|
||||
count = await get_message_count(request.session_id)
|
||||
if count == 2 and not request.skip_user_add:
|
||||
@@ -443,23 +657,63 @@ async def chat_stream(request: ChatRequest):
|
||||
if (session or {}).get("title", "Новый чат") in ("", "Новый чат"):
|
||||
await update_session_title(request.session_id, f"{persona.get('name', persona_id)} — {preview}")
|
||||
|
||||
image_path = None
|
||||
image_error = None
|
||||
if prompt_str and SD_AUTO_GENERATE:
|
||||
updated_session = await get_session(request.session_id) or session
|
||||
hist = await get_history(request.session_id)
|
||||
bundle = await generate_sd_prompt(
|
||||
hist,
|
||||
persona_id,
|
||||
outfit_json=updated_session.get("outfit_json", "[]") if updated_session else "[]",
|
||||
scene_json=updated_session.get("scene_json", "{}") if updated_session else "{}",
|
||||
)
|
||||
prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(complete)
|
||||
msg_id = await get_last_assistant_message_id(request.session_id)
|
||||
if msg_id and choices:
|
||||
await update_message_choices(
|
||||
msg_id, json.dumps(choices, ensure_ascii=False)
|
||||
)
|
||||
|
||||
sd_out: dict = {}
|
||||
if bundle:
|
||||
yield f"data: {json.dumps({
|
||||
'image_generating': True,
|
||||
'image_prompt': bundle.tag_full,
|
||||
'image_prompt_alt': bundle.desc_full,
|
||||
})}\n\n"
|
||||
sd_out = await run_sd_for_message(bundle, msg_id)
|
||||
elif prompt_str and SD_AUTO_GENERATE:
|
||||
yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n"
|
||||
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
|
||||
if rel:
|
||||
image_path = rel
|
||||
msg_id = await get_last_assistant_message_id(request.session_id)
|
||||
sd_out["image_path"] = f"/static/{rel}"
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
else:
|
||||
image_error = err
|
||||
sd_out["image_error"] = err
|
||||
sd_out["image_prompt"] = prompt_str
|
||||
|
||||
updated_session = await get_session(request.session_id)
|
||||
affinity = updated_session.get("affinity", 0) if updated_session else 0
|
||||
done_payload = {
|
||||
"done": True,
|
||||
"assistant_message_id": msg_id,
|
||||
"assistant_content": display_text or raw_display,
|
||||
"image_prompt": sd_out.get("image_prompt") or prompt_str,
|
||||
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
||||
"image_path": sd_out.get("image_path"),
|
||||
"image_path_alt": sd_out.get("image_path_alt"),
|
||||
"image_error": sd_out.get("image_error"),
|
||||
"image_error_alt": sd_out.get("image_error_alt"),
|
||||
"choices": choices,
|
||||
"debug": debug_blocks,
|
||||
"affinity": affinity,
|
||||
"quests": quests_updated,
|
||||
"narrator_meta": narrator_meta,
|
||||
}
|
||||
if rpg_settings.get("stats") and updated_session:
|
||||
done_payload["narrative_stats"] = parse_stats_json(
|
||||
updated_session.get("narrative_stats_json")
|
||||
)
|
||||
|
||||
yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n"
|
||||
yield f"data: {json.dumps(done_payload)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
@@ -470,23 +724,37 @@ async def chat_stream(request: ChatRequest):
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
await get_or_create_session(request.session_id, request.persona_id)
|
||||
persona_id = await resolve_session_persona(
|
||||
request.session_id,
|
||||
request.persona_id,
|
||||
create_persona=request.persona_id,
|
||||
)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
static_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
await upsert_static_system_message(request.session_id, static_prompt, history)
|
||||
|
||||
await add_message(request.session_id, "user", request.message)
|
||||
messages = await get_history(request.session_id)
|
||||
reply = await send_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
session = await get_session(request.session_id)
|
||||
llm_system = static_prompt
|
||||
if session and session.get("rpg_enabled"):
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
llm_system += build_rpg_runtime_suffix(session, rpg_settings, facts_block)
|
||||
if persona_id != "default" or (session and session.get("rpg_enabled")):
|
||||
llm_system += RP_OUTPUT_REMINDER
|
||||
llm_messages = messages_for_llm(messages, llm_system)
|
||||
reply = await send_message(llm_messages)
|
||||
display = strip_ooc_from_reply(strip_image_prompt_tag(reply))
|
||||
bundle = await generate_sd_prompt(
|
||||
messages,
|
||||
persona_id,
|
||||
outfit_json=session.get("outfit_json", "[]") if session else "[]",
|
||||
scene_json=session.get("scene_json", "{}") if session else "{}",
|
||||
)
|
||||
display = strip_image_prompt_tag(reply)
|
||||
prompt_tuple = await generate_sd_prompt(messages, persona_id)
|
||||
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
|
||||
prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(reply)
|
||||
|
||||
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
|
||||
|
||||
@@ -497,6 +765,15 @@ async def chat(request: ChatRequest):
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/messages/{message_id}")
|
||||
async def remove_message(message_id: int):
|
||||
msg = await get_message(message_id)
|
||||
if not msg:
|
||||
raise HTTPException(status_code=404, detail="Сообщение не найдено")
|
||||
await delete_message_and_following(msg["session_id"], message_id)
|
||||
return {"status": "deleted", "message_id": message_id}
|
||||
|
||||
|
||||
@router.patch("/messages/{message_id}")
|
||||
async def edit_message(message_id: int, req: MessageEditRequest):
|
||||
msg = await get_message(message_id)
|
||||
@@ -527,7 +804,6 @@ async def regenerate_chat(req: RegenerateRequest):
|
||||
stream_req = ChatRequest(
|
||||
message=user_text,
|
||||
session_id=req.session_id,
|
||||
persona_id=req.persona_id,
|
||||
skip_user_add=True,
|
||||
)
|
||||
return await chat_stream(stream_req)
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services import sdbackend as sd_service
|
||||
from services.comfy_models import list_node_types, parse_model_lists
|
||||
from services.llm import (
|
||||
CHAT_MODEL,
|
||||
LLM_FALLBACK_MODEL,
|
||||
LLMError,
|
||||
SYSTEM_MODEL,
|
||||
send_message,
|
||||
send_message_with_model,
|
||||
)
|
||||
from services.personas import get_all_personas
|
||||
from services.sd_prompt import (
|
||||
SD_PROMPT_MODEL,
|
||||
anima_dual_enabled,
|
||||
run_prompt_builder,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/debug", tags=["debug"])
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class SdPromptDebugRequest(BaseModel):
|
||||
persona_id: str = "default"
|
||||
chat_excerpt: str = ""
|
||||
messages: list[ChatMessage] | None = None
|
||||
outfit_json: str = "[]"
|
||||
appearance_override: str | None = None
|
||||
use_prose: bool = False
|
||||
|
||||
|
||||
class LlmDebugRequest(BaseModel):
|
||||
model: str = ""
|
||||
system: str = ""
|
||||
user: str = ""
|
||||
messages: list[ChatMessage] | None = None
|
||||
|
||||
|
||||
class ComfyRawRequest(BaseModel):
|
||||
method: str = "GET"
|
||||
path: str = "/system_stats"
|
||||
params_json: str = "{}"
|
||||
body_json: str = ""
|
||||
|
||||
|
||||
class ComfyGenerateRequest(BaseModel):
|
||||
positive: str
|
||||
negative: str = ""
|
||||
unet: str | None = None
|
||||
clip: str | None = None
|
||||
vae: str | None = None
|
||||
checkpoint: str | None = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def debug_config():
|
||||
base = sd_service.SD_BASE_URL
|
||||
return {
|
||||
"chat_model": CHAT_MODEL,
|
||||
"system_model": SYSTEM_MODEL,
|
||||
"llm_fallback_model": LLM_FALLBACK_MODEL,
|
||||
"sd_prompt_model": SD_PROMPT_MODEL or SYSTEM_MODEL,
|
||||
"sd_base_url": base,
|
||||
"sd_has_token": bool(sd_service.SD_QUERY_PARAMS.get("token")),
|
||||
"sd_anima_dual": anima_dual_enabled(),
|
||||
"sd_unet": sd_service.SD_UNET,
|
||||
"sd_clip": sd_service.SD_CLIP,
|
||||
"sd_vae": sd_service.SD_VAE,
|
||||
"sd_checkpoint": sd_service.SD_CHECKPOINT,
|
||||
"sd_steps": sd_service.SD_STEPS,
|
||||
"sd_cfg": sd_service.SD_CFG,
|
||||
"router_key_set": bool(os.getenv("ROUTER_KEY")),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/personas")
|
||||
async def debug_personas():
|
||||
personas = await get_all_personas()
|
||||
return [
|
||||
{
|
||||
"persona_id": pid,
|
||||
"name": p.get("name", pid),
|
||||
"appearance_tags": p.get("appearance_tags", ""),
|
||||
}
|
||||
for pid, p in personas.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/sd-prompt")
|
||||
async def debug_sd_prompt(req: SdPromptDebugRequest):
|
||||
msgs = None
|
||||
if req.messages:
|
||||
msgs = [m.model_dump() for m in req.messages]
|
||||
return await run_prompt_builder(
|
||||
req.persona_id,
|
||||
messages=msgs,
|
||||
chat_excerpt=req.chat_excerpt,
|
||||
outfit_json=req.outfit_json,
|
||||
appearance_override=req.appearance_override,
|
||||
use_prose=req.use_prose,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/llm")
|
||||
async def debug_llm(req: LlmDebugRequest):
|
||||
if req.messages:
|
||||
messages = [m.model_dump() for m in req.messages]
|
||||
else:
|
||||
messages = []
|
||||
if req.system.strip():
|
||||
messages.append({"role": "system", "content": req.system.strip()})
|
||||
if req.user.strip():
|
||||
messages.append({"role": "user", "content": req.user.strip()})
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="Нужны messages или system/user")
|
||||
|
||||
model = (req.model or "").strip() or SD_PROMPT_MODEL or SYSTEM_MODEL
|
||||
try:
|
||||
if model in (SYSTEM_MODEL, "") and not req.model:
|
||||
text = await send_message(messages)
|
||||
else:
|
||||
text = await send_message_with_model(messages, model)
|
||||
return {"model": model, "response": text}
|
||||
except LLMError as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/comfy/ping")
|
||||
async def debug_comfy_ping():
|
||||
try:
|
||||
status, body, headers = await sd_service.comfy_api_request("GET", "/system_stats")
|
||||
return {"ok": status == 200, "status": status, "body": body, "headers": headers}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/comfy/models")
|
||||
async def debug_comfy_models():
|
||||
try:
|
||||
info = await sd_service.fetch_object_info()
|
||||
return {
|
||||
"models": parse_model_lists(info),
|
||||
"configured": {
|
||||
"unet": sd_service.SD_UNET,
|
||||
"clip": sd_service.SD_CLIP,
|
||||
"vae": sd_service.SD_VAE,
|
||||
"checkpoint": sd_service.SD_CHECKPOINT,
|
||||
},
|
||||
"node_type_count": len(list_node_types(info)),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/comfy/object_info")
|
||||
async def debug_comfy_object_info(node: str | None = None):
|
||||
try:
|
||||
info = await sd_service.fetch_object_info()
|
||||
if node:
|
||||
if node not in info:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown node: {node}")
|
||||
return {node: info[node]}
|
||||
return {
|
||||
"node_types": list_node_types(info),
|
||||
"models": parse_model_lists(info),
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/comfy/raw")
|
||||
async def debug_comfy_raw(req: ComfyRawRequest):
|
||||
path = req.path.strip()
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
try:
|
||||
params = json.loads(req.params_json or "{}")
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("params_json must be object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"params_json: {e}")
|
||||
|
||||
body = None
|
||||
if req.body_json.strip():
|
||||
try:
|
||||
body = json.loads(req.body_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"body_json: {e}")
|
||||
|
||||
method = req.method.upper()
|
||||
if method not in ("GET", "POST", "PUT", "DELETE"):
|
||||
raise HTTPException(status_code=400, detail="method must be GET|POST|PUT|DELETE")
|
||||
|
||||
try:
|
||||
status, resp_body, headers = await sd_service.comfy_api_request(
|
||||
method,
|
||||
path,
|
||||
params=params or None,
|
||||
json_body=body,
|
||||
timeout=120,
|
||||
)
|
||||
return {"status": status, "headers": headers, "body": resp_body}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/comfy/generate")
|
||||
async def debug_comfy_generate(req: ComfyGenerateRequest):
|
||||
if not req.positive.strip():
|
||||
raise HTTPException(status_code=400, detail="positive required")
|
||||
|
||||
overrides: dict[str, str] = {}
|
||||
if req.unet:
|
||||
overrides["unet"] = req.unet
|
||||
if req.clip:
|
||||
overrides["clip"] = req.clip
|
||||
if req.vae:
|
||||
overrides["vae"] = req.vae
|
||||
if req.checkpoint:
|
||||
overrides["checkpoint"] = req.checkpoint
|
||||
|
||||
full = req.positive.strip()
|
||||
if req.negative.strip():
|
||||
full += f"\n\nNegative prompt: {req.negative.strip()}"
|
||||
|
||||
try:
|
||||
rel, err = await sd_service.generate_from_full_prompt(
|
||||
full,
|
||||
overrides=overrides or None,
|
||||
)
|
||||
if not rel:
|
||||
raise HTTPException(status_code=502, detail=err or "generation failed")
|
||||
return {"image_path": f"/static/{rel}", "status": "ok"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
@@ -57,6 +57,7 @@ class PersonaPatch(BaseModel):
|
||||
lora_name: Optional[str] = None
|
||||
lora_weight: Optional[float] = None
|
||||
appearance_tags: Optional[str] = None
|
||||
appearance_prose: Optional[str] = None
|
||||
personality: Optional[str] = None
|
||||
scenario: Optional[str] = None
|
||||
first_mes: Optional[str] = None
|
||||
|
||||
+195
-2
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from services.memory import (
|
||||
get_all_sessions,
|
||||
get_session,
|
||||
get_or_create_session,
|
||||
get_history,
|
||||
delete_session,
|
||||
update_session_title,
|
||||
update_session_persona,
|
||||
@@ -14,10 +17,27 @@ from services.memory import (
|
||||
update_session_genre,
|
||||
update_session_rpg_settings,
|
||||
get_quests,
|
||||
update_quest_by_id,
|
||||
set_session_affinity,
|
||||
update_session_narrative_stats,
|
||||
update_session_outfit,
|
||||
update_session_scene,
|
||||
update_session_plot_arc,
|
||||
get_last_message_preview,
|
||||
fork_session,
|
||||
)
|
||||
from models.schemas import ForkSessionRequest
|
||||
from models.schemas import (
|
||||
ForkSessionRequest,
|
||||
RebindPersonaRequest,
|
||||
QuestStatusPatch,
|
||||
RpgStateDebugPatch,
|
||||
SessionContextPatch,
|
||||
)
|
||||
from services.rpg_plot import reconcile_plot_arc
|
||||
from services.rpg_state import parse_stats_json
|
||||
from services.chat_prompt import get_system_prompt
|
||||
from services.memory import rebind_session_persona
|
||||
from services.personas import get_persona
|
||||
|
||||
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||
|
||||
@@ -35,9 +55,149 @@ async def list_sessions():
|
||||
|
||||
@router.get("/{session_id}/quests")
|
||||
async def list_quests(session_id: str):
|
||||
session = await get_session(session_id)
|
||||
if session and session.get("rpg_enabled"):
|
||||
persona = await get_persona(session.get("persona_id") or "default") or {}
|
||||
await reconcile_plot_arc(
|
||||
session_id,
|
||||
persona_name=persona.get("name", session.get("persona_id") or "Character"),
|
||||
recent_context=(session.get("status_quo") or "")[:2000],
|
||||
genre=session.get("genre") or "adventure",
|
||||
)
|
||||
return await get_quests(session_id)
|
||||
|
||||
|
||||
@router.patch("/{session_id}/context")
|
||||
async def patch_session_context(session_id: str, body: SessionContextPatch):
|
||||
"""Live-edit session context (outfit, scene, plot, facts, status quo)."""
|
||||
from services.outfit_tags import parse_and_normalize_outfit_json
|
||||
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Сессия не найдена")
|
||||
|
||||
if body.status_quo is not None:
|
||||
await update_session_status_quo(session_id, body.status_quo)
|
||||
if body.global_plot is not None:
|
||||
await update_session_global_plot(session_id, body.global_plot)
|
||||
if body.outfit_json is not None:
|
||||
try:
|
||||
normalized = parse_and_normalize_outfit_json(body.outfit_json)
|
||||
json.loads(normalized)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"outfit_json: {e}") from e
|
||||
await update_session_outfit(session_id, normalized)
|
||||
if body.scene_json is not None:
|
||||
try:
|
||||
json.loads(body.scene_json or "{}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"scene_json: {e}") from e
|
||||
await update_session_scene(session_id, body.scene_json)
|
||||
if body.facts_json is not None:
|
||||
from services.rpg_facts import (
|
||||
parse_facts_list,
|
||||
facts_list_to_json,
|
||||
dedupe_facts_fuzzy,
|
||||
compress_facts,
|
||||
FACTS_DEDUP_THRESHOLD,
|
||||
FACTS_COMPRESS_TARGET,
|
||||
)
|
||||
|
||||
try:
|
||||
facts = dedupe_facts_fuzzy(parse_facts_list(body.facts_json))
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"facts_json: {e}") from e
|
||||
if len(facts) > FACTS_DEDUP_THRESHOLD:
|
||||
facts = await compress_facts(
|
||||
facts,
|
||||
status_quo=(session.get("status_quo") or ""),
|
||||
scene_context=session.get("scene_json") or "{}",
|
||||
target=FACTS_COMPRESS_TARGET,
|
||||
)
|
||||
facts = dedupe_facts_fuzzy(facts)
|
||||
normalized = facts_list_to_json(facts)
|
||||
await update_session_facts(session_id, normalized)
|
||||
if body.plot_arc_json is not None:
|
||||
try:
|
||||
json.loads(body.plot_arc_json or "{}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"plot_arc_json: {e}") from e
|
||||
await update_session_plot_arc(session_id, body.plot_arc_json)
|
||||
if body.affinity is not None:
|
||||
await set_session_affinity(session_id, body.affinity)
|
||||
stats_changed = any(
|
||||
getattr(body, k) is not None for k in ("lust", "stamina", "tension")
|
||||
)
|
||||
if stats_changed:
|
||||
stats = parse_stats_json(session.get("narrative_stats_json"))
|
||||
for key in ("lust", "stamina", "tension"):
|
||||
val = getattr(body, key, None)
|
||||
if val is not None:
|
||||
stats[key] = max(0, min(10, int(val)))
|
||||
await update_session_narrative_stats(
|
||||
session_id, json.dumps(stats, ensure_ascii=False)
|
||||
)
|
||||
|
||||
updated = await get_session(session_id) or session
|
||||
return {
|
||||
"status": "updated",
|
||||
"outfit_json": updated.get("outfit_json", "[]"),
|
||||
"scene_json": updated.get("scene_json", "{}"),
|
||||
"affinity": updated.get("affinity", 0),
|
||||
"narrative_stats": parse_stats_json(updated.get("narrative_stats_json")),
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{session_id}/rpg-state")
|
||||
async def patch_rpg_state(session_id: str, body: RpgStateDebugPatch):
|
||||
"""Debug: set affinity and/or narrative stats (lust/stamina/tension 0–10)."""
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Сессия не найдена")
|
||||
affinity = session.get("affinity", 0)
|
||||
if body.affinity is not None:
|
||||
affinity = await set_session_affinity(session_id, body.affinity)
|
||||
stats = parse_stats_json(session.get("narrative_stats_json"))
|
||||
changed_stats = False
|
||||
for key in ("lust", "stamina", "tension"):
|
||||
val = getattr(body, key, None)
|
||||
if val is not None:
|
||||
stats[key] = max(0, min(10, int(val)))
|
||||
changed_stats = True
|
||||
if changed_stats:
|
||||
import json as _json
|
||||
|
||||
await update_session_narrative_stats(
|
||||
session_id, _json.dumps(stats, ensure_ascii=False)
|
||||
)
|
||||
return {
|
||||
"affinity": affinity,
|
||||
"narrative_stats": stats,
|
||||
"target": "current_player",
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{session_id}/quests/{quest_id}")
|
||||
async def patch_quest(session_id: str, quest_id: int, body: QuestStatusPatch):
|
||||
status = body.status.strip()
|
||||
if status not in ("active", "done", "failed"):
|
||||
raise HTTPException(status_code=400, detail="status must be active, done, or failed")
|
||||
ok = await update_quest_by_id(quest_id, session_id, status)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Quest not found")
|
||||
if status in ("done", "failed"):
|
||||
session = await get_session(session_id)
|
||||
if session and session.get("rpg_enabled"):
|
||||
persona = await get_persona(session.get("persona_id") or "default") or {}
|
||||
await reconcile_plot_arc(
|
||||
session_id,
|
||||
persona_name=persona.get("name", session.get("persona_id") or "Character"),
|
||||
recent_context=(session.get("status_quo") or "")[:2000],
|
||||
genre=session.get("genre") or "adventure",
|
||||
)
|
||||
return {"status": "updated", "quest_id": quest_id, "new_status": status}
|
||||
|
||||
|
||||
@router.get("/{session_id}")
|
||||
async def get_session_route(session_id: str):
|
||||
s = await get_session(session_id)
|
||||
@@ -46,9 +206,42 @@ async def get_session_route(session_id: str):
|
||||
return s
|
||||
|
||||
|
||||
@router.post("/{session_id}/rebind-persona")
|
||||
async def rebind_persona(session_id: str, body: RebindPersonaRequest):
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Сессия не найдена")
|
||||
persona = await get_persona(body.persona_id)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=400, detail="Персонаж не найден")
|
||||
|
||||
hist = [] if body.clear_history else await get_history(session_id)
|
||||
static = await get_system_prompt(body.persona_id, hist, "")
|
||||
first_mes = (persona.get("first_mes") or "").strip() if body.clear_history else None
|
||||
|
||||
try:
|
||||
await rebind_session_persona(
|
||||
session_id,
|
||||
body.persona_id,
|
||||
clear_history=body.clear_history,
|
||||
static_prompt=static,
|
||||
first_mes=first_mes or None,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
return {
|
||||
"persona_id": body.persona_id,
|
||||
"persona_name": persona.get("name", body.persona_id),
|
||||
"system_prompt_preview": static[:500],
|
||||
"clear_history": body.clear_history,
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{session_id}")
|
||||
async def patch_session(session_id: str, data: dict):
|
||||
await get_or_create_session(session_id, data.get("persona_id", "default"))
|
||||
create_pid = data.get("persona_id") if "persona_id" in data else None
|
||||
await get_or_create_session(session_id, create_pid)
|
||||
if "title" in data:
|
||||
await update_session_title(session_id, data["title"])
|
||||
if "persona_id" in data:
|
||||
|
||||
Reference in New Issue
Block a user