Fixed SD Promt

This commit is contained in:
2026-06-02 15:03:39 +03:00
parent d4cd8f02f4
commit 03cbda5dce
46 changed files with 3285 additions and 429 deletions
+1
View File
@@ -22,6 +22,7 @@ class CardPatch(BaseModel):
first_mes: Optional[str] = None
mes_example: Optional[str] = None
appearance_tags: Optional[str] = None
appearance_prose: Optional[str] = None
lora_name: Optional[str] = None
lora_weight: Optional[float] = None
alternate_greetings_json: Optional[str] = None
+219 -187
View File
@@ -3,14 +3,12 @@ import logging
import os
import random
import aiosqlite
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from database.db import DB_PATH
from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest
from services.llm import send_message, stream_message
from services.llm import LLMError, send_message, stream_message
from services.memory import (
get_history,
add_message,
@@ -18,7 +16,6 @@ from services.memory import (
get_or_create_session,
get_session,
update_session_title,
update_session_persona,
get_message_count,
get_last_assistant_message_id,
update_message_image,
@@ -26,7 +23,6 @@ from services.memory import (
update_session_status_quo,
update_session_affinity,
update_session_genre,
update_session_rpg_settings,
update_session_outfit,
update_session_plot_arc,
upsert_quest,
@@ -36,20 +32,23 @@ from services.memory import (
update_message_content,
delete_messages_after,
delete_message,
upsert_static_system_message,
)
from services.personas import get_persona
from services.chat_prompt import get_system_prompt, DEFAULT_PROMPT
from services.session_identity import resolve_session_persona
from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag
from services.lorebook import get_lorebook_context
from services.sd_images import run_sd_for_message
from services.character_card import get_character
from services import sdbackend as sd_service
from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt
from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase
from services.rpg_narrator import narrator_pre, narrator_post
from services.opening import ensure_plot_arc_and_quests, resolve_greeting, process_opening
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True}
@@ -72,24 +71,20 @@ def affinity_prompt_block(affinity: int) -> str:
return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---"
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
persona = await get_persona(persona_id)
if not persona:
return DEFAULT_PROMPT
prompt = persona["prompt"]
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
context = recent + [{"role": "user", "content": user_message}]
if persona.get("lorebook_json"):
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
if persona_id.startswith("card_"):
card = await get_character(persona_id[5:])
if card:
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
return prompt
def messages_for_llm(history: list, llm_system_content: str) -> list[dict]:
"""Build LLM payload: one system message (static + runtime), no duplicate system rows."""
out: list[dict] = []
system_used = False
for m in history:
if m["role"] == "system":
if not system_used:
out.append({"role": "system", "content": llm_system_content})
system_used = True
else:
out.append({"role": m["role"], "content": m["content"]})
if not system_used:
out.insert(0, {"role": "system", "content": llm_system_content})
return out
@router.get("/history/{session_id}")
@@ -100,11 +95,18 @@ async def get_chat_history(session_id: str):
@router.get("/system/{session_id}")
async def get_system_blob(session_id: str):
history = await get_history(session_id)
system_msg = next((m for m in history if m.get("role") == "system"), None)
session = await get_session(session_id)
persona_id = (session.get("persona_id") if session else None) or "default"
persona = await get_persona(persona_id) or {}
system_msg = next((m for m in history if m.get("role") == "system"), None)
stored = system_msg.get("content") if system_msg else ""
live_static = await get_system_prompt(persona_id, history, "")
system_prompt = live_static if live_static else stored
quests = await get_quests(session_id)
return {
"system_prompt": system_msg.get("content") if system_msg else "",
"persona_id": persona_id,
"persona_name": persona.get("name", persona_id),
"system_prompt": system_prompt,
"status_quo": session.get("status_quo") if session else "",
"facts_json": session.get("facts_json") if session else "[]",
"plot_arc_json": session.get("plot_arc_json") if session else "{}",
@@ -119,14 +121,21 @@ async def get_system_blob(session_id: str):
@router.post("/init")
async def init_chat(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
await get_or_create_session(
request.session_id,
request.persona_id or "default",
)
persona_id = await resolve_session_persona(
request.session_id,
request.persona_id,
create_persona=request.persona_id,
)
history = await get_history(request.session_id)
if history:
return {"first_mes": None}
system_prompt = await get_system_prompt(persona_id, [], "")
await add_message(request.session_id, "system", system_prompt)
await upsert_static_system_message(request.session_id, system_prompt, [])
first_mes = None
if request.first_mes_override and request.first_mes_override.strip():
@@ -152,53 +161,47 @@ class RpgBootstrapRequest(BaseModel):
genre: str = "adventure"
class OpeningProcessRequest(BaseModel):
session_id: str
persona_id: str = "default"
rpg: bool = False
@router.post("/opening/process")
async def opening_process(req: OpeningProcessRequest):
await get_or_create_session(req.session_id, req.persona_id)
persona_id = await resolve_session_persona(req.session_id, req.persona_id)
try:
return await process_opening(req.session_id, persona_id, rpg=req.rpg)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/rpg/bootstrap")
async def rpg_bootstrap(req: RpgBootstrapRequest):
await get_or_create_session(req.session_id, req.persona_id)
session = await get_session(req.session_id)
persona = await get_persona(req.persona_id) or {}
# Save genre
persona_id = await resolve_session_persona(req.session_id, req.persona_id)
await update_session_genre(req.session_id, req.genre)
arc_json = (session.get("plot_arc_json") or "{}") if session else "{}"
try:
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
except Exception:
arc = {}
if not arc:
facts_block = facts_to_prompt((session or {}).get("facts_json", "[]"))
arc = await generate_plot_arc(
persona.get("name", req.persona_id),
persona.get("description", ""),
persona.get("scenario", ""),
persona.get("first_mes", ""),
facts_block=facts_block,
genre=req.genre,
)
if arc:
from services.memory import update_session_plot_arc
await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False))
# Seed quests from beats
for beat in arc.get("beats", []):
title = (beat.get("title") or beat.get("injection", "")).strip()
if title:
await upsert_quest(req.session_id, title[:120])
persona = await get_persona(persona_id) or {}
greeting = await resolve_greeting(req.session_id, persona)
arc = await ensure_plot_arc_and_quests(req.session_id, persona, greeting, req.genre)
quests = await get_quests(req.session_id)
return {"plot_arc": arc, "quests": quests}
@router.post("/stream")
async def chat_stream(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
await get_or_create_session(request.session_id, request.persona_id)
persona_id = await resolve_session_persona(
request.session_id,
request.persona_id,
create_persona=request.persona_id,
)
history = await get_history(request.session_id)
session = await get_session(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
static_prompt = await get_system_prompt(persona_id, history, request.message)
runtime_suffix = ""
arc = {}
roll = None
@@ -206,26 +209,27 @@ async def chat_stream(request: ChatRequest):
resolution_text = ""
narrator_msg = None # shown as narrator bubble before assistant reply
rpg_settings = {}
facts_block = ""
if session and session.get("rpg_enabled"):
rpg_settings = get_rpg_settings(session)
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
if facts_block:
system_prompt = system_prompt + "\n\n" + facts_block
runtime_suffix += "\n\n" + facts_block
try:
arc = json.loads(session.get("plot_arc_json") or "{}")
except Exception:
arc = {}
if arc:
system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps(
runtime_suffix += "\n\n--- PlotArc ---\n" + json.dumps(
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
) + "\n---"
status_quo = (session.get("status_quo") or "").strip()
if status_quo:
system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---"
runtime_suffix += "\n\n--- Status quo ---\n" + status_quo + "\n---"
if rpg_settings.get("affinity", True):
aff = int(session.get("affinity") or 0)
system_prompt = system_prompt + affinity_prompt_block(aff)
runtime_suffix += affinity_prompt_block(aff)
if rpg_settings.get("narrator", True):
persona = await get_persona(persona_id) or {}
@@ -274,7 +278,7 @@ async def chat_stream(request: ChatRequest):
pre_sq = (pre.get("status_quo_update") or "").strip()
if directives:
system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
runtime_suffix += "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
if pre_sq:
await update_session_status_quo(request.session_id, pre_sq)
session["status_quo"] = pre_sq
@@ -290,50 +294,37 @@ async def chat_stream(request: ChatRequest):
)
narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text}
# Inject outcome into system prompt so character reply is consistent
if roll is not None:
system_prompt = (
system_prompt
+ f"\n\n--- Mechanics ---\n"
+ f"Roll d20={roll}. Outcome: {outcome}.\n"
runtime_suffix += (
f"\n\n--- Mechanics ---\n"
f"Roll d20={roll}. Outcome: {outcome}.\n"
+ "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n"
+ "---"
)
# is_narrator_choice: wrap message so LLM understands context
llm_system = static_prompt + runtime_suffix
user_message_content = request.message
if request.is_narrator_choice:
user_message_content = f"[Player chose: {request.message}]"
if not history:
await add_message(request.session_id, "system", system_prompt)
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""UPDATE messages SET content = ?
WHERE session_id = ? AND role = 'system'
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
(system_prompt, request.session_id, request.session_id),
)
await db.commit()
await upsert_static_system_message(request.session_id, static_prompt, history)
if not request.skip_user_add:
await add_message(request.session_id, "user", user_message_content)
messages = await get_history(request.session_id)
llm_messages = messages_for_llm(messages, llm_system)
full_reply = []
async def generate():
nonlocal arc
# Send narrator BEFORE streaming so it appears above the reply
if narrator_msg:
yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n"
try:
async for chunk in stream_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
):
async for chunk in stream_message(llm_messages):
full_reply.append(chunk)
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
except Exception as e:
@@ -344,97 +335,111 @@ async def chat_stream(request: ChatRequest):
complete = "".join(full_reply)
display_text = strip_image_prompt_tag(complete)
hist_with_reply = await get_history(request.session_id) + [
{"role": "assistant", "content": display_text}
]
sd_result = await generate_sd_prompt(
hist_with_reply, persona_id,
outfit_json=session.get("outfit_json", "[]") if session else "[]"
)
prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete)
if (display_text or complete).strip():
await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str)
await add_message(request.session_id, "assistant", display_text or complete)
choices = []
debug_blocks = []
quests_updated = []
if session and session.get("rpg_enabled"):
if not arc:
try:
if not arc:
persona = await get_persona(persona_id) or {}
arc = await generate_plot_arc(
persona.get("name", persona_id),
persona.get("description", ""),
persona.get("scenario", ""),
persona.get("first_mes", ""),
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
genre=session.get("genre") or "adventure",
)
if arc:
await update_session_plot_arc(
request.session_id, json.dumps(arc, ensure_ascii=False)
)
debug_blocks.append({
"type": "plot_arc",
"text": json.dumps(arc, ensure_ascii=False, indent=2),
})
if rpg_settings.get("quests", True):
for beat in arc.get("beats", []):
t = (beat.get("title") or beat.get("injection", "")).strip()
if t:
await upsert_quest(request.session_id, t[:120])
trig = should_advance_arc(request.message)
if trig and arc:
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
if beats:
await update_session_plot_arc(
request.session_id, json.dumps(arc, ensure_ascii=False)
)
inj = beats[0].get("injection", "")
if inj:
debug_blocks.append({"type": "narrator_injection", "text": inj})
if rpg_settings.get("choices", True):
choices += beats[0].get("choices") or []
if advance_phase(arc):
await update_session_plot_arc(
request.session_id, json.dumps(arc, ensure_ascii=False)
)
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
ctx = [
m for m in (await get_history(request.session_id))
if m["role"] in ("user", "assistant")
][-10:]
new_facts = await extract_facts(ctx)
if new_facts:
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
await update_session_facts(request.session_id, merged)
session["facts_json"] = merged
persona = await get_persona(persona_id) or {}
arc = await generate_plot_arc(
persona.get("name", persona_id),
persona.get("description", ""),
persona.get("scenario", ""),
persona.get("first_mes", ""),
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
genre=session.get("genre") or "adventure",
ctx_txt = "\n".join(
f"{m['role']}: {m['content']}"
for m in ctx[-8:]
if m.get("role") in ("user", "assistant")
)
post = await narrator_post(
persona.get("name", persona_id),
ctx_txt,
json.dumps(arc, ensure_ascii=False) if arc else "",
facts_to_prompt(session.get("facts_json", "[]")),
)
if arc:
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)})
if rpg_settings.get("quests", True):
for beat in arc.get("beats", []):
t = (beat.get("title") or beat.get("injection", "")).strip()
if t:
await upsert_quest(request.session_id, t[:120])
trig = should_advance_arc(request.message)
if trig and arc:
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
if beats:
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
inj = beats[0].get("injection", "")
if inj:
debug_blocks.append({"type": "narrator_injection", "text": inj})
if rpg_settings.get("choices", True):
choices += beats[0].get("choices") or []
if advance_phase(arc):
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
sq = (post.get("status_quo_update") or "").strip()
if sq:
await update_session_status_quo(request.session_id, sq)
debug_blocks.append({"type": "status_quo", "text": sq})
ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:]
new_facts = await extract_facts(ctx)
if new_facts:
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
await update_session_facts(request.session_id, merged)
session["facts_json"] = merged
if rpg_settings.get("choices", True):
choices += post.get("choices") or []
persona = await get_persona(persona_id) or {}
ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant"))
post = await narrator_post(
persona.get("name", persona_id),
ctx_txt,
json.dumps(arc, ensure_ascii=False) if arc else "",
facts_to_prompt(session.get("facts_json", "[]")),
)
if rpg_settings.get("affinity", True):
delta = int(post.get("affinity_delta") or 0)
if delta:
await update_session_affinity(request.session_id, delta)
sq = (post.get("status_quo_update") or "").strip()
if sq:
await update_session_status_quo(request.session_id, sq)
debug_blocks.append({"type": "status_quo", "text": sq})
outfit_update = post.get("outfit_update")
if isinstance(outfit_update, list) and outfit_update:
outfit_str = json.dumps(outfit_update, ensure_ascii=False)
await update_session_outfit(request.session_id, outfit_str)
session["outfit_json"] = outfit_str
if rpg_settings.get("choices", True):
choices += post.get("choices") or []
if rpg_settings.get("affinity", True):
delta = int(post.get("affinity_delta") or 0)
if delta:
await update_session_affinity(request.session_id, delta)
outfit_update = post.get("outfit_update")
if isinstance(outfit_update, list) and outfit_update:
outfit_str = json.dumps(outfit_update, ensure_ascii=False)
await update_session_outfit(request.session_id, outfit_str)
session["outfit_json"] = outfit_str
if rpg_settings.get("quests", True):
for qu in (post.get("quest_updates") or []):
t = (qu.get("title") or "").strip()
if t:
await upsert_quest(request.session_id, t[:120], qu.get("status", "active"))
quests_updated = await get_quests(request.session_id)
if rpg_settings.get("quests", True):
for qu in (post.get("quest_updates") or []):
t = (qu.get("title") or "").strip()
if t:
await upsert_quest(
request.session_id, t[:120], qu.get("status", "active")
)
quests_updated = await get_quests(request.session_id)
except LLMError as e:
logger.warning("RPG post-process skipped after reply: %s", e)
except Exception as e:
logger.exception("RPG post-process failed after reply: %s", e)
count = await get_message_count(request.session_id)
if count == 2 and not request.skip_user_add:
@@ -443,23 +448,50 @@ async def chat_stream(request: ChatRequest):
if (session or {}).get("title", "Новый чат") in ("", "Новый чат"):
await update_session_title(request.session_id, f"{persona.get('name', persona_id)}{preview}")
image_path = None
image_error = None
if prompt_str and SD_AUTO_GENERATE:
updated_session = await get_session(request.session_id) or session
hist = await get_history(request.session_id)
bundle = await generate_sd_prompt(
hist,
persona_id,
outfit_json=updated_session.get("outfit_json", "[]") if updated_session else "[]",
)
prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(complete)
msg_id = await get_last_assistant_message_id(request.session_id)
sd_out: dict = {}
if bundle:
yield f"data: {json.dumps({
'image_generating': True,
'image_prompt': bundle.tag_full,
'image_prompt_alt': bundle.desc_full,
})}\n\n"
sd_out = await run_sd_for_message(bundle, msg_id)
elif prompt_str and SD_AUTO_GENERATE:
yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n"
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
if rel:
image_path = rel
msg_id = await get_last_assistant_message_id(request.session_id)
sd_out["image_path"] = f"/static/{rel}"
if msg_id:
await update_message_image(msg_id, rel)
else:
image_error = err
sd_out["image_error"] = err
sd_out["image_prompt"] = prompt_str
updated_session = await get_session(request.session_id)
affinity = updated_session.get("affinity", 0) if updated_session else 0
yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n"
yield f"data: {json.dumps({
'done': True,
'image_prompt': sd_out.get('image_prompt') or prompt_str,
'image_prompt_alt': sd_out.get('image_prompt_alt'),
'image_path': sd_out.get('image_path'),
'image_path_alt': sd_out.get('image_path_alt'),
'image_error': sd_out.get('image_error'),
'image_error_alt': sd_out.get('image_error_alt'),
'choices': choices,
'debug': debug_blocks,
'affinity': affinity,
'quests': quests_updated,
})}\n\n"
return StreamingResponse(
generate(),
@@ -470,23 +502,24 @@ async def chat_stream(request: ChatRequest):
@router.post("/", response_model=ChatResponse)
async def chat(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
await get_or_create_session(request.session_id, request.persona_id)
persona_id = await resolve_session_persona(
request.session_id,
request.persona_id,
create_persona=request.persona_id,
)
history = await get_history(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
if not history:
await add_message(request.session_id, "system", system_prompt)
static_prompt = await get_system_prompt(persona_id, history, request.message)
await upsert_static_system_message(request.session_id, static_prompt, history)
await add_message(request.session_id, "user", request.message)
messages = await get_history(request.session_id)
reply = await send_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
)
llm_messages = messages_for_llm(messages, static_prompt)
reply = await send_message(llm_messages)
display = strip_image_prompt_tag(reply)
prompt_tuple = await generate_sd_prompt(messages, persona_id)
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
bundle = await generate_sd_prompt(messages, persona_id)
prompt_str = bundle.tag_full if bundle else extract_image_prompt_tag(reply)
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
@@ -527,7 +560,6 @@ async def regenerate_chat(req: RegenerateRequest):
stream_req = ChatRequest(
message=user_text,
session_id=req.session_id,
persona_id=req.persona_id,
skip_user_add=True,
)
return await chat_stream(stream_req)
+248
View File
@@ -0,0 +1,248 @@
import json
import os
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services import sdbackend as sd_service
from services.comfy_models import list_node_types, parse_model_lists
from services.llm import (
CHAT_MODEL,
LLM_FALLBACK_MODEL,
LLMError,
SYSTEM_MODEL,
send_message,
send_message_with_model,
)
from services.personas import get_all_personas
from services.sd_prompt import (
SD_PROMPT_MODEL,
anima_dual_enabled,
run_prompt_builder,
)
router = APIRouter(prefix="/debug", tags=["debug"])
class ChatMessage(BaseModel):
role: str
content: str
class SdPromptDebugRequest(BaseModel):
persona_id: str = "default"
chat_excerpt: str = ""
messages: list[ChatMessage] | None = None
outfit_json: str = "[]"
appearance_override: str | None = None
use_prose: bool = False
class LlmDebugRequest(BaseModel):
model: str = ""
system: str = ""
user: str = ""
messages: list[ChatMessage] | None = None
class ComfyRawRequest(BaseModel):
method: str = "GET"
path: str = "/system_stats"
params_json: str = "{}"
body_json: str = ""
class ComfyGenerateRequest(BaseModel):
positive: str
negative: str = ""
unet: str | None = None
clip: str | None = None
vae: str | None = None
checkpoint: str | None = None
@router.get("/config")
async def debug_config():
base = sd_service.SD_BASE_URL
return {
"chat_model": CHAT_MODEL,
"system_model": SYSTEM_MODEL,
"llm_fallback_model": LLM_FALLBACK_MODEL,
"sd_prompt_model": SD_PROMPT_MODEL or SYSTEM_MODEL,
"sd_base_url": base,
"sd_has_token": bool(sd_service.SD_QUERY_PARAMS.get("token")),
"sd_anima_dual": anima_dual_enabled(),
"sd_unet": sd_service.SD_UNET,
"sd_clip": sd_service.SD_CLIP,
"sd_vae": sd_service.SD_VAE,
"sd_checkpoint": sd_service.SD_CHECKPOINT,
"sd_steps": sd_service.SD_STEPS,
"sd_cfg": sd_service.SD_CFG,
"router_key_set": bool(os.getenv("ROUTER_KEY")),
}
@router.get("/personas")
async def debug_personas():
personas = await get_all_personas()
return [
{
"persona_id": pid,
"name": p.get("name", pid),
"appearance_tags": p.get("appearance_tags", ""),
}
for pid, p in personas.items()
]
@router.post("/sd-prompt")
async def debug_sd_prompt(req: SdPromptDebugRequest):
msgs = None
if req.messages:
msgs = [m.model_dump() for m in req.messages]
return await run_prompt_builder(
req.persona_id,
messages=msgs,
chat_excerpt=req.chat_excerpt,
outfit_json=req.outfit_json,
appearance_override=req.appearance_override,
use_prose=req.use_prose,
)
@router.post("/llm")
async def debug_llm(req: LlmDebugRequest):
if req.messages:
messages = [m.model_dump() for m in req.messages]
else:
messages = []
if req.system.strip():
messages.append({"role": "system", "content": req.system.strip()})
if req.user.strip():
messages.append({"role": "user", "content": req.user.strip()})
if not messages:
raise HTTPException(status_code=400, detail="Нужны messages или system/user")
model = (req.model or "").strip() or SD_PROMPT_MODEL or SYSTEM_MODEL
try:
if model in (SYSTEM_MODEL, "") and not req.model:
text = await send_message(messages)
else:
text = await send_message_with_model(messages, model)
return {"model": model, "response": text}
except LLMError as e:
raise HTTPException(status_code=502, detail=str(e))
@router.get("/comfy/ping")
async def debug_comfy_ping():
try:
status, body, headers = await sd_service.comfy_api_request("GET", "/system_stats")
return {"ok": status == 200, "status": status, "body": body, "headers": headers}
except Exception as e:
return {"ok": False, "error": str(e)}
@router.get("/comfy/models")
async def debug_comfy_models():
try:
info = await sd_service.fetch_object_info()
return {
"models": parse_model_lists(info),
"configured": {
"unet": sd_service.SD_UNET,
"clip": sd_service.SD_CLIP,
"vae": sd_service.SD_VAE,
"checkpoint": sd_service.SD_CHECKPOINT,
},
"node_type_count": len(list_node_types(info)),
}
except Exception as e:
raise HTTPException(status_code=502, detail=str(e))
@router.get("/comfy/object_info")
async def debug_comfy_object_info(node: str | None = None):
try:
info = await sd_service.fetch_object_info()
if node:
if node not in info:
raise HTTPException(status_code=404, detail=f"Unknown node: {node}")
return {node: info[node]}
return {
"node_types": list_node_types(info),
"models": parse_model_lists(info),
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=502, detail=str(e))
@router.post("/comfy/raw")
async def debug_comfy_raw(req: ComfyRawRequest):
path = req.path.strip()
if not path.startswith("/"):
path = "/" + path
try:
params = json.loads(req.params_json or "{}")
if not isinstance(params, dict):
raise ValueError("params_json must be object")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"params_json: {e}")
body = None
if req.body_json.strip():
try:
body = json.loads(req.body_json)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"body_json: {e}")
method = req.method.upper()
if method not in ("GET", "POST", "PUT", "DELETE"):
raise HTTPException(status_code=400, detail="method must be GET|POST|PUT|DELETE")
try:
status, resp_body, headers = await sd_service.comfy_api_request(
method,
path,
params=params or None,
json_body=body,
timeout=120,
)
return {"status": status, "headers": headers, "body": resp_body}
except Exception as e:
raise HTTPException(status_code=502, detail=str(e))
@router.post("/comfy/generate")
async def debug_comfy_generate(req: ComfyGenerateRequest):
if not req.positive.strip():
raise HTTPException(status_code=400, detail="positive required")
overrides: dict[str, str] = {}
if req.unet:
overrides["unet"] = req.unet
if req.clip:
overrides["clip"] = req.clip
if req.vae:
overrides["vae"] = req.vae
if req.checkpoint:
overrides["checkpoint"] = req.checkpoint
full = req.positive.strip()
if req.negative.strip():
full += f"\n\nNegative prompt: {req.negative.strip()}"
try:
rel, err = await sd_service.generate_from_full_prompt(
full,
overrides=overrides or None,
)
if not rel:
raise HTTPException(status_code=502, detail=err or "generation failed")
return {"image_path": f"/static/{rel}", "status": "ok"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=502, detail=str(e))
+1
View File
@@ -57,6 +57,7 @@ class PersonaPatch(BaseModel):
lora_name: Optional[str] = None
lora_weight: Optional[float] = None
appearance_tags: Optional[str] = None
appearance_prose: Optional[str] = None
personality: Optional[str] = None
scenario: Optional[str] = None
first_mes: Optional[str] = None
+39 -2
View File
@@ -3,6 +3,7 @@ from services.memory import (
get_all_sessions,
get_session,
get_or_create_session,
get_history,
delete_session,
update_session_title,
update_session_persona,
@@ -17,7 +18,10 @@ from services.memory import (
get_last_message_preview,
fork_session,
)
from models.schemas import ForkSessionRequest
from models.schemas import ForkSessionRequest, RebindPersonaRequest
from services.chat_prompt import get_system_prompt
from services.memory import rebind_session_persona
from services.personas import get_persona
router = APIRouter(prefix="/sessions", tags=["sessions"])
@@ -46,9 +50,42 @@ async def get_session_route(session_id: str):
return s
@router.post("/{session_id}/rebind-persona")
async def rebind_persona(session_id: str, body: RebindPersonaRequest):
session = await get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Сессия не найдена")
persona = await get_persona(body.persona_id)
if not persona:
raise HTTPException(status_code=400, detail="Персонаж не найден")
hist = [] if body.clear_history else await get_history(session_id)
static = await get_system_prompt(body.persona_id, hist, "")
first_mes = (persona.get("first_mes") or "").strip() if body.clear_history else None
try:
await rebind_session_persona(
session_id,
body.persona_id,
clear_history=body.clear_history,
static_prompt=static,
first_mes=first_mes or None,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
return {
"persona_id": body.persona_id,
"persona_name": persona.get("name", body.persona_id),
"system_prompt_preview": static[:500],
"clear_history": body.clear_history,
}
@router.patch("/{session_id}")
async def patch_session(session_id: str, data: dict):
await get_or_create_session(session_id, data.get("persona_id", "default"))
create_pid = data.get("persona_id") if "persona_id" in data else None
await get_or_create_session(session_id, create_pid)
if "title" in data:
await update_session_title(session_id, data["title"])
if "persona_id" in data: