Fixed SD RPG
This commit is contained in:
@@ -0,0 +1,173 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.memory import (
|
||||
get_history,
|
||||
get_session,
|
||||
get_last_assistant_message_id,
|
||||
update_session_plot_arc,
|
||||
update_message_choices,
|
||||
seed_quests_from_arc,
|
||||
get_quests,
|
||||
)
|
||||
from services.rpg_state import apply_narrator_post
|
||||
from services.personas import get_persona
|
||||
from services.rpg_facts import facts_to_prompt
|
||||
from services.rpg_plot import generate_plot_arc, choices_from_narrator
|
||||
from services.rpg_context import format_narrator_context
|
||||
from services.rpg_narrator import narrator_post
|
||||
from services.sd_prompt import generate_sd_prompt
|
||||
from services.sd_images import run_sd_for_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RPG_SETTINGS = {
|
||||
"dice": True,
|
||||
"narrator": True,
|
||||
"quests": True,
|
||||
"affinity": True,
|
||||
"choices": True,
|
||||
"stats": False,
|
||||
}
|
||||
|
||||
|
||||
def get_rpg_settings(session: dict) -> dict:
|
||||
try:
|
||||
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
||||
except Exception:
|
||||
return DEFAULT_RPG_SETTINGS
|
||||
|
||||
|
||||
async def resolve_greeting(session_id: str, persona: dict) -> str:
|
||||
history = await get_history(session_id)
|
||||
for m in reversed(history):
|
||||
if m.get("role") == "assistant" and (m.get("content") or "").strip():
|
||||
return m["content"].strip()
|
||||
return (persona.get("first_mes") or "").strip()
|
||||
|
||||
|
||||
async def ensure_plot_arc_and_quests(
|
||||
session_id: str,
|
||||
persona: dict,
|
||||
greeting: str,
|
||||
genre: str,
|
||||
*,
|
||||
seed_quests: bool = True,
|
||||
) -> dict:
|
||||
session = await get_session(session_id) or {}
|
||||
arc_json = session.get("plot_arc_json") or "{}"
|
||||
try:
|
||||
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
||||
except Exception:
|
||||
arc = {}
|
||||
|
||||
if arc:
|
||||
return arc
|
||||
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", "Character"),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
greeting,
|
||||
facts_block=facts_block,
|
||||
genre=genre,
|
||||
)
|
||||
if not arc:
|
||||
return {}
|
||||
|
||||
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
||||
if seed_quests:
|
||||
await seed_quests_from_arc(session_id, arc)
|
||||
return arc
|
||||
|
||||
|
||||
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
history = await get_history(session_id)
|
||||
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
|
||||
if not assistant_msgs:
|
||||
raise ValueError("No assistant message (first_mes) found")
|
||||
|
||||
first_mes_text = assistant_msgs[-1].get("content", "").strip()
|
||||
if not first_mes_text:
|
||||
raise ValueError("Empty first_mes")
|
||||
|
||||
msg_id = await get_last_assistant_message_id(session_id)
|
||||
persona = await get_persona(persona_id) or {}
|
||||
rpg_settings = get_rpg_settings(session)
|
||||
|
||||
arc: dict = {}
|
||||
choices: list = []
|
||||
status_quo = session.get("status_quo") or ""
|
||||
outfit_json = session.get("outfit_json") or "[]"
|
||||
|
||||
if rpg:
|
||||
genre = session.get("genre") or "adventure"
|
||||
arc = await ensure_plot_arc_and_quests(
|
||||
session_id,
|
||||
persona,
|
||||
first_mes_text,
|
||||
genre,
|
||||
seed_quests=rpg_settings.get("quests", True),
|
||||
)
|
||||
|
||||
session = await get_session(session_id) or session
|
||||
ctx_txt = f"assistant: {first_mes_text}"
|
||||
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
|
||||
quests_pre = await get_quests(session_id)
|
||||
narr_ctx = format_narrator_context(arc, quests_pre, session.get("status_quo") or "")
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
arc_json,
|
||||
facts_block,
|
||||
is_opening=True,
|
||||
extra_context=narr_ctx,
|
||||
)
|
||||
|
||||
if rpg_settings.get("choices", True):
|
||||
choices = choices_from_narrator(post.get("choices") or [])
|
||||
|
||||
await apply_narrator_post(session_id, post, rpg_settings, session)
|
||||
session = await get_session(session_id) or session
|
||||
status_quo = session.get("status_quo") or status_quo
|
||||
outfit_json = session.get("outfit_json") or outfit_json
|
||||
|
||||
quests = await get_quests(session_id)
|
||||
messages = await get_history(session_id)
|
||||
bundle = await generate_sd_prompt(
|
||||
messages,
|
||||
persona_id,
|
||||
outfit_json=outfit_json,
|
||||
scene_json=session.get("scene_json", "{}") if session else "{}",
|
||||
)
|
||||
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
|
||||
|
||||
updated = await get_session(session_id)
|
||||
affinity = updated.get("affinity", 0) if updated else 0
|
||||
|
||||
if msg_id and choices:
|
||||
await update_message_choices(
|
||||
msg_id, json.dumps(choices, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return {
|
||||
"plot_arc": arc,
|
||||
"quests": quests,
|
||||
"outfit_json": outfit_json,
|
||||
"status_quo": status_quo,
|
||||
"choices": choices,
|
||||
"image_prompt": sd_out.get("image_prompt"),
|
||||
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
||||
"image_path": sd_out.get("image_path"),
|
||||
"image_path_alt": sd_out.get("image_path_alt"),
|
||||
"image_error": sd_out.get("image_error"),
|
||||
"image_error_alt": sd_out.get("image_error_alt"),
|
||||
"affinity": affinity,
|
||||
}
|
||||
Reference in New Issue
Block a user