Added RPG
This commit is contained in:
+205
-1
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from database.db import DB_PATH
|
||||
from models.schemas import ChatRequest, ChatResponse
|
||||
@@ -13,10 +15,14 @@ from services.memory import (
|
||||
add_message,
|
||||
clear_history,
|
||||
get_or_create_session,
|
||||
get_session,
|
||||
update_session_title,
|
||||
get_message_count,
|
||||
get_last_assistant_message_id,
|
||||
update_message_image,
|
||||
update_session_facts,
|
||||
update_session_status_quo,
|
||||
add_action_resolution,
|
||||
)
|
||||
from services.personas import get_persona
|
||||
from services.sd_prompt import (
|
||||
@@ -27,6 +33,9 @@ from services.sd_prompt import (
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.character_card import get_character
|
||||
from services import sdbackend as sd_service
|
||||
from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt
|
||||
from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats
|
||||
from services.rpg_narrator import narrator_pre, narrator_post
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
@@ -41,6 +50,14 @@ async def get_system_prompt(persona_id: str, history: list, user_message: str =
|
||||
|
||||
prompt = persona["prompt"]
|
||||
|
||||
# persona-only lorebook
|
||||
if persona.get("lorebook_json"):
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt = prompt + "\n\n" + lore
|
||||
|
||||
if persona_id.startswith("card_"):
|
||||
card_id = persona_id[5:]
|
||||
card = await get_character(card_id)
|
||||
@@ -60,6 +77,20 @@ async def get_chat_history(session_id: str):
|
||||
return await get_history(session_id)
|
||||
|
||||
|
||||
@router.get("/system/{session_id}")
|
||||
async def get_system_blob(session_id: str):
|
||||
history = await get_history(session_id)
|
||||
system_msg = next((m for m in history if m.get("role") == "system"), None)
|
||||
session = await get_session(session_id)
|
||||
return {
|
||||
"system_prompt": system_msg.get("content") if system_msg else "",
|
||||
"facts_json": session.get("facts_json") if session else "[]",
|
||||
"status_quo": session.get("status_quo") if session else "",
|
||||
"plot_arc_json": session.get("plot_arc_json") if session else "{}",
|
||||
"rpg_enabled": bool(session.get("rpg_enabled")) if session else False,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def init_chat(request: ChatRequest):
|
||||
"""Called when opening a new chat. Seeds system prompt and first_mes if card persona."""
|
||||
@@ -73,7 +104,11 @@ async def init_chat(request: ChatRequest):
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
|
||||
first_mes = None
|
||||
if persona_id.startswith("card_"):
|
||||
persona = await get_persona(persona_id)
|
||||
if persona and persona.get("first_mes"):
|
||||
first_mes = persona["first_mes"]
|
||||
await add_message(request.session_id, "assistant", first_mes)
|
||||
elif persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card and card.get("first_mes"):
|
||||
first_mes = card["first_mes"]
|
||||
@@ -82,6 +117,39 @@ async def init_chat(request: ChatRequest):
|
||||
return {"first_mes": first_mes}
|
||||
|
||||
|
||||
class RpgBootstrapRequest(BaseModel):
|
||||
session_id: str
|
||||
persona_id: str = "default"
|
||||
|
||||
|
||||
@router.post("/rpg/bootstrap")
|
||||
async def rpg_bootstrap(req: RpgBootstrapRequest):
|
||||
"""Generate plot arc early for debugging."""
|
||||
await get_or_create_session(req.session_id, req.persona_id)
|
||||
session = await get_session(req.session_id)
|
||||
persona = await get_persona(req.persona_id) or {}
|
||||
|
||||
arc_json = (session.get("plot_arc_json") or "{}") if session else "{}"
|
||||
try:
|
||||
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
||||
except Exception:
|
||||
arc = {}
|
||||
if not arc:
|
||||
facts_block = facts_to_prompt((session or {}).get("facts_json", "[]"))
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", req.persona_id),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
persona.get("first_mes", ""),
|
||||
facts_block=facts_block,
|
||||
)
|
||||
if arc:
|
||||
from services.memory import update_session_plot_arc
|
||||
await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
|
||||
return {"plot_arc": arc}
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
@@ -89,7 +157,77 @@ async def chat_stream(request: ChatRequest):
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
session = await get_session(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
# Experimental RPG: inject persistent facts + global plot
|
||||
arc = {}
|
||||
if session and session.get("rpg_enabled"):
|
||||
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
||||
if facts_block:
|
||||
system_prompt = system_prompt + "\n\n" + facts_block
|
||||
# load plot arc
|
||||
try:
|
||||
arc = json.loads(session.get("plot_arc_json") or "{}")
|
||||
except Exception:
|
||||
arc = {}
|
||||
if arc:
|
||||
system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps(
|
||||
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
|
||||
) + "\n---"
|
||||
status_quo = (session.get("status_quo") or "").strip()
|
||||
if status_quo:
|
||||
system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---"
|
||||
|
||||
# d20 outcome directive
|
||||
roll = random.randint(1, 20)
|
||||
if roll == 1:
|
||||
outcome = "critical failure"
|
||||
elif roll <= 8:
|
||||
outcome = "failure"
|
||||
elif roll >= 20:
|
||||
outcome = "critical success"
|
||||
else:
|
||||
outcome = "success"
|
||||
system_prompt = (
|
||||
system_prompt
|
||||
+ f"\n\n--- Mechanics ---\n"
|
||||
+ f"Roll d20={roll}. Outcome: {outcome}.\n"
|
||||
+ "You MUST incorporate this outcome into the narrative result.\n"
|
||||
+ "---"
|
||||
)
|
||||
|
||||
# System/Narrator pre-pass: add directives for the next reply + optional status quo update
|
||||
persona = await get_persona(persona_id) or {}
|
||||
recent_txt = "\n".join(
|
||||
f"{m['role']}: {m['content']}" for m in history[-8:]
|
||||
if m.get("role") in ("user", "assistant")
|
||||
)
|
||||
pre = await narrator_pre(
|
||||
persona.get("name", persona_id),
|
||||
recent_txt,
|
||||
json.dumps(arc, ensure_ascii=False) if arc else "",
|
||||
facts_block,
|
||||
roll,
|
||||
outcome,
|
||||
)
|
||||
directives = pre.get("directives") or []
|
||||
if directives:
|
||||
system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
|
||||
pre_sq = (pre.get("status_quo_update") or "").strip()
|
||||
if pre_sq:
|
||||
await update_session_status_quo(request.session_id, pre_sq)
|
||||
session["status_quo"] = pre_sq
|
||||
|
||||
resolution_text = (pre.get("resolution_text") or "").strip()
|
||||
if resolution_text:
|
||||
await add_action_resolution(
|
||||
request.session_id,
|
||||
intent_text=request.message,
|
||||
roll=roll,
|
||||
outcome=outcome,
|
||||
resolution_text=resolution_text,
|
||||
message_id=None,
|
||||
)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
@@ -109,6 +247,7 @@ async def chat_stream(request: ChatRequest):
|
||||
full_reply = []
|
||||
|
||||
async def generate():
|
||||
nonlocal arc
|
||||
async for chunk in stream_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
):
|
||||
@@ -133,6 +272,68 @@ async def chat_stream(request: ChatRequest):
|
||||
image_prompt=prompt_str,
|
||||
)
|
||||
|
||||
# Experimental RPG: facts autosave and plot generation
|
||||
choices = []
|
||||
debug_blocks = []
|
||||
if session and session.get("rpg_enabled"):
|
||||
# generate arc once
|
||||
if not arc:
|
||||
persona = await get_persona(persona_id) or {}
|
||||
arc = await generate_plot_arc(
|
||||
persona.get("name", persona_id),
|
||||
persona.get("description", ""),
|
||||
persona.get("scenario", ""),
|
||||
persona.get("first_mes", ""),
|
||||
facts_block=facts_block,
|
||||
)
|
||||
if arc:
|
||||
from services.memory import update_session_plot_arc
|
||||
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)})
|
||||
|
||||
# event-driven beat injection
|
||||
trig = should_advance_arc(request.message)
|
||||
if trig and arc:
|
||||
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
|
||||
if beats:
|
||||
from services.memory import update_session_plot_arc
|
||||
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
||||
inj = beats[0].get("injection", "")
|
||||
if inj:
|
||||
debug_blocks.append({"type": "narrator_injection", "text": inj})
|
||||
beat_choices = beats[0].get("choices") or []
|
||||
if beat_choices:
|
||||
choices = (choices or []) + beat_choices
|
||||
# extract facts
|
||||
ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:]
|
||||
new_facts = await extract_facts(ctx)
|
||||
if new_facts:
|
||||
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
|
||||
await update_session_facts(request.session_id, merged)
|
||||
session["facts_json"] = merged
|
||||
debug_blocks.append({"type": "facts", "text": facts_to_prompt(merged)})
|
||||
|
||||
# System/Narrator post-pass: update status quo and optionally produce extra choices
|
||||
persona = await get_persona(persona_id) or {}
|
||||
ctx_txt = "\n".join(
|
||||
f"{m['role']}: {m['content']}" for m in ctx[-8:]
|
||||
if m.get("role") in ("user", "assistant")
|
||||
)
|
||||
post = await narrator_post(
|
||||
persona.get("name", persona_id),
|
||||
ctx_txt,
|
||||
json.dumps(arc, ensure_ascii=False) if arc else "",
|
||||
facts_to_prompt(session.get("facts_json", "[]")),
|
||||
)
|
||||
sq = (post.get("status_quo_update") or "").strip()
|
||||
if sq:
|
||||
await update_session_status_quo(request.session_id, sq)
|
||||
session["status_quo"] = sq
|
||||
debug_blocks.append({"type": "status_quo", "text": f"--- Status quo update ---\n{sq}\n---"})
|
||||
extra_choices = post.get("choices") or []
|
||||
if extra_choices:
|
||||
choices = (choices or []) + extra_choices
|
||||
|
||||
count = await get_message_count(request.session_id)
|
||||
if count == 2:
|
||||
title = request.message[:40] + ("…" if len(request.message) > 40 else "")
|
||||
@@ -155,6 +356,9 @@ async def chat_stream(request: ChatRequest):
|
||||
'image_prompt': prompt_str,
|
||||
'image_path': f'/static/{image_path}' if image_path else None,
|
||||
'image_error': image_error,
|
||||
'choices': choices,
|
||||
'debug': debug_blocks,
|
||||
'resolution': {'roll': roll, 'outcome': outcome, 'text': resolution_text} if (session and session.get('rpg_enabled') and resolution_text) else None,
|
||||
})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
|
||||
Reference in New Issue
Block a user