import json import os import random import aiosqlite from fastapi import APIRouter from fastapi.responses import StreamingResponse from pydantic import BaseModel from database.db import DB_PATH from models.schemas import ChatRequest, ChatResponse from services.llm import send_message, stream_message from services.memory import ( get_history, add_message, clear_history, get_or_create_session, get_session, update_session_title, get_message_count, get_last_assistant_message_id, update_message_image, update_session_facts, update_session_status_quo, add_action_resolution, ) from services.personas import get_persona from services.sd_prompt import ( generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag, ) from services.lorebook import get_lorebook_context from services.character_card import get_character from services import sdbackend as sd_service from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats from services.rpg_narrator import narrator_pre, narrator_post router = APIRouter(prefix="/chat", tags=["chat"]) DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: persona = await get_persona(persona_id) if not persona: return DEFAULT_PROMPT prompt = persona["prompt"] # persona-only lorebook if persona.get("lorebook_json"): recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] context = recent + [{"role": "user", "content": user_message}] lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context) if lore: prompt = prompt + "\n\n" + lore if persona_id.startswith("card_"): card_id = persona_id[5:] card = await get_character(card_id) if card: # match lorebook against recent context + current message recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] context = recent + [{"role": "user", "content": user_message}] lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) if lore: prompt = prompt + "\n\n" + lore return prompt @router.get("/history/{session_id}") async def get_chat_history(session_id: str): return await get_history(session_id) @router.get("/system/{session_id}") async def get_system_blob(session_id: str): history = await get_history(session_id) system_msg = next((m for m in history if m.get("role") == "system"), None) session = await get_session(session_id) return { "system_prompt": system_msg.get("content") if system_msg else "", "facts_json": session.get("facts_json") if session else "[]", "status_quo": session.get("status_quo") if session else "", "plot_arc_json": session.get("plot_arc_json") if session else "{}", "rpg_enabled": bool(session.get("rpg_enabled")) if session else False, } @router.post("/init") async def init_chat(request: ChatRequest): """Called when opening a new chat. Seeds system prompt and first_mes if card persona.""" persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) if history: return {"first_mes": None} # already initialized system_prompt = await get_system_prompt(persona_id, [], "") await add_message(request.session_id, "system", system_prompt) first_mes = None persona = await get_persona(persona_id) if persona and persona.get("first_mes"): first_mes = persona["first_mes"] await add_message(request.session_id, "assistant", first_mes) elif persona_id.startswith("card_"): card = await get_character(persona_id[5:]) if card and card.get("first_mes"): first_mes = card["first_mes"] await add_message(request.session_id, "assistant", first_mes) return {"first_mes": first_mes} class RpgBootstrapRequest(BaseModel): session_id: str persona_id: str = "default" @router.post("/rpg/bootstrap") async def rpg_bootstrap(req: RpgBootstrapRequest): """Generate plot arc early for debugging.""" await get_or_create_session(req.session_id, req.persona_id) session = await get_session(req.session_id) persona = await get_persona(req.persona_id) or {} arc_json = (session.get("plot_arc_json") or "{}") if session else "{}" try: arc = json.loads(arc_json) if isinstance(arc_json, str) else {} except Exception: arc = {} if not arc: facts_block = facts_to_prompt((session or {}).get("facts_json", "[]")) arc = await generate_plot_arc( persona.get("name", req.persona_id), persona.get("description", ""), persona.get("scenario", ""), persona.get("first_mes", ""), facts_block=facts_block, ) if arc: from services.memory import update_session_plot_arc await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False)) return {"plot_arc": arc} @router.post("/stream") async def chat_stream(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) session = await get_session(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) # Experimental RPG: inject persistent facts + global plot arc = {} if session and session.get("rpg_enabled"): facts_block = facts_to_prompt(session.get("facts_json", "[]")) if facts_block: system_prompt = system_prompt + "\n\n" + facts_block # load plot arc try: arc = json.loads(session.get("plot_arc_json") or "{}") except Exception: arc = {} if arc: system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps( {k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False ) + "\n---" status_quo = (session.get("status_quo") or "").strip() if status_quo: system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---" # d20 outcome directive roll = random.randint(1, 20) if roll == 1: outcome = "critical failure" elif roll <= 8: outcome = "failure" elif roll >= 20: outcome = "critical success" else: outcome = "success" system_prompt = ( system_prompt + f"\n\n--- Mechanics ---\n" + f"Roll d20={roll}. Outcome: {outcome}.\n" + "You MUST incorporate this outcome into the narrative result.\n" + "---" ) # System/Narrator pre-pass: add directives for the next reply + optional status quo update persona = await get_persona(persona_id) or {} recent_txt = "\n".join( f"{m['role']}: {m['content']}" for m in history[-8:] if m.get("role") in ("user", "assistant") ) pre = await narrator_pre( persona.get("name", persona_id), recent_txt, json.dumps(arc, ensure_ascii=False) if arc else "", facts_block, roll, outcome, ) directives = pre.get("directives") or [] if directives: system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---" pre_sq = (pre.get("status_quo_update") or "").strip() if pre_sq: await update_session_status_quo(request.session_id, pre_sq) session["status_quo"] = pre_sq resolution_text = (pre.get("resolution_text") or "").strip() if resolution_text: await add_action_resolution( request.session_id, intent_text=request.message, roll=roll, outcome=outcome, resolution_text=resolution_text, message_id=None, ) if not history: await add_message(request.session_id, "system", system_prompt) elif history[0]["role"] == "system" and history[0]["content"] != system_prompt: async with aiosqlite.connect(DB_PATH) as db: await db.execute( """UPDATE messages SET content = ? WHERE session_id = ? AND role = 'system' AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", (system_prompt, request.session_id, request.session_id), ) await db.commit() await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) full_reply = [] async def generate(): nonlocal arc async for chunk in stream_message( [{"role": m["role"], "content": m["content"]} for m in messages] ): full_reply.append(chunk) yield f"data: {json.dumps({'chunk': chunk})}\n\n" complete = "".join(full_reply) display_text = strip_image_prompt_tag(complete) hist_with_reply = await get_history(request.session_id) + [ {"role": "assistant", "content": display_text} ] sd_result = await generate_sd_prompt(hist_with_reply, persona_id) prompt_str = sd_result[0] if sd_result else None if not prompt_str: prompt_str = extract_image_prompt_tag(complete) await add_message( request.session_id, "assistant", display_text or complete, image_prompt=prompt_str, ) # Experimental RPG: facts autosave and plot generation choices = [] debug_blocks = [] if session and session.get("rpg_enabled"): # generate arc once if not arc: persona = await get_persona(persona_id) or {} arc = await generate_plot_arc( persona.get("name", persona_id), persona.get("description", ""), persona.get("scenario", ""), persona.get("first_mes", ""), facts_block=facts_block, ) if arc: from services.memory import update_session_plot_arc await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)}) # event-driven beat injection trig = should_advance_arc(request.message) if trig and arc: arc, beats = pop_matching_beats(arc, trig, max_beats=1) if beats: from services.memory import update_session_plot_arc await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) inj = beats[0].get("injection", "") if inj: debug_blocks.append({"type": "narrator_injection", "text": inj}) beat_choices = beats[0].get("choices") or [] if beat_choices: choices = (choices or []) + beat_choices # extract facts ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:] new_facts = await extract_facts(ctx) if new_facts: merged = merge_facts(session.get("facts_json", "[]"), new_facts) await update_session_facts(request.session_id, merged) session["facts_json"] = merged debug_blocks.append({"type": "facts", "text": facts_to_prompt(merged)}) # System/Narrator post-pass: update status quo and optionally produce extra choices persona = await get_persona(persona_id) or {} ctx_txt = "\n".join( f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant") ) post = await narrator_post( persona.get("name", persona_id), ctx_txt, json.dumps(arc, ensure_ascii=False) if arc else "", facts_to_prompt(session.get("facts_json", "[]")), ) sq = (post.get("status_quo_update") or "").strip() if sq: await update_session_status_quo(request.session_id, sq) session["status_quo"] = sq debug_blocks.append({"type": "status_quo", "text": f"--- Status quo update ---\n{sq}\n---"}) extra_choices = post.get("choices") or [] if extra_choices: choices = (choices or []) + extra_choices count = await get_message_count(request.session_id) if count == 2: title = request.message[:40] + ("…" if len(request.message) > 40 else "") await update_session_title(request.session_id, title) image_path = None image_error = None if prompt_str and SD_AUTO_GENERATE: rel, err = await sd_service.generate_from_full_prompt(prompt_str) if rel: image_path = rel msg_id = await get_last_assistant_message_id(request.session_id) if msg_id: await update_message_image(msg_id, rel) else: image_error = err yield f"data: {json.dumps({ 'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'resolution': {'roll': roll, 'outcome': outcome, 'text': resolution_text} if (session and session.get('rpg_enabled') and resolution_text) else None, })}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/", response_model=ChatResponse) async def chat(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) if not history: await add_message(request.session_id, "system", system_prompt) await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) reply = await send_message( [{"role": m["role"], "content": m["content"]} for m in messages] ) display = strip_image_prompt_tag(reply) prompt_tuple = await generate_sd_prompt(messages, persona_id) prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply) await add_message(request.session_id, "assistant", display, image_prompt=prompt_str) return ChatResponse( reply=display, session_id=request.session_id, image_prompt=prompt_str, ) @router.delete("/{session_id}") async def clear_chat(session_id: str): await clear_history(session_id) return {"status": "cleared", "session_id": session_id}