179 lines
5.6 KiB
Python
179 lines
5.6 KiB
Python
import json
|
|
import logging
|
|
|
|
from services.memory import (
|
|
get_history,
|
|
get_session,
|
|
get_last_assistant_message_id,
|
|
update_session_plot_arc,
|
|
update_session_status_quo,
|
|
update_session_affinity,
|
|
update_session_outfit,
|
|
upsert_quest,
|
|
get_quests,
|
|
)
|
|
from services.personas import get_persona
|
|
from services.rpg_facts import facts_to_prompt
|
|
from services.rpg_plot import generate_plot_arc
|
|
from services.rpg_narrator import narrator_post
|
|
from services.sd_prompt import generate_sd_prompt
|
|
from services.sd_images import run_sd_for_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_RPG_SETTINGS = {
|
|
"dice": True,
|
|
"narrator": True,
|
|
"quests": True,
|
|
"affinity": True,
|
|
"choices": True,
|
|
}
|
|
|
|
|
|
def get_rpg_settings(session: dict) -> dict:
|
|
try:
|
|
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
|
except Exception:
|
|
return DEFAULT_RPG_SETTINGS
|
|
|
|
|
|
async def resolve_greeting(session_id: str, persona: dict) -> str:
|
|
history = await get_history(session_id)
|
|
for m in reversed(history):
|
|
if m.get("role") == "assistant" and (m.get("content") or "").strip():
|
|
return m["content"].strip()
|
|
return (persona.get("first_mes") or "").strip()
|
|
|
|
|
|
async def ensure_plot_arc_and_quests(
|
|
session_id: str,
|
|
persona: dict,
|
|
greeting: str,
|
|
genre: str,
|
|
*,
|
|
seed_quests: bool = True,
|
|
) -> dict:
|
|
session = await get_session(session_id) or {}
|
|
arc_json = session.get("plot_arc_json") or "{}"
|
|
try:
|
|
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
|
except Exception:
|
|
arc = {}
|
|
|
|
if arc:
|
|
return arc
|
|
|
|
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
|
arc = await generate_plot_arc(
|
|
persona.get("name", "Character"),
|
|
persona.get("description", ""),
|
|
persona.get("scenario", ""),
|
|
greeting,
|
|
facts_block=facts_block,
|
|
genre=genre,
|
|
)
|
|
if not arc:
|
|
return {}
|
|
|
|
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
|
if seed_quests:
|
|
for beat in arc.get("beats", []):
|
|
title = (beat.get("title") or beat.get("injection", "")).strip()
|
|
if title:
|
|
await upsert_quest(session_id, title[:120])
|
|
return arc
|
|
|
|
|
|
async def process_opening(session_id: str, persona_id: str, *, rpg: bool) -> dict:
|
|
session = await get_session(session_id)
|
|
if not session:
|
|
raise ValueError("Session not found")
|
|
|
|
history = await get_history(session_id)
|
|
assistant_msgs = [m for m in history if m.get("role") == "assistant"]
|
|
if not assistant_msgs:
|
|
raise ValueError("No assistant message (first_mes) found")
|
|
|
|
first_mes_text = assistant_msgs[-1].get("content", "").strip()
|
|
if not first_mes_text:
|
|
raise ValueError("Empty first_mes")
|
|
|
|
msg_id = await get_last_assistant_message_id(session_id)
|
|
persona = await get_persona(persona_id) or {}
|
|
rpg_settings = get_rpg_settings(session)
|
|
|
|
arc: dict = {}
|
|
choices: list = []
|
|
status_quo = session.get("status_quo") or ""
|
|
outfit_json = session.get("outfit_json") or "[]"
|
|
|
|
if rpg:
|
|
genre = session.get("genre") or "adventure"
|
|
arc = await ensure_plot_arc_and_quests(
|
|
session_id,
|
|
persona,
|
|
first_mes_text,
|
|
genre,
|
|
seed_quests=rpg_settings.get("quests", True),
|
|
)
|
|
|
|
session = await get_session(session_id) or session
|
|
ctx_txt = f"assistant: {first_mes_text}"
|
|
arc_json = json.dumps(arc, ensure_ascii=False) if arc else ""
|
|
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
|
|
|
post = await narrator_post(
|
|
persona.get("name", persona_id),
|
|
ctx_txt,
|
|
arc_json,
|
|
facts_block,
|
|
is_opening=True,
|
|
)
|
|
|
|
sq = (post.get("status_quo_update") or "").strip()
|
|
if sq:
|
|
await update_session_status_quo(session_id, sq)
|
|
status_quo = sq
|
|
|
|
if rpg_settings.get("choices", True):
|
|
choices = post.get("choices") or []
|
|
|
|
if rpg_settings.get("affinity", True):
|
|
delta = int(post.get("affinity_delta") or 0)
|
|
if delta:
|
|
await update_session_affinity(session_id, delta)
|
|
|
|
outfit_update = post.get("outfit_update")
|
|
if isinstance(outfit_update, list) and outfit_update:
|
|
outfit_json = json.dumps(outfit_update, ensure_ascii=False)
|
|
await update_session_outfit(session_id, outfit_json)
|
|
|
|
if rpg_settings.get("quests", True):
|
|
for qu in (post.get("quest_updates") or []):
|
|
title = (qu.get("title") or "").strip()
|
|
if title:
|
|
await upsert_quest(session_id, title[:120], qu.get("status", "active"))
|
|
|
|
quests = await get_quests(session_id)
|
|
messages = await get_history(session_id)
|
|
bundle = await generate_sd_prompt(messages, persona_id, outfit_json=outfit_json)
|
|
sd_out = await run_sd_for_message(bundle, msg_id) if bundle else {}
|
|
|
|
updated = await get_session(session_id)
|
|
affinity = updated.get("affinity", 0) if updated else 0
|
|
|
|
return {
|
|
"plot_arc": arc,
|
|
"quests": quests,
|
|
"outfit_json": outfit_json,
|
|
"status_quo": status_quo,
|
|
"choices": choices,
|
|
"image_prompt": sd_out.get("image_prompt"),
|
|
"image_prompt_alt": sd_out.get("image_prompt_alt"),
|
|
"image_path": sd_out.get("image_path"),
|
|
"image_path_alt": sd_out.get("image_path_alt"),
|
|
"image_error": sd_out.get("image_error"),
|
|
"image_error_alt": sd_out.get("image_error_alt"),
|
|
"affinity": affinity,
|
|
}
|