import json import logging import os import random import aiosqlite from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from database.db import DB_PATH from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest from services.llm import send_message, stream_message from services.memory import ( get_history, add_message, clear_history, get_or_create_session, get_session, update_session_title, update_session_persona, get_message_count, get_last_assistant_message_id, update_message_image, update_session_facts, update_session_status_quo, update_session_affinity, update_session_genre, update_session_rpg_settings, update_session_outfit, update_session_plot_arc, upsert_quest, get_quests, add_action_resolution, get_message, update_message_content, delete_messages_after, delete_message, ) from services.personas import get_persona from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag from services.lorebook import get_lorebook_context from services.character_card import get_character from services import sdbackend as sd_service from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase from services.rpg_narrator import narrator_pre, narrator_post logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["chat"]) DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True} def get_rpg_settings(session: dict) -> dict: try: return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")} except Exception: return DEFAULT_RPG_SETTINGS def affinity_prompt_block(affinity: int) -> str: if affinity >= 10: tone = "very warm, trusting, affectionate" elif affinity >= 5: tone = "friendly and open" elif affinity >= 1: tone = "slightly positive" elif affinity <= -5: tone = "hostile or deeply distrustful" elif affinity <= -1: tone = "cold and wary" else: tone = "neutral" return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---" async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: persona = await get_persona(persona_id) if not persona: return DEFAULT_PROMPT prompt = persona["prompt"] recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] context = recent + [{"role": "user", "content": user_message}] if persona.get("lorebook_json"): lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context) if lore: prompt += "\n\n" + lore if persona_id.startswith("card_"): card = await get_character(persona_id[5:]) if card: lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) if lore: prompt += "\n\n" + lore return prompt @router.get("/history/{session_id}") async def get_chat_history(session_id: str): return await get_history(session_id) @router.get("/system/{session_id}") async def get_system_blob(session_id: str): history = await get_history(session_id) system_msg = next((m for m in history if m.get("role") == "system"), None) session = await get_session(session_id) quests = await get_quests(session_id) return { "system_prompt": system_msg.get("content") if system_msg else "", "status_quo": session.get("status_quo") if session else "", "facts_json": session.get("facts_json") if session else "[]", "plot_arc_json": session.get("plot_arc_json") if session else "{}", "outfit_json": session.get("outfit_json") if session else "[]", "affinity": session.get("affinity", 0) if session else 0, "genre": session.get("genre", "") if session else "", "rpg_settings_json": session.get("rpg_settings_json") if session else "{}", "rpg_enabled": bool(session.get("rpg_enabled")) if session else False, "quests": quests, } @router.post("/init") async def init_chat(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) if history: return {"first_mes": None} system_prompt = await get_system_prompt(persona_id, [], "") await add_message(request.session_id, "system", system_prompt) first_mes = None if request.first_mes_override and request.first_mes_override.strip(): first_mes = request.first_mes_override.strip() await add_message(request.session_id, "assistant", first_mes) else: persona = await get_persona(persona_id) if persona and persona.get("first_mes"): first_mes = persona["first_mes"] await add_message(request.session_id, "assistant", first_mes) elif persona_id.startswith("card_"): card = await get_character(persona_id[5:]) if card and card.get("first_mes"): first_mes = card["first_mes"] await add_message(request.session_id, "assistant", first_mes) return {"first_mes": first_mes} class RpgBootstrapRequest(BaseModel): session_id: str persona_id: str = "default" genre: str = "adventure" @router.post("/rpg/bootstrap") async def rpg_bootstrap(req: RpgBootstrapRequest): await get_or_create_session(req.session_id, req.persona_id) session = await get_session(req.session_id) persona = await get_persona(req.persona_id) or {} # Save genre await update_session_genre(req.session_id, req.genre) arc_json = (session.get("plot_arc_json") or "{}") if session else "{}" try: arc = json.loads(arc_json) if isinstance(arc_json, str) else {} except Exception: arc = {} if not arc: facts_block = facts_to_prompt((session or {}).get("facts_json", "[]")) arc = await generate_plot_arc( persona.get("name", req.persona_id), persona.get("description", ""), persona.get("scenario", ""), persona.get("first_mes", ""), facts_block=facts_block, genre=req.genre, ) if arc: from services.memory import update_session_plot_arc await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False)) # Seed quests from beats for beat in arc.get("beats", []): title = (beat.get("title") or beat.get("injection", "")).strip() if title: await upsert_quest(req.session_id, title[:120]) quests = await get_quests(req.session_id) return {"plot_arc": arc, "quests": quests} @router.post("/stream") async def chat_stream(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) session = await get_session(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) arc = {} roll = None outcome = None resolution_text = "" narrator_msg = None # shown as narrator bubble before assistant reply rpg_settings = {} if session and session.get("rpg_enabled"): rpg_settings = get_rpg_settings(session) facts_block = facts_to_prompt(session.get("facts_json", "[]")) if facts_block: system_prompt = system_prompt + "\n\n" + facts_block try: arc = json.loads(session.get("plot_arc_json") or "{}") except Exception: arc = {} if arc: system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps( {k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False ) + "\n---" status_quo = (session.get("status_quo") or "").strip() if status_quo: system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---" if rpg_settings.get("affinity", True): aff = int(session.get("affinity") or 0) system_prompt = system_prompt + affinity_prompt_block(aff) if rpg_settings.get("narrator", True): persona = await get_persona(persona_id) or {} recent_txt = "\n".join( f"{m['role']}: {m['content']}" for m in history[-8:] if m.get("role") in ("user", "assistant") ) # Phase 1: ask narrator if check is needed (no roll yet) pre = await narrator_pre( persona.get("name", persona_id), recent_txt, json.dumps(arc, ensure_ascii=False) if arc else "", facts_block, request.message, ) needs_check = pre.get("needs_check", False) and rpg_settings.get("dice", True) if needs_check: # Phase 2: roll and get resolution roll = random.randint(1, 20) if roll == 1: outcome = "critical failure" elif roll <= 8: outcome = "failure" elif roll >= 20: outcome = "critical success" else: outcome = "success" pre2 = await narrator_pre( persona.get("name", persona_id), recent_txt, json.dumps(arc, ensure_ascii=False) if arc else "", facts_block, request.message, roll=roll, outcome=outcome, ) resolution_text = (pre2.get("resolution_text") or "").strip() directives = pre2.get("directives") or [] pre_sq = (pre2.get("status_quo_update") or "").strip() else: directives = pre.get("directives") or [] pre_sq = (pre.get("status_quo_update") or "").strip() if directives: system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---" if pre_sq: await update_session_status_quo(request.session_id, pre_sq) session["status_quo"] = pre_sq if resolution_text: await add_action_resolution( request.session_id, intent_text=request.message, roll=roll, outcome=outcome, resolution_text=resolution_text, message_id=None, ) narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text} # Inject outcome into system prompt so character reply is consistent if roll is not None: system_prompt = ( system_prompt + f"\n\n--- Mechanics ---\n" + f"Roll d20={roll}. Outcome: {outcome}.\n" + "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n" + "---" ) # is_narrator_choice: wrap message so LLM understands context user_message_content = request.message if request.is_narrator_choice: user_message_content = f"[Player chose: {request.message}]" if not history: await add_message(request.session_id, "system", system_prompt) elif history[0]["role"] == "system" and history[0]["content"] != system_prompt: async with aiosqlite.connect(DB_PATH) as db: await db.execute( """UPDATE messages SET content = ? WHERE session_id = ? AND role = 'system' AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", (system_prompt, request.session_id, request.session_id), ) await db.commit() if not request.skip_user_add: await add_message(request.session_id, "user", user_message_content) messages = await get_history(request.session_id) full_reply = [] async def generate(): nonlocal arc # Send narrator BEFORE streaming so it appears above the reply if narrator_msg: yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n" try: async for chunk in stream_message( [{"role": m["role"], "content": m["content"]} for m in messages] ): full_reply.append(chunk) yield f"data: {json.dumps({'chunk': chunk})}\n\n" except Exception as e: logger.error("stream_message failed: %s", e) yield f"data: {json.dumps({'error': str(e)})}\n\n" return complete = "".join(full_reply) display_text = strip_image_prompt_tag(complete) hist_with_reply = await get_history(request.session_id) + [ {"role": "assistant", "content": display_text} ] sd_result = await generate_sd_prompt( hist_with_reply, persona_id, outfit_json=session.get("outfit_json", "[]") if session else "[]" ) prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete) if (display_text or complete).strip(): await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str) choices = [] debug_blocks = [] quests_updated = [] if session and session.get("rpg_enabled"): if not arc: persona = await get_persona(persona_id) or {} arc = await generate_plot_arc( persona.get("name", persona_id), persona.get("description", ""), persona.get("scenario", ""), persona.get("first_mes", ""), facts_block=facts_to_prompt(session.get("facts_json", "[]")), genre=session.get("genre") or "adventure", ) if arc: await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)}) if rpg_settings.get("quests", True): for beat in arc.get("beats", []): t = (beat.get("title") or beat.get("injection", "")).strip() if t: await upsert_quest(request.session_id, t[:120]) trig = should_advance_arc(request.message) if trig and arc: arc, beats = pop_matching_beats(arc, trig, max_beats=1) if beats: await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) inj = beats[0].get("injection", "") if inj: debug_blocks.append({"type": "narrator_injection", "text": inj}) if rpg_settings.get("choices", True): choices += beats[0].get("choices") or [] if advance_phase(arc): await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False)) debug_blocks.append({"type": "phase_advance", "text": arc["phase"]}) ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:] new_facts = await extract_facts(ctx) if new_facts: merged = merge_facts(session.get("facts_json", "[]"), new_facts) await update_session_facts(request.session_id, merged) session["facts_json"] = merged persona = await get_persona(persona_id) or {} ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant")) post = await narrator_post( persona.get("name", persona_id), ctx_txt, json.dumps(arc, ensure_ascii=False) if arc else "", facts_to_prompt(session.get("facts_json", "[]")), ) sq = (post.get("status_quo_update") or "").strip() if sq: await update_session_status_quo(request.session_id, sq) debug_blocks.append({"type": "status_quo", "text": sq}) if rpg_settings.get("choices", True): choices += post.get("choices") or [] if rpg_settings.get("affinity", True): delta = int(post.get("affinity_delta") or 0) if delta: await update_session_affinity(request.session_id, delta) outfit_update = post.get("outfit_update") if isinstance(outfit_update, list) and outfit_update: outfit_str = json.dumps(outfit_update, ensure_ascii=False) await update_session_outfit(request.session_id, outfit_str) session["outfit_json"] = outfit_str if rpg_settings.get("quests", True): for qu in (post.get("quest_updates") or []): t = (qu.get("title") or "").strip() if t: await upsert_quest(request.session_id, t[:120], qu.get("status", "active")) quests_updated = await get_quests(request.session_id) count = await get_message_count(request.session_id) if count == 2 and not request.skip_user_add: persona = await get_persona(persona_id) or {} preview = request.message[:40] + ("…" if len(request.message) > 40 else "") if (session or {}).get("title", "Новый чат") in ("", "Новый чат"): await update_session_title(request.session_id, f"{persona.get('name', persona_id)} — {preview}") image_path = None image_error = None if prompt_str and SD_AUTO_GENERATE: yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n" rel, err = await sd_service.generate_from_full_prompt(prompt_str) if rel: image_path = rel msg_id = await get_last_assistant_message_id(request.session_id) if msg_id: await update_message_image(msg_id, rel) else: image_error = err updated_session = await get_session(request.session_id) affinity = updated_session.get("affinity", 0) if updated_session else 0 yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/", response_model=ChatResponse) async def chat(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) if not history: await add_message(request.session_id, "system", system_prompt) await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) reply = await send_message( [{"role": m["role"], "content": m["content"]} for m in messages] ) display = strip_image_prompt_tag(reply) prompt_tuple = await generate_sd_prompt(messages, persona_id) prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply) await add_message(request.session_id, "assistant", display, image_prompt=prompt_str) return ChatResponse( reply=display, session_id=request.session_id, image_prompt=prompt_str, ) @router.patch("/messages/{message_id}") async def edit_message(message_id: int, req: MessageEditRequest): msg = await get_message(message_id) if not msg: raise HTTPException(status_code=404, detail="Сообщение не найдено") await update_message_content(message_id, req.content) if req.truncate_after: await delete_messages_after(msg["session_id"], message_id) return {"status": "updated", "message_id": message_id} @router.post("/regenerate") async def regenerate_chat(req: RegenerateRequest): msg_id = req.message_id or await get_last_assistant_message_id(req.session_id) if not msg_id: raise HTTPException(status_code=400, detail="Нет сообщения для перегенерации") msg = await get_message(msg_id) if not msg or msg.get("role") != "assistant": raise HTTPException(status_code=400, detail="Неверное сообщение") await delete_message(msg_id) history = await get_history(req.session_id) last_user = next((m for m in reversed(history) if m["role"] == "user"), None) if not last_user: raise HTTPException(status_code=400, detail="Нет сообщения пользователя") user_text = last_user["content"] if user_text.startswith("[Player chose: ") and user_text.endswith("]"): user_text = user_text[15:-1] stream_req = ChatRequest( message=user_text, session_id=req.session_id, persona_id=req.persona_id, skip_user_add=True, ) return await chat_stream(stream_req) @router.delete("/{session_id}") async def clear_chat(session_id: str): await clear_history(session_id) return {"status": "cleared", "session_id": session_id}