540 lines
22 KiB
Python
540 lines
22 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
|
|
import aiosqlite
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from database.db import DB_PATH
|
|
from models.schemas import ChatRequest, ChatResponse, MessageEditRequest, RegenerateRequest
|
|
from services.llm import send_message, stream_message
|
|
from services.memory import (
|
|
get_history,
|
|
add_message,
|
|
clear_history,
|
|
get_or_create_session,
|
|
get_session,
|
|
update_session_title,
|
|
update_session_persona,
|
|
get_message_count,
|
|
get_last_assistant_message_id,
|
|
update_message_image,
|
|
update_session_facts,
|
|
update_session_status_quo,
|
|
update_session_affinity,
|
|
update_session_genre,
|
|
update_session_rpg_settings,
|
|
update_session_outfit,
|
|
update_session_plot_arc,
|
|
upsert_quest,
|
|
get_quests,
|
|
add_action_resolution,
|
|
get_message,
|
|
update_message_content,
|
|
delete_messages_after,
|
|
delete_message,
|
|
)
|
|
from services.personas import get_persona
|
|
from services.sd_prompt import generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag
|
|
from services.lorebook import get_lorebook_context
|
|
from services.character_card import get_character
|
|
from services import sdbackend as sd_service
|
|
from services.rpg_facts import extract_facts, merge_facts, facts_to_prompt
|
|
from services.rpg_plot import generate_plot_arc, should_advance_arc, pop_matching_beats, advance_phase
|
|
from services.rpg_narrator import narrator_pre, narrator_post
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
|
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
|
|
|
DEFAULT_RPG_SETTINGS = {"dice": True, "narrator": True, "quests": True, "affinity": True, "choices": True}
|
|
|
|
|
|
def get_rpg_settings(session: dict) -> dict:
|
|
try:
|
|
return {**DEFAULT_RPG_SETTINGS, **json.loads(session.get("rpg_settings_json") or "{}")}
|
|
except Exception:
|
|
return DEFAULT_RPG_SETTINGS
|
|
|
|
|
|
def affinity_prompt_block(affinity: int) -> str:
|
|
if affinity >= 10: tone = "very warm, trusting, affectionate"
|
|
elif affinity >= 5: tone = "friendly and open"
|
|
elif affinity >= 1: tone = "slightly positive"
|
|
elif affinity <= -5: tone = "hostile or deeply distrustful"
|
|
elif affinity <= -1: tone = "cold and wary"
|
|
else: tone = "neutral"
|
|
return f"\n\n--- Relationship ---\nAffinity toward player: {affinity} ({tone}). Reflect this in your attitude and word choice.\n---"
|
|
|
|
|
|
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
|
persona = await get_persona(persona_id)
|
|
if not persona:
|
|
return DEFAULT_PROMPT
|
|
prompt = persona["prompt"]
|
|
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
|
context = recent + [{"role": "user", "content": user_message}]
|
|
if persona.get("lorebook_json"):
|
|
lore = get_lorebook_context(persona.get("lorebook_json", "[]"), context)
|
|
if lore:
|
|
prompt += "\n\n" + lore
|
|
if persona_id.startswith("card_"):
|
|
card = await get_character(persona_id[5:])
|
|
if card:
|
|
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
|
if lore:
|
|
prompt += "\n\n" + lore
|
|
return prompt
|
|
|
|
|
|
@router.get("/history/{session_id}")
|
|
async def get_chat_history(session_id: str):
|
|
return await get_history(session_id)
|
|
|
|
|
|
@router.get("/system/{session_id}")
|
|
async def get_system_blob(session_id: str):
|
|
history = await get_history(session_id)
|
|
system_msg = next((m for m in history if m.get("role") == "system"), None)
|
|
session = await get_session(session_id)
|
|
quests = await get_quests(session_id)
|
|
return {
|
|
"system_prompt": system_msg.get("content") if system_msg else "",
|
|
"status_quo": session.get("status_quo") if session else "",
|
|
"facts_json": session.get("facts_json") if session else "[]",
|
|
"plot_arc_json": session.get("plot_arc_json") if session else "{}",
|
|
"outfit_json": session.get("outfit_json") if session else "[]",
|
|
"affinity": session.get("affinity", 0) if session else 0,
|
|
"genre": session.get("genre", "") if session else "",
|
|
"rpg_settings_json": session.get("rpg_settings_json") if session else "{}",
|
|
"rpg_enabled": bool(session.get("rpg_enabled")) if session else False,
|
|
"quests": quests,
|
|
}
|
|
|
|
|
|
@router.post("/init")
|
|
async def init_chat(request: ChatRequest):
|
|
persona_id = request.persona_id or "default"
|
|
await get_or_create_session(request.session_id, persona_id)
|
|
history = await get_history(request.session_id)
|
|
if history:
|
|
return {"first_mes": None}
|
|
|
|
system_prompt = await get_system_prompt(persona_id, [], "")
|
|
await add_message(request.session_id, "system", system_prompt)
|
|
|
|
first_mes = None
|
|
if request.first_mes_override and request.first_mes_override.strip():
|
|
first_mes = request.first_mes_override.strip()
|
|
await add_message(request.session_id, "assistant", first_mes)
|
|
else:
|
|
persona = await get_persona(persona_id)
|
|
if persona and persona.get("first_mes"):
|
|
first_mes = persona["first_mes"]
|
|
await add_message(request.session_id, "assistant", first_mes)
|
|
elif persona_id.startswith("card_"):
|
|
card = await get_character(persona_id[5:])
|
|
if card and card.get("first_mes"):
|
|
first_mes = card["first_mes"]
|
|
await add_message(request.session_id, "assistant", first_mes)
|
|
|
|
return {"first_mes": first_mes}
|
|
|
|
|
|
class RpgBootstrapRequest(BaseModel):
|
|
session_id: str
|
|
persona_id: str = "default"
|
|
genre: str = "adventure"
|
|
|
|
|
|
@router.post("/rpg/bootstrap")
|
|
async def rpg_bootstrap(req: RpgBootstrapRequest):
|
|
await get_or_create_session(req.session_id, req.persona_id)
|
|
session = await get_session(req.session_id)
|
|
persona = await get_persona(req.persona_id) or {}
|
|
|
|
# Save genre
|
|
await update_session_genre(req.session_id, req.genre)
|
|
|
|
arc_json = (session.get("plot_arc_json") or "{}") if session else "{}"
|
|
try:
|
|
arc = json.loads(arc_json) if isinstance(arc_json, str) else {}
|
|
except Exception:
|
|
arc = {}
|
|
if not arc:
|
|
facts_block = facts_to_prompt((session or {}).get("facts_json", "[]"))
|
|
arc = await generate_plot_arc(
|
|
persona.get("name", req.persona_id),
|
|
persona.get("description", ""),
|
|
persona.get("scenario", ""),
|
|
persona.get("first_mes", ""),
|
|
facts_block=facts_block,
|
|
genre=req.genre,
|
|
)
|
|
if arc:
|
|
from services.memory import update_session_plot_arc
|
|
await update_session_plot_arc(req.session_id, json.dumps(arc, ensure_ascii=False))
|
|
|
|
# Seed quests from beats
|
|
for beat in arc.get("beats", []):
|
|
title = (beat.get("title") or beat.get("injection", "")).strip()
|
|
if title:
|
|
await upsert_quest(req.session_id, title[:120])
|
|
|
|
quests = await get_quests(req.session_id)
|
|
return {"plot_arc": arc, "quests": quests}
|
|
|
|
|
|
@router.post("/stream")
|
|
async def chat_stream(request: ChatRequest):
|
|
persona_id = request.persona_id or "default"
|
|
|
|
await get_or_create_session(request.session_id, persona_id)
|
|
|
|
history = await get_history(request.session_id)
|
|
session = await get_session(request.session_id)
|
|
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
|
|
|
arc = {}
|
|
roll = None
|
|
outcome = None
|
|
resolution_text = ""
|
|
narrator_msg = None # shown as narrator bubble before assistant reply
|
|
rpg_settings = {}
|
|
|
|
if session and session.get("rpg_enabled"):
|
|
rpg_settings = get_rpg_settings(session)
|
|
facts_block = facts_to_prompt(session.get("facts_json", "[]"))
|
|
if facts_block:
|
|
system_prompt = system_prompt + "\n\n" + facts_block
|
|
try:
|
|
arc = json.loads(session.get("plot_arc_json") or "{}")
|
|
except Exception:
|
|
arc = {}
|
|
if arc:
|
|
system_prompt = system_prompt + "\n\n--- PlotArc ---\n" + json.dumps(
|
|
{k: arc.get(k) for k in ("title", "phase", "next_beat_hint")}, ensure_ascii=False
|
|
) + "\n---"
|
|
status_quo = (session.get("status_quo") or "").strip()
|
|
if status_quo:
|
|
system_prompt = system_prompt + "\n\n--- Status quo ---\n" + status_quo + "\n---"
|
|
if rpg_settings.get("affinity", True):
|
|
aff = int(session.get("affinity") or 0)
|
|
system_prompt = system_prompt + affinity_prompt_block(aff)
|
|
|
|
if rpg_settings.get("narrator", True):
|
|
persona = await get_persona(persona_id) or {}
|
|
recent_txt = "\n".join(
|
|
f"{m['role']}: {m['content']}" for m in history[-8:]
|
|
if m.get("role") in ("user", "assistant")
|
|
)
|
|
|
|
# Phase 1: ask narrator if check is needed (no roll yet)
|
|
pre = await narrator_pre(
|
|
persona.get("name", persona_id),
|
|
recent_txt,
|
|
json.dumps(arc, ensure_ascii=False) if arc else "",
|
|
facts_block,
|
|
request.message,
|
|
)
|
|
|
|
needs_check = pre.get("needs_check", False) and rpg_settings.get("dice", True)
|
|
|
|
if needs_check:
|
|
# Phase 2: roll and get resolution
|
|
roll = random.randint(1, 20)
|
|
if roll == 1:
|
|
outcome = "critical failure"
|
|
elif roll <= 8:
|
|
outcome = "failure"
|
|
elif roll >= 20:
|
|
outcome = "critical success"
|
|
else:
|
|
outcome = "success"
|
|
|
|
pre2 = await narrator_pre(
|
|
persona.get("name", persona_id),
|
|
recent_txt,
|
|
json.dumps(arc, ensure_ascii=False) if arc else "",
|
|
facts_block,
|
|
request.message,
|
|
roll=roll,
|
|
outcome=outcome,
|
|
)
|
|
resolution_text = (pre2.get("resolution_text") or "").strip()
|
|
directives = pre2.get("directives") or []
|
|
pre_sq = (pre2.get("status_quo_update") or "").strip()
|
|
else:
|
|
directives = pre.get("directives") or []
|
|
pre_sq = (pre.get("status_quo_update") or "").strip()
|
|
|
|
if directives:
|
|
system_prompt = system_prompt + "\n\n--- Narrator directives ---\n" + "\n".join(f"- {d}" for d in directives) + "\n---"
|
|
if pre_sq:
|
|
await update_session_status_quo(request.session_id, pre_sq)
|
|
session["status_quo"] = pre_sq
|
|
|
|
if resolution_text:
|
|
await add_action_resolution(
|
|
request.session_id,
|
|
intent_text=request.message,
|
|
roll=roll,
|
|
outcome=outcome,
|
|
resolution_text=resolution_text,
|
|
message_id=None,
|
|
)
|
|
narrator_msg = {"roll": roll, "outcome": outcome, "text": resolution_text}
|
|
|
|
# Inject outcome into system prompt so character reply is consistent
|
|
if roll is not None:
|
|
system_prompt = (
|
|
system_prompt
|
|
+ f"\n\n--- Mechanics ---\n"
|
|
+ f"Roll d20={roll}. Outcome: {outcome}.\n"
|
|
+ "Your reply MUST be consistent with this outcome. Do NOT contradict the narrator resolution.\n"
|
|
+ "---"
|
|
)
|
|
|
|
# is_narrator_choice: wrap message so LLM understands context
|
|
user_message_content = request.message
|
|
if request.is_narrator_choice:
|
|
user_message_content = f"[Player chose: {request.message}]"
|
|
|
|
if not history:
|
|
await add_message(request.session_id, "system", system_prompt)
|
|
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
await db.execute(
|
|
"""UPDATE messages SET content = ?
|
|
WHERE session_id = ? AND role = 'system'
|
|
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
|
(system_prompt, request.session_id, request.session_id),
|
|
)
|
|
await db.commit()
|
|
|
|
if not request.skip_user_add:
|
|
await add_message(request.session_id, "user", user_message_content)
|
|
messages = await get_history(request.session_id)
|
|
|
|
full_reply = []
|
|
|
|
async def generate():
|
|
nonlocal arc
|
|
|
|
# Send narrator BEFORE streaming so it appears above the reply
|
|
if narrator_msg:
|
|
yield f"data: {json.dumps({'narrator': narrator_msg})}\n\n"
|
|
|
|
try:
|
|
async for chunk in stream_message(
|
|
[{"role": m["role"], "content": m["content"]} for m in messages]
|
|
):
|
|
full_reply.append(chunk)
|
|
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
|
except Exception as e:
|
|
logger.error("stream_message failed: %s", e)
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
return
|
|
|
|
complete = "".join(full_reply)
|
|
display_text = strip_image_prompt_tag(complete)
|
|
|
|
hist_with_reply = await get_history(request.session_id) + [
|
|
{"role": "assistant", "content": display_text}
|
|
]
|
|
sd_result = await generate_sd_prompt(
|
|
hist_with_reply, persona_id,
|
|
outfit_json=session.get("outfit_json", "[]") if session else "[]"
|
|
)
|
|
prompt_str = (sd_result[0] if sd_result and sd_result[0] else None) or extract_image_prompt_tag(complete)
|
|
|
|
if (display_text or complete).strip():
|
|
await add_message(request.session_id, "assistant", display_text or complete, image_prompt=prompt_str)
|
|
|
|
choices = []
|
|
debug_blocks = []
|
|
quests_updated = []
|
|
|
|
if session and session.get("rpg_enabled"):
|
|
if not arc:
|
|
persona = await get_persona(persona_id) or {}
|
|
arc = await generate_plot_arc(
|
|
persona.get("name", persona_id),
|
|
persona.get("description", ""),
|
|
persona.get("scenario", ""),
|
|
persona.get("first_mes", ""),
|
|
facts_block=facts_to_prompt(session.get("facts_json", "[]")),
|
|
genre=session.get("genre") or "adventure",
|
|
)
|
|
if arc:
|
|
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
|
debug_blocks.append({"type": "plot_arc", "text": json.dumps(arc, ensure_ascii=False, indent=2)})
|
|
if rpg_settings.get("quests", True):
|
|
for beat in arc.get("beats", []):
|
|
t = (beat.get("title") or beat.get("injection", "")).strip()
|
|
if t:
|
|
await upsert_quest(request.session_id, t[:120])
|
|
|
|
trig = should_advance_arc(request.message)
|
|
if trig and arc:
|
|
arc, beats = pop_matching_beats(arc, trig, max_beats=1)
|
|
if beats:
|
|
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
|
inj = beats[0].get("injection", "")
|
|
if inj:
|
|
debug_blocks.append({"type": "narrator_injection", "text": inj})
|
|
if rpg_settings.get("choices", True):
|
|
choices += beats[0].get("choices") or []
|
|
if advance_phase(arc):
|
|
await update_session_plot_arc(request.session_id, json.dumps(arc, ensure_ascii=False))
|
|
debug_blocks.append({"type": "phase_advance", "text": arc["phase"]})
|
|
|
|
ctx = [m for m in (await get_history(request.session_id)) if m["role"] in ("user", "assistant")][-10:]
|
|
new_facts = await extract_facts(ctx)
|
|
if new_facts:
|
|
merged = merge_facts(session.get("facts_json", "[]"), new_facts)
|
|
await update_session_facts(request.session_id, merged)
|
|
session["facts_json"] = merged
|
|
|
|
persona = await get_persona(persona_id) or {}
|
|
ctx_txt = "\n".join(f"{m['role']}: {m['content']}" for m in ctx[-8:] if m.get("role") in ("user", "assistant"))
|
|
post = await narrator_post(
|
|
persona.get("name", persona_id),
|
|
ctx_txt,
|
|
json.dumps(arc, ensure_ascii=False) if arc else "",
|
|
facts_to_prompt(session.get("facts_json", "[]")),
|
|
)
|
|
|
|
sq = (post.get("status_quo_update") or "").strip()
|
|
if sq:
|
|
await update_session_status_quo(request.session_id, sq)
|
|
debug_blocks.append({"type": "status_quo", "text": sq})
|
|
|
|
if rpg_settings.get("choices", True):
|
|
choices += post.get("choices") or []
|
|
|
|
if rpg_settings.get("affinity", True):
|
|
delta = int(post.get("affinity_delta") or 0)
|
|
if delta:
|
|
await update_session_affinity(request.session_id, delta)
|
|
|
|
outfit_update = post.get("outfit_update")
|
|
if isinstance(outfit_update, list) and outfit_update:
|
|
outfit_str = json.dumps(outfit_update, ensure_ascii=False)
|
|
await update_session_outfit(request.session_id, outfit_str)
|
|
session["outfit_json"] = outfit_str
|
|
|
|
if rpg_settings.get("quests", True):
|
|
for qu in (post.get("quest_updates") or []):
|
|
t = (qu.get("title") or "").strip()
|
|
if t:
|
|
await upsert_quest(request.session_id, t[:120], qu.get("status", "active"))
|
|
quests_updated = await get_quests(request.session_id)
|
|
|
|
count = await get_message_count(request.session_id)
|
|
if count == 2 and not request.skip_user_add:
|
|
persona = await get_persona(persona_id) or {}
|
|
preview = request.message[:40] + ("…" if len(request.message) > 40 else "")
|
|
if (session or {}).get("title", "Новый чат") in ("", "Новый чат"):
|
|
await update_session_title(request.session_id, f"{persona.get('name', persona_id)} — {preview}")
|
|
|
|
image_path = None
|
|
image_error = None
|
|
if prompt_str and SD_AUTO_GENERATE:
|
|
yield f"data: {json.dumps({'image_generating': True, 'image_prompt': prompt_str})}\n\n"
|
|
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
|
|
if rel:
|
|
image_path = rel
|
|
msg_id = await get_last_assistant_message_id(request.session_id)
|
|
if msg_id:
|
|
await update_message_image(msg_id, rel)
|
|
else:
|
|
image_error = err
|
|
|
|
updated_session = await get_session(request.session_id)
|
|
affinity = updated_session.get("affinity", 0) if updated_session else 0
|
|
|
|
yield f"data: {json.dumps({'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, 'choices': choices, 'debug': debug_blocks, 'affinity': affinity, 'quests': quests_updated})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
@router.post("/", response_model=ChatResponse)
|
|
async def chat(request: ChatRequest):
|
|
persona_id = request.persona_id or "default"
|
|
await get_or_create_session(request.session_id, persona_id)
|
|
|
|
history = await get_history(request.session_id)
|
|
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
|
|
|
if not history:
|
|
await add_message(request.session_id, "system", system_prompt)
|
|
|
|
await add_message(request.session_id, "user", request.message)
|
|
messages = await get_history(request.session_id)
|
|
reply = await send_message(
|
|
[{"role": m["role"], "content": m["content"]} for m in messages]
|
|
)
|
|
display = strip_image_prompt_tag(reply)
|
|
prompt_tuple = await generate_sd_prompt(messages, persona_id)
|
|
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
|
|
|
|
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
|
|
|
|
return ChatResponse(
|
|
reply=display,
|
|
session_id=request.session_id,
|
|
image_prompt=prompt_str,
|
|
)
|
|
|
|
|
|
@router.patch("/messages/{message_id}")
|
|
async def edit_message(message_id: int, req: MessageEditRequest):
|
|
msg = await get_message(message_id)
|
|
if not msg:
|
|
raise HTTPException(status_code=404, detail="Сообщение не найдено")
|
|
await update_message_content(message_id, req.content)
|
|
if req.truncate_after:
|
|
await delete_messages_after(msg["session_id"], message_id)
|
|
return {"status": "updated", "message_id": message_id}
|
|
|
|
|
|
@router.post("/regenerate")
|
|
async def regenerate_chat(req: RegenerateRequest):
|
|
msg_id = req.message_id or await get_last_assistant_message_id(req.session_id)
|
|
if not msg_id:
|
|
raise HTTPException(status_code=400, detail="Нет сообщения для перегенерации")
|
|
msg = await get_message(msg_id)
|
|
if not msg or msg.get("role") != "assistant":
|
|
raise HTTPException(status_code=400, detail="Неверное сообщение")
|
|
await delete_message(msg_id)
|
|
history = await get_history(req.session_id)
|
|
last_user = next((m for m in reversed(history) if m["role"] == "user"), None)
|
|
if not last_user:
|
|
raise HTTPException(status_code=400, detail="Нет сообщения пользователя")
|
|
user_text = last_user["content"]
|
|
if user_text.startswith("[Player chose: ") and user_text.endswith("]"):
|
|
user_text = user_text[15:-1]
|
|
stream_req = ChatRequest(
|
|
message=user_text,
|
|
session_id=req.session_id,
|
|
persona_id=req.persona_id,
|
|
skip_user_add=True,
|
|
)
|
|
return await chat_stream(stream_req)
|
|
|
|
|
|
@router.delete("/{session_id}")
|
|
async def clear_chat(session_id: str):
|
|
await clear_history(session_id)
|
|
return {"status": "cleared", "session_id": session_id}
|