import json import os import aiosqlite from fastapi import APIRouter from fastapi.responses import StreamingResponse from database.db import DB_PATH from models.schemas import ChatRequest, ChatResponse from services.llm import send_message, stream_message from services.memory import ( get_history, add_message, clear_history, get_or_create_session, update_session_title, get_message_count, get_last_assistant_message_id, update_message_image, ) from services.personas import get_persona from services.sd_prompt import ( generate_sd_prompt, strip_image_prompt_tag, extract_image_prompt_tag, ) from services.lorebook import get_lorebook_context from services.character_card import get_character from services import sdbackend as sd_service router = APIRouter(prefix="/chat", tags=["chat"]) DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: persona = await get_persona(persona_id) if not persona: return DEFAULT_PROMPT prompt = persona["prompt"] if persona_id.startswith("card_"): card_id = persona_id[5:] card = await get_character(card_id) if card: # match lorebook against recent context + current message recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] context = recent + [{"role": "user", "content": user_message}] lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) if lore: prompt = prompt + "\n\n" + lore return prompt @router.get("/history/{session_id}") async def get_chat_history(session_id: str): return await get_history(session_id) @router.post("/init") async def init_chat(request: ChatRequest): """Called when opening a new chat. Seeds system prompt and first_mes if card persona.""" persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) if history: return {"first_mes": None} # already initialized system_prompt = await get_system_prompt(persona_id, [], "") await add_message(request.session_id, "system", system_prompt) first_mes = None if persona_id.startswith("card_"): card = await get_character(persona_id[5:]) if card and card.get("first_mes"): first_mes = card["first_mes"] await add_message(request.session_id, "assistant", first_mes) return {"first_mes": first_mes} @router.post("/stream") async def chat_stream(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) if not history: await add_message(request.session_id, "system", system_prompt) elif history[0]["role"] == "system" and history[0]["content"] != system_prompt: async with aiosqlite.connect(DB_PATH) as db: await db.execute( """UPDATE messages SET content = ? WHERE session_id = ? AND role = 'system' AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", (system_prompt, request.session_id, request.session_id), ) await db.commit() await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) full_reply = [] async def generate(): async for chunk in stream_message( [{"role": m["role"], "content": m["content"]} for m in messages] ): full_reply.append(chunk) yield f"data: {json.dumps({'chunk': chunk})}\n\n" complete = "".join(full_reply) display_text = strip_image_prompt_tag(complete) hist_with_reply = await get_history(request.session_id) + [ {"role": "assistant", "content": display_text} ] sd_result = await generate_sd_prompt(hist_with_reply, persona_id) prompt_str = sd_result[0] if sd_result else None if not prompt_str: prompt_str = extract_image_prompt_tag(complete) await add_message( request.session_id, "assistant", display_text or complete, image_prompt=prompt_str, ) count = await get_message_count(request.session_id) if count == 2: title = request.message[:40] + ("…" if len(request.message) > 40 else "") await update_session_title(request.session_id, title) image_path = None image_error = None if prompt_str and SD_AUTO_GENERATE: rel, err = await sd_service.generate_from_full_prompt(prompt_str) if rel: image_path = rel msg_id = await get_last_assistant_message_id(request.session_id) if msg_id: await update_message_image(msg_id, rel) else: image_error = err yield f"data: {json.dumps({ 'done': True, 'image_prompt': prompt_str, 'image_path': f'/static/{image_path}' if image_path else None, 'image_error': image_error, })}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @router.post("/", response_model=ChatResponse) async def chat(request: ChatRequest): persona_id = request.persona_id or "default" await get_or_create_session(request.session_id, persona_id) history = await get_history(request.session_id) system_prompt = await get_system_prompt(persona_id, history, request.message) if not history: await add_message(request.session_id, "system", system_prompt) await add_message(request.session_id, "user", request.message) messages = await get_history(request.session_id) reply = await send_message( [{"role": m["role"], "content": m["content"]} for m in messages] ) display = strip_image_prompt_tag(reply) prompt_tuple = await generate_sd_prompt(messages, persona_id) prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply) await add_message(request.session_id, "assistant", display, image_prompt=prompt_str) return ChatResponse( reply=display, session_id=request.session_id, image_prompt=prompt_str, ) @router.delete("/{session_id}") async def clear_chat(session_id: str): await clear_history(session_id) return {"status": "cleared", "session_id": session_id}