Files
ChatAIBot/routers/chat.py
T
2026-06-01 07:44:38 +03:00

540 lines
22 KiB
Python

import json
import logging
import os
import random
import aiosqlite
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from database.db import DB_PATH
from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest
from services.llm import send_message, stream_message
from services.memory import (
get_history,
add_message,
clear_history,
get_or_create_session,
get_session,
update_session_title,
update_session_persona,
get_message_count,
get_last_assistant_message_id,
update_message_image,
update_session_facts,
update_session_status_quo,
update_session_affinity,
update_session_genre,
update_session_rpg_settings,
update_session_outfit,
update_session_plot_arc,
upsert_quest,
get_quests,
add_action_resolution,
get_message,
update_message_content,
delete_messages_after,
delete_message,
)
from services.personas import get_persona
from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag
from services.lorebook import get_lorebook_context
from services.character_card import get_character
from services import sdbackend as sd_service
from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt
from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase
from services.rpg_narrator import narrator_pre, narrator_post
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True}
def get_rpg_settings(session: dict) -> dict:
try:
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
except Exception:
return DEFAULT_RPG_SETTINGS
def affinity_prompt_block(affinity: int) -> str:
if affinity >= 10: tone = "very warm, trusting, affectionate"
elif affinity >= 5: tone = "friendly and open"
elif affinity >= 1: tone = "slightly positive"
elif affinity <= -5: tone = "hostile or deeply distrustful"
elif affinity <= -1: tone = "cold and wary"
else: tone = "neutral"
return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---"
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
persona = await get_persona(persona_id)
if not persona:
return DEFAULT_PROMPT
prompt = persona["prompt"]
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
context = recent + [{"role": "user", "content": user_message}]
if persona.get("lorebook_json"):
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
if persona_id.startswith("card_"):
card = await get_character(persona_id[5:])
if card:
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
if lore:
prompt += "\n\n" + lore
return prompt
@router.get("/history/{session_id}")
async def get_chat_history(session_id: str):
return await get_history(session_id)
@router.get("/system/{session_id}")
async def get_system_blob(session_id: str):
history = await get_history(session_id)
system_msg = next((m for m in history if m.get("role") == "system"), None)
session = await get_session(session_id)
quests = await get_quests(session_id)
return {
"system_prompt": system_msg.get("content") if system_msg else "",
"status_quo": session.get("status_quo") if session else "",
"facts_json": session.get("facts_json") if session else "[]",
"plot_arc_json": session.get("plot_arc_json") if session else "{}",
"outfit_json": session.get("outfit_json") if session else "[]",
"affinity": session.get("affinity", 0) if session else 0,
"genre": session.get("genre", "") if session else "",
"rpg_settings_json": session.get("rpg_settings_json") if session else "{}",
"rpg_enabled": bool(session.get("rpg_enabled")) if session else False,
"quests": quests,
}
@router.post("/init")
async def init_chat(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
if history:
return {"first_mes": None}
system_prompt = await get_system_prompt(persona_id, [], "")
await add_message(request.session_id, "system", system_prompt)
first_mes = None
if request.first_mes_override and request.first_mes_override.strip():
first_mes = request.first_mes_override.strip()
await add_message(request.session_id, "assistant", first_mes)
else:
persona = await get_persona(persona_id)
if persona and persona.get("first_mes"):
first_mes = persona["first_mes"]
await add_message(request.session_id, "assistant", first_mes)
elif persona_id.startswith("card_"):
card = await get_character(persona_id[5:])
if card and card.get("first_mes"):
first_mes = card["first_mes"]
await add_message(request.session_id, "assistant", first_mes)
return {"first_mes": first_mes}
class RpgBootstrapRequest(BaseModel):
session_id: str
persona_id: str = "default"
genre: str = "adventure"
@router.post("/rpg/bootstrap")
async def rpg_bootstrap(req: RpgBootstrapRequest):
await get_or_create_session(req.session_id, req.persona_id)
session = await get_session(req.session_id)
persona = await get_persona(req.persona_id) or {}
# Save genre
await update_session_genre(req.session_id, req.genre)
arc_json = (session.get("plot_arc_json") or "{}") if session else "{}"
try:
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
except Exception:
arc = {}
if not arc:
facts_block = facts_to_prompt((session or {}).get("facts_json", "[]"))
arc = await generate_plot_arc(
persona.get("name", req.persona_id),
persona.get("description", ""),
persona.get("scenario", ""),
persona.get("first_mes", ""),
facts_block=facts_block,
genre=req.genre,
)
if arc:
from services.memory import update_session_plot_arc
await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False))
# Seed quests from beats
for beat in arc.get("beats", []):
title = (beat.get("title") or beat.get("injection", "")).strip()
if title:
await upsert_quest(req.session_id, title[:120])
quests = await get_quests(req.session_id)
return {"plot_arc": arc, "quests": quests}
@router.post("/stream")
async def chat_stream(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
session = await get_session(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
arc = {}
roll = None
outcome = None
resolution_text = ""
narrator_msg = None # shown as narrator bubble before assistant reply
rpg_settings = {}
if session and session.get("rpg_enabled"):
rpg_settings = get_rpg_settings(session)
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
if facts_block:
system_prompt = system_prompt + "\n\n" + facts_block
try:
arc = json.loads(session.get("plot_arc_json") or "{}")
except Exception:
arc = {}
if arc:
system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps(
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
) + "\n---"
status_quo = (session.get("status_quo") or "").strip()
if status_quo:
system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---"
if rpg_settings.get("affinity", True):
aff = int(session.get("affinity") or 0)
system_prompt = system_prompt + affinity_prompt_block(aff)
if rpg_settings.get("narrator", True):
persona = await get_persona(persona_id) or {}
recent_txt = "\n".join(
f"{m['role']}: {m['content']}" for m in history[-8:]
if m.get("role") in ("user", "assistant")
)
# Phase 1: ask narrator if check is needed (no roll yet)
pre = await narrator_pre(
persona.get("name", persona_id),
recent_txt,
json.dumps(arc, ensure_ascii=False) if arc else "",
facts_block,
request.message,
)
needs_check = pre.get("needs_check", False) and rpg_settings.get("dice", True)
if needs_check:
# Phase 2: roll and get resolution
roll = random.randint(1, 20)
if roll == 1:
outcome = "critical failure"
elif roll <= 8:
outcome = "failure"
elif roll >= 20:
outcome = "critical success"
else:
outcome = "success"
pre2 = await narrator_pre(
persona.get("name", persona_id),
recent_txt,
json.dumps(arc, ensure_ascii=False) if arc else "",
facts_block,
request.message,
roll=roll,
outcome=outcome,
)
resolution_text = (pre2.get("resolution_text") or "").strip()
directives = pre2.get("directives") or []
pre_sq = (pre2.get("status_quo_update") or "").strip()
else:
directives = pre.get("directives") or []
pre_sq = (pre.get("status_quo_update") or "").strip()
if directives:
system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
if pre_sq:
await update_session_status_quo(request.session_id, pre_sq)
session["status_quo"] = pre_sq
if resolution_text:
await add_action_resolution(
request.session_id,
intent_text=request.message,
roll=roll,
outcome=outcome,
resolution_text=resolution_text,
message_id=None,
)
narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text}
# Inject outcome into system prompt so character reply is consistent
if roll is not None:
system_prompt = (
system_prompt
+ f"\n\n--- Mechanics ---\n"
+ f"Roll d20={roll}. Outcome: {outcome}.\n"
+ "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n"
+ "---"
)
# is_narrator_choice: wrap message so LLM understands context
user_message_content = request.message
if request.is_narrator_choice:
user_message_content = f"[Player chose: {request.message}]"
if not history:
await add_message(request.session_id, "system", system_prompt)
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""UPDATE messages SET content = ?
WHERE session_id = ? AND role = 'system'
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
(system_prompt, request.session_id, request.session_id),
)
await db.commit()
if not request.skip_user_add:
await add_message(request.session_id, "user", user_message_content)
messages = await get_history(request.session_id)
full_reply = []
async def generate():
nonlocal arc
# Send narrator BEFORE streaming so it appears above the reply
if narrator_msg:
yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n"
try:
async for chunk in stream_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
):
full_reply.append(chunk)
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
except Exception as e:
logger.error("stream_message failed: %s", e)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
complete = "".join(full_reply)
display_text = strip_image_prompt_tag(complete)
hist_with_reply = await get_history(request.session_id) + [
{"role": "assistant", "content": display_text}
]
sd_result = await generate_sd_prompt(
hist_with_reply, persona_id,
outfit_json=session.get("outfit_json", "[]") if session else "[]"
)
prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete)
if (display_text or complete).strip():
await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str)
choices = []
debug_blocks = []
quests_updated = []
if session and session.get("rpg_enabled"):
if not arc:
persona = await get_persona(persona_id) or {}
arc = await generate_plot_arc(
persona.get("name", persona_id),
persona.get("description", ""),
persona.get("scenario", ""),
persona.get("first_mes", ""),
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
genre=session.get("genre") or "adventure",
)
if arc:
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)})
if rpg_settings.get("quests", True):
for beat in arc.get("beats", []):
t = (beat.get("title") or beat.get("injection", "")).strip()
if t:
await upsert_quest(request.session_id, t[:120])
trig = should_advance_arc(request.message)
if trig and arc:
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
if beats:
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
inj = beats[0].get("injection", "")
if inj:
debug_blocks.append({"type": "narrator_injection", "text": inj})
if rpg_settings.get("choices", True):
choices += beats[0].get("choices") or []
if advance_phase(arc):
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:]
new_facts = await extract_facts(ctx)
if new_facts:
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
await update_session_facts(request.session_id, merged)
session["facts_json"] = merged
persona = await get_persona(persona_id) or {}
ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant"))
post = await narrator_post(
persona.get("name", persona_id),
ctx_txt,
json.dumps(arc, ensure_ascii=False) if arc else "",
facts_to_prompt(session.get("facts_json", "[]")),
)
sq = (post.get("status_quo_update") or "").strip()
if sq:
await update_session_status_quo(request.session_id, sq)
debug_blocks.append({"type": "status_quo", "text": sq})
if rpg_settings.get("choices", True):
choices += post.get("choices") or []
if rpg_settings.get("affinity", True):
delta = int(post.get("affinity_delta") or 0)
if delta:
await update_session_affinity(request.session_id, delta)
outfit_update = post.get("outfit_update")
if isinstance(outfit_update, list) and outfit_update:
outfit_str = json.dumps(outfit_update, ensure_ascii=False)
await update_session_outfit(request.session_id, outfit_str)
session["outfit_json"] = outfit_str
if rpg_settings.get("quests", True):
for qu in (post.get("quest_updates") or []):
t = (qu.get("title") or "").strip()
if t:
await upsert_quest(request.session_id, t[:120], qu.get("status", "active"))
quests_updated = await get_quests(request.session_id)
count = await get_message_count(request.session_id)
if count == 2 and not request.skip_user_add:
persona = await get_persona(persona_id) or {}
preview = request.message[:40] + ("" if len(request.message) > 40 else "")
if (session or {}).get("title", "Новый чат") in ("", "Новый чат"):
await update_session_title(request.session_id, f"{persona.get('name', persona_id)}{preview}")
image_path = None
image_error = None
if prompt_str and SD_AUTO_GENERATE:
yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n"
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
if rel:
image_path = rel
msg_id = await get_last_assistant_message_id(request.session_id)
if msg_id:
await update_message_image(msg_id, rel)
else:
image_error = err
updated_session = await get_session(request.session_id)
affinity = updated_session.get("affinity", 0) if updated_session else 0
yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/", response_model=ChatResponse)
async def chat(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
if not history:
await add_message(request.session_id, "system", system_prompt)
await add_message(request.session_id, "user", request.message)
messages = await get_history(request.session_id)
reply = await send_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
)
display = strip_image_prompt_tag(reply)
prompt_tuple = await generate_sd_prompt(messages, persona_id)
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
return ChatResponse(
reply=display,
session_id=request.session_id,
image_prompt=prompt_str,
)
@router.patch("/messages/{message_id}")
async def edit_message(message_id: int, req: MessageEditRequest):
msg = await get_message(message_id)
if not msg:
raise HTTPException(status_code=404, detail="Сообщение не найдено")
await update_message_content(message_id, req.content)
if req.truncate_after:
await delete_messages_after(msg["session_id"], message_id)
return {"status": "updated", "message_id": message_id}
@router.post("/regenerate")
async def regenerate_chat(req: RegenerateRequest):
msg_id = req.message_id or await get_last_assistant_message_id(req.session_id)
if not msg_id:
raise HTTPException(status_code=400, detail="Нет сообщения для перегенерации")
msg = await get_message(msg_id)
if not msg or msg.get("role") != "assistant":
raise HTTPException(status_code=400, detail="Неверное сообщение")
await delete_message(msg_id)
history = await get_history(req.session_id)
last_user = next((m for m in reversed(history) if m["role"] == "user"), None)
if not last_user:
raise HTTPException(status_code=400, detail="Нет сообщения пользователя")
user_text = last_user["content"]
if user_text.startswith("[Player chose: ") and user_text.endswith("]"):
user_text = user_text[15:-1]
stream_req = ChatRequest(
message=user_text,
session_id=req.session_id,
persona_id=req.persona_id,
skip_user_add=True,
)
return await chat_stream(stream_req)
@router.delete("/{session_id}")
async def clear_chat(session_id: str):
await clear_history(session_id)
return {"status": "cleared", "session_id": session_id}