Fixed SD RPG
This commit is contained in:
+405
-5
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from services.llm import send_message_with_model, send_message
|
||||
from services.llm import LLMError, send_message_with_model, send_message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +63,19 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
|
||||
{"role": "system", "content": ARC_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
raw = await (send_message_with_model(messages, PLOT_MODEL) if PLOT_MODEL else send_message(messages))
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("generate_plot_arc LLM failed (model=%s): %s", PLOT_MODEL or "SYSTEM", e)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("generate_plot_arc unexpected error: %s", e)
|
||||
return {}
|
||||
|
||||
cleaned = raw.strip()
|
||||
# common OpenRouter formatting: fenced json
|
||||
if cleaned.startswith("```"):
|
||||
@@ -79,17 +91,236 @@ async def generate_plot_arc(persona_name: str, persona_desc: str, persona_scenar
|
||||
return {}
|
||||
|
||||
|
||||
def should_advance_arc(user_text: str) -> str | None:
|
||||
BEAT_MATCH_SYSTEM = """You decide whether the player's latest message should fire ONE scripted plot beat.
|
||||
Return ONLY valid JSON (no markdown):
|
||||
{"fire_beat_id": "id from list or null", "confidence": "high|low"}
|
||||
|
||||
Rules:
|
||||
- Fire only if the message clearly matches that beat's narrative intent RIGHT NOW.
|
||||
- event_driven:rest — stopping to rest, sleep, camp, sauna break, recuperate (not mere sitting still in scene).
|
||||
- event_driven:travel — leaving, driving, journey, going to a new place, hitting the road.
|
||||
- event_driven:help_request — explicit plea for help/rescue/assistance.
|
||||
- event_driven:after_fail / after_success — follow-up to a recent failure/success beat.
|
||||
- Casual talk, flirting, exploring the current place without leaving does NOT fire travel.
|
||||
- If nothing fits well, return null.
|
||||
- Pick at most ONE beat; prefer high confidence only."""
|
||||
|
||||
|
||||
def dice_outcome_to_beat_trigger(outcome: str | None) -> str | None:
|
||||
"""Map d20 outcome to event_driven beat trigger (after_fail / after_success)."""
|
||||
o = (outcome or "").strip().lower()
|
||||
if o in ("failure", "critical failure"):
|
||||
return "event_driven:after_fail"
|
||||
if o in ("success", "critical success"):
|
||||
return "event_driven:after_success"
|
||||
return None
|
||||
|
||||
|
||||
def should_advance_arc_keywords(user_text: str) -> str | None:
|
||||
"""Legacy keyword fallback when LLM match is unavailable."""
|
||||
t = (user_text or "").lower()
|
||||
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн"]):
|
||||
if any(x in t for x in ["отдыха", "ночлег", "спим", "сон", "разбить лагерь", "лагерь", "отдохн", "привала", "заправк", "саун"]):
|
||||
return "event_driven:rest"
|
||||
if any(x in t for x in ["идем дальше", "пойдем дальше", "в путь", "продолжаем путь", "уходим", "возвращаемся", "переходим"]):
|
||||
if any(
|
||||
x in t
|
||||
for x in [
|
||||
"идем дальше", "пойдем дальше", "пойдём дальше", "едем дальше", "едем",
|
||||
"поехали", "выезжаем", "выезжаю", "в путь", "продолжаем путь",
|
||||
"уходим", "возвращаемся", "переходим", "за рул", "машин", "автомоб",
|
||||
"дорог", "трас", "шосс", "приех", "прибыва", "стади", "на стадион",
|
||||
"отправляемся", "выдвигаемся", "в дорогу",
|
||||
]
|
||||
):
|
||||
return "event_driven:travel"
|
||||
if any(x in t for x in ["помоги", "помочь", "нужна помощь", "спасите", "help"]):
|
||||
return "event_driven:help_request"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_llm_json(raw: str) -> dict | list | None:
|
||||
cleaned = (raw or "").strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
async def classify_plot_beat(
|
||||
user_text: str,
|
||||
beats: list[dict],
|
||||
recent_context: str = "",
|
||||
last_dice_outcome: str | None = None,
|
||||
) -> str | None:
|
||||
"""LLM: return beat id to fire, or None."""
|
||||
pending = [b for b in beats if isinstance(b, dict) and b.get("id")]
|
||||
if not pending or not (user_text or "").strip():
|
||||
return None
|
||||
|
||||
lines = []
|
||||
for b in pending[:8]:
|
||||
lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"id": b.get("id"),
|
||||
"title": b.get("title", ""),
|
||||
"trigger": b.get("trigger", ""),
|
||||
"injection": (b.get("injection") or "")[:200],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
user = (
|
||||
f"Player message:\n{user_text.strip()}\n\n"
|
||||
f"Pending beats:\n" + "\n".join(lines)
|
||||
)
|
||||
if recent_context.strip():
|
||||
user += f"\n\nRecent chat:\n{recent_context.strip()[-2500:]}\n"
|
||||
if last_dice_outcome:
|
||||
user += f"\nLast dice outcome this turn: {last_dice_outcome}\n"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": BEAT_MATCH_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("classify_plot_beat LLM failed: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("classify_plot_beat unexpected: %s", e)
|
||||
return None
|
||||
|
||||
data = _parse_llm_json(raw)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
bid = data.get("fire_beat_id")
|
||||
if bid in (None, "", "null", "none"):
|
||||
return None
|
||||
bid = str(bid).strip()
|
||||
if data.get("confidence") == "low":
|
||||
return None
|
||||
valid_ids = {str(b.get("id")) for b in pending}
|
||||
if bid in valid_ids:
|
||||
logger.info("classify_plot_beat: fired %s", bid)
|
||||
return bid
|
||||
return None
|
||||
|
||||
|
||||
def pop_beat_by_id(arc: dict, beat_id: str) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats") or []
|
||||
matched, remaining = [], []
|
||||
for b in beats:
|
||||
if isinstance(b, dict) and str(b.get("id")) == str(beat_id) and not matched:
|
||||
matched.append(b)
|
||||
else:
|
||||
remaining.append(b)
|
||||
arc["beats"] = remaining
|
||||
return arc, matched
|
||||
|
||||
|
||||
def beat_title(beat: dict) -> str:
|
||||
return ((beat.get("title") or beat.get("injection") or "")[:120]).strip()
|
||||
|
||||
|
||||
def count_active_quests(quests: list | None) -> int:
|
||||
return sum(1 for q in (quests or []) if q.get("status") == "active")
|
||||
|
||||
|
||||
def prune_beats_for_done_quests(arc: dict, quests: list | None) -> tuple[dict, list[dict]]:
|
||||
"""Drop beats whose title already matches a done/failed quest (manual quest close desync)."""
|
||||
done_titles = {
|
||||
(q.get("title") or "").strip().lower()
|
||||
for q in (quests or [])
|
||||
if q.get("status") in ("done", "failed")
|
||||
}
|
||||
if not done_titles:
|
||||
return arc, []
|
||||
removed, kept = [], []
|
||||
for b in arc.get("beats") or []:
|
||||
if isinstance(b, dict) and beat_title(b).lower() in done_titles:
|
||||
removed.append(b)
|
||||
else:
|
||||
kept.append(b)
|
||||
arc["beats"] = kept
|
||||
return arc, removed
|
||||
|
||||
|
||||
def pop_next_beats(arc: dict, max_beats: int = 1) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats") or []
|
||||
if not isinstance(beats, list) or not beats:
|
||||
return arc, []
|
||||
n = min(max_beats, len(beats))
|
||||
matched = [b for b in beats[:n] if isinstance(b, dict)]
|
||||
arc["beats"] = beats[n:]
|
||||
return arc, matched
|
||||
|
||||
|
||||
async def process_arc_beats(
|
||||
arc: dict,
|
||||
quests: list | None,
|
||||
user_text: str,
|
||||
*,
|
||||
recent_context: str = "",
|
||||
last_dice_outcome: str | None = None,
|
||||
allow_stuck_recovery: bool = True,
|
||||
) -> tuple[dict, list[dict], list[dict], str]:
|
||||
"""
|
||||
Prune completed beats, then fire by dice outcome, LLM match, keywords, or stuck recovery.
|
||||
Returns (arc, fired_beats, pruned_beats, mode).
|
||||
mode: '' | 'after_dice' | 'llm' | 'trigger' | 'stuck_recovery' | 'pruned'
|
||||
"""
|
||||
if not arc:
|
||||
return arc, [], [], ""
|
||||
|
||||
arc, pruned = prune_beats_for_done_quests(arc, quests)
|
||||
beats_pending = arc.get("beats") or []
|
||||
|
||||
dice_trig = dice_outcome_to_beat_trigger(last_dice_outcome)
|
||||
if dice_trig and beats_pending:
|
||||
arc, fired = pop_matching_beats(arc, dice_trig, max_beats=1)
|
||||
if fired:
|
||||
logger.info(
|
||||
"process_arc_beats: after_dice %s -> %s",
|
||||
last_dice_outcome,
|
||||
fired[0].get("id"),
|
||||
)
|
||||
return arc, fired, pruned, "after_dice"
|
||||
|
||||
if beats_pending:
|
||||
beat_id = await classify_plot_beat(
|
||||
user_text, beats_pending, recent_context, last_dice_outcome
|
||||
)
|
||||
if beat_id:
|
||||
arc, fired = pop_beat_by_id(arc, beat_id)
|
||||
if fired:
|
||||
return arc, fired, pruned, "llm"
|
||||
|
||||
trig = should_advance_arc_keywords(user_text)
|
||||
if trig:
|
||||
arc, fired = pop_matching_beats(arc, trig, max_beats=1)
|
||||
if fired:
|
||||
return arc, fired, pruned, "trigger"
|
||||
|
||||
if allow_stuck_recovery and arc.get("beats") and count_active_quests(quests) == 0:
|
||||
arc, fired = pop_next_beats(arc, 1)
|
||||
if fired:
|
||||
return arc, fired, pruned, "stuck_recovery"
|
||||
|
||||
if pruned:
|
||||
return arc, [], pruned, "pruned"
|
||||
return arc, [], [], ""
|
||||
|
||||
|
||||
PHASE_ORDER = ["opening", "hook", "complication", "reveal", "climax", "aftermath"]
|
||||
|
||||
|
||||
@@ -108,6 +339,129 @@ def advance_phase(arc: dict) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
BEATS_APPEND_SYSTEM = """You are a narrative designer for an RPG chat.
|
||||
The plot arc has NO remaining scripted beats. Generate 2-3 NEW beats to continue play.
|
||||
Return ONLY valid JSON (no markdown):
|
||||
{
|
||||
"beats": [
|
||||
{"id":"b_new_1","title":"short quest title","trigger":"event_driven:rest|event_driven:travel|event_driven:help_request|event_driven:after_fail|event_driven:after_success",
|
||||
"injection":"1-3 sentences in-world",
|
||||
"choices":[{"id":"a","label":"..."},{"id":"b","label":"..."}]}
|
||||
],
|
||||
"next_beat_hint": "what to push next",
|
||||
"phase": "hook|complication|reveal|climax|aftermath"
|
||||
}
|
||||
Match the current scene and completed quests. Do not restart finished storylines."""
|
||||
|
||||
|
||||
async def replenish_arc_beats(
|
||||
arc: dict,
|
||||
persona_name: str,
|
||||
recent_context: str,
|
||||
quests: list,
|
||||
genre: str = "adventure",
|
||||
) -> dict:
|
||||
"""Append new beats when arc.beats is empty so plot/quest engine can continue."""
|
||||
if arc.get("beats"):
|
||||
return arc
|
||||
|
||||
quest_lines = "\n".join(
|
||||
f" [{q.get('status')}] {q.get('title')}" for q in (quests or [])
|
||||
) or " (none)"
|
||||
user = (
|
||||
f"Character: {persona_name}\n"
|
||||
f"Genre: {format_genres(genre)}\n"
|
||||
f"Current arc title: {arc.get('title', '')}\n"
|
||||
f"Phase: {arc.get('phase', 'aftermath')}\n"
|
||||
f"Boundaries: {json.dumps(arc.get('boundaries', []), ensure_ascii=False)}\n"
|
||||
f"Quests:\n{quest_lines}\n\n"
|
||||
f"Recent chat:\n{recent_context[-4000:]}\n"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": BEATS_APPEND_SYSTEM},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
try:
|
||||
raw = await (
|
||||
send_message_with_model(messages, PLOT_MODEL)
|
||||
if PLOT_MODEL
|
||||
else send_message(messages)
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.warning("replenish_arc_beats failed: %s", e)
|
||||
return arc
|
||||
except Exception as e:
|
||||
logger.warning("replenish_arc_beats unexpected: %s", e)
|
||||
return arc
|
||||
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned.rsplit("```", 1)[0]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except Exception:
|
||||
logger.warning("replenish_arc_beats JSON parse failed. Raw=%.400s", raw)
|
||||
return arc
|
||||
|
||||
new_beats = data.get("beats") if isinstance(data, dict) else []
|
||||
if isinstance(new_beats, list) and new_beats:
|
||||
arc["beats"] = new_beats
|
||||
logger.info("replenish_arc_beats: added %d beats", len(new_beats))
|
||||
if isinstance(data, dict) and data.get("next_beat_hint"):
|
||||
arc["next_beat_hint"] = data["next_beat_hint"]
|
||||
if isinstance(data, dict) and data.get("phase"):
|
||||
arc["phase"] = data["phase"]
|
||||
return arc
|
||||
|
||||
|
||||
async def reconcile_plot_arc(
|
||||
session_id: str,
|
||||
*,
|
||||
replenish_if_empty: bool = True,
|
||||
recent_context: str = "",
|
||||
persona_name: str = "Character",
|
||||
genre: str = "adventure",
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Prune beats that match done quests; replenish if empty. Persists arc when changed.
|
||||
Returns (arc, changed).
|
||||
"""
|
||||
from services.memory import get_session, get_quests, update_session_plot_arc, seed_quests_from_arc
|
||||
|
||||
session = await get_session(session_id)
|
||||
if not session or not session.get("rpg_enabled"):
|
||||
return {}, False
|
||||
try:
|
||||
arc = json.loads(session.get("plot_arc_json") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
arc = {}
|
||||
if not isinstance(arc, dict):
|
||||
arc = {}
|
||||
|
||||
quests = await get_quests(session_id)
|
||||
arc, pruned = prune_beats_for_done_quests(arc, quests)
|
||||
changed = bool(pruned)
|
||||
|
||||
if replenish_if_empty and not arc.get("beats"):
|
||||
arc = await replenish_arc_beats(
|
||||
arc,
|
||||
persona_name,
|
||||
recent_context,
|
||||
quests,
|
||||
genre=session.get("genre") or genre,
|
||||
)
|
||||
if arc.get("beats"):
|
||||
changed = True
|
||||
await seed_quests_from_arc(session_id, arc)
|
||||
|
||||
if changed:
|
||||
await update_session_plot_arc(session_id, json.dumps(arc, ensure_ascii=False))
|
||||
return arc, changed
|
||||
|
||||
|
||||
def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dict, list[dict]]:
|
||||
beats = arc.get("beats", [])
|
||||
if not isinstance(beats, list):
|
||||
@@ -121,3 +475,49 @@ def pop_matching_beats(arc: dict, trigger: str, max_beats: int = 1) -> tuple[dic
|
||||
arc["beats"] = remaining
|
||||
return arc, matched
|
||||
|
||||
|
||||
def normalize_choice(
|
||||
raw: dict,
|
||||
*,
|
||||
source: str = "narrator",
|
||||
beat: dict | None = None,
|
||||
) -> dict | None:
|
||||
"""Normalize a choice dict for storage/UI. Adds source and optional beat metadata."""
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
label = (raw.get("label") or "").strip()
|
||||
if not label:
|
||||
return None
|
||||
cid = (raw.get("id") or label[:1].lower() or "a").strip()
|
||||
out = {"id": cid, "label": label, "source": source}
|
||||
if beat and source == "plot_beat":
|
||||
if beat.get("id"):
|
||||
out["beat_id"] = beat["id"]
|
||||
title = (beat.get("title") or "").strip()
|
||||
if title:
|
||||
out["beat_title"] = title
|
||||
injection = (beat.get("injection") or "").strip()
|
||||
if injection:
|
||||
out["beat_injection"] = injection
|
||||
return out
|
||||
|
||||
|
||||
def choices_from_beat(beat: dict) -> list[dict]:
|
||||
if not isinstance(beat, dict):
|
||||
return []
|
||||
return [
|
||||
c for c in (
|
||||
normalize_choice(item, source="plot_beat", beat=beat)
|
||||
for item in (beat.get("choices") or [])
|
||||
)
|
||||
if c
|
||||
]
|
||||
|
||||
|
||||
def choices_from_narrator(raw_choices: list) -> list[dict]:
|
||||
if not isinstance(raw_choices, list):
|
||||
return []
|
||||
return [
|
||||
c for c in (normalize_choice(item, source="narrator") for item in raw_choices) if c
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user