174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
import json
|
|
import logging
|
|
|
|
from services.memory import (
|
|
get_history,
|
|
get_session,
|
|
get_last_assistant_message_id,
|
|
update_session_plot_arc,
|
|
update_message_choices,
|
|
seed_quests_from_arc,
|
|
get_quests,
|
|
)
|
|
from services.rpg_state import apply_narrator_post
|
|
from services.personas import get_persona
|
|
from services.rpg_facts import facts_to_prompt
|
|
from services.rpg_plot import generate_plot_arc, choices_from_narrator
|
|
from services.rpg_context import format_narrator_context
|
|
from services.rpg_narrator import narrator_post
|
|
from services.sd_prompt import generate_sd_prompt
|
|
from services.sd_images import run_sd_for_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_RPG_SETTINGS = {
|
|
"dice": True,
|
|
"narrator": True,
|
|
"quests": True,
|
|
"affinity": True,
|
|
"choices": True,
|
|
"stats": False,
|
|
}
|
|
|
|
|
|
def get_rpg_settings(session: dict) -> dict:
|
|
try:
|
|
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
|
except Exception:
|
|
return DEFAULT_RPG_SETTINGS
|
|
|
|
|
|
async def resolve_greeting(session_id: str, persona: dict) -> str:
|
|
history = await get_history(session_id)
|
|
for m in reversed(history):
|
|
if m.get("role") == "assistant" and (m.get("content") or "").strip():
|
|
return m["content"].strip()
|
|
return (persona.get("first_mes") or "").strip()
|
|
|
|
|
|
async def ensure_plot_arc_and_quests(
|
|
session_id: str,
|
|
persona: dict,
|
|
greeting: str,
|
|
genre: str,
|
|
*,
|
|
seed_quests: bool = True,
|
|
) -> dict:
|
|
session = await get_session(session_id) or {}
|
|
arc_json = session.get("plot_arc_json") or "{}"
|
|
try:
|
|
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
|
except Exception:
|
|
arc = {}
|
|
|
|
if arc:
|
|
return arc
|
|
|
|
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
|
arc = await generate_plot_arc(
|
|
persona.get("name", "Character"),
|
|
persona.get("description", ""),
|
|
persona.get("scenario", ""),
|
|
greeting,
|
|
facts_block=facts_block,
|
|
genre=genre,
|
|
)
|
|
if not arc:
|
|
return {}
|
|
|
|
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
|
if seed_quests:
|
|
await seed_quests_from_arc(session_id, arc)
|
|
return arc
|
|
|
|
|
|
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
|
|
session = await get_session(session_id)
|
|
if not session:
|
|
raise ValueError("Session not found")
|
|
|
|
history = await get_history(session_id)
|
|
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
|
|
if not assistant_msgs:
|
|
raise ValueError("No assistant message (first_mes) found")
|
|
|
|
first_mes_text = assistant_msgs[-1].get("content", "").strip()
|
|
if not first_mes_text:
|
|
raise ValueError("Empty first_mes")
|
|
|
|
msg_id = await get_last_assistant_message_id(session_id)
|
|
persona = await get_persona(persona_id) or {}
|
|
rpg_settings = get_rpg_settings(session)
|
|
|
|
arc: dict = {}
|
|
choices: list = []
|
|
status_quo = session.get("status_quo") or ""
|
|
outfit_json = session.get("outfit_json") or "[]"
|
|
|
|
if rpg:
|
|
genre = session.get("genre") or "adventure"
|
|
arc = await ensure_plot_arc_and_quests(
|
|
session_id,
|
|
persona,
|
|
first_mes_text,
|
|
genre,
|
|
seed_quests=rpg_settings.get("quests", True),
|
|
)
|
|
|
|
session = await get_session(session_id) or session
|
|
ctx_txt = f"assistant: {first_mes_text}"
|
|
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
|
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
|
|
|
quests_pre = await get_quests(session_id)
|
|
narr_ctx = format_narrator_context(arc, quests_pre, session.get("status_quo") or "")
|
|
post = await narrator_post(
|
|
persona.get("name", persona_id),
|
|
ctx_txt,
|
|
arc_json,
|
|
facts_block,
|
|
is_opening=True,
|
|
extra_context=narr_ctx,
|
|
)
|
|
|
|
if rpg_settings.get("choices", True):
|
|
choices = choices_from_narrator(post.get("choices") or [])
|
|
|
|
await apply_narrator_post(session_id, post, rpg_settings, session)
|
|
session = await get_session(session_id) or session
|
|
status_quo = session.get("status_quo") or status_quo
|
|
outfit_json = session.get("outfit_json") or outfit_json
|
|
|
|
quests = await get_quests(session_id)
|
|
messages = await get_history(session_id)
|
|
bundle = await generate_sd_prompt(
|
|
messages,
|
|
persona_id,
|
|
outfit_json=outfit_json,
|
|
scene_json=session.get("scene_json", "{}") if session else "{}",
|
|
)
|
|
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
|
|
|
|
updated = await get_session(session_id)
|
|
affinity = updated.get("affinity", 0) if updated else 0
|
|
|
|
if msg_id and choices:
|
|
await update_message_choices(
|
|
msg_id, json.dumps(choices, ensure_ascii=False)
|
|
)
|
|
|
|
return {
|
|
"plot_arc": arc,
|
|
"quests": quests,
|
|
"outfit_json": outfit_json,
|
|
"status_quo": status_quo,
|
|
"choices": choices,
|
|
"image_prompt": sd_out.get("image_prompt"),
|
|
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
|
"image_path": sd_out.get("image_path"),
|
|
"image_path_alt": sd_out.get("image_path_alt"),
|
|
"image_error": sd_out.get("image_error"),
|
|
"image_error_alt": sd_out.get("image_error_alt"),
|
|
"affinity": affinity,
|
|
}
|