first commit
This commit is contained in:
+199
@@ -0,0 +1,199 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from database.db import DB_PATH
|
||||
from models.schemas import ChatRequest, ChatResponse
|
||||
from services.llm import send_message, stream_message
|
||||
from services.memory import (
|
||||
get_history,
|
||||
add_message,
|
||||
clear_history,
|
||||
get_or_create_session,
|
||||
update_session_title,
|
||||
get_message_count,
|
||||
get_last_assistant_message_id,
|
||||
update_message_image,
|
||||
)
|
||||
from services.personas import get_persona
|
||||
from services.sd_prompt import (
|
||||
generate_sd_prompt,
|
||||
strip_image_prompt_tag,
|
||||
extract_image_prompt_tag,
|
||||
)
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.character_card import get_character
|
||||
from services import sdbackend as sd_service
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
||||
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return DEFAULT_PROMPT
|
||||
|
||||
prompt = persona["prompt"]
|
||||
|
||||
if persona_id.startswith("card_"):
|
||||
card_id = persona_id[5:]
|
||||
card = await get_character(card_id)
|
||||
if card:
|
||||
# match lorebook against recent context + current message
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt = prompt + "\n\n" + lore
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@router.get("/history/{session_id}")
|
||||
async def get_chat_history(session_id: str):
|
||||
return await get_history(session_id)
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def init_chat(request: ChatRequest):
|
||||
"""Called when opening a new chat. Seeds system prompt and first_mes if card persona."""
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
history = await get_history(request.session_id)
|
||||
if history:
|
||||
return {"first_mes": None} # already initialized
|
||||
|
||||
system_prompt = await get_system_prompt(persona_id, [], "")
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
|
||||
first_mes = None
|
||||
if persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card and card.get("first_mes"):
|
||||
first_mes = card["first_mes"]
|
||||
await add_message(request.session_id, "assistant", first_mes)
|
||||
|
||||
return {"first_mes": first_mes}
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(system_prompt, request.session_id, request.session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
await add_message(request.session_id, "user", request.message)
|
||||
messages = await get_history(request.session_id)
|
||||
|
||||
full_reply = []
|
||||
|
||||
async def generate():
|
||||
async for chunk in stream_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
):
|
||||
full_reply.append(chunk)
|
||||
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||||
|
||||
complete = "".join(full_reply)
|
||||
display_text = strip_image_prompt_tag(complete)
|
||||
|
||||
hist_with_reply = await get_history(request.session_id) + [
|
||||
{"role": "assistant", "content": display_text}
|
||||
]
|
||||
sd_result = await generate_sd_prompt(hist_with_reply, persona_id)
|
||||
prompt_str = sd_result[0] if sd_result else None
|
||||
if not prompt_str:
|
||||
prompt_str = extract_image_prompt_tag(complete)
|
||||
|
||||
await add_message(
|
||||
request.session_id,
|
||||
"assistant",
|
||||
display_text or complete,
|
||||
image_prompt=prompt_str,
|
||||
)
|
||||
|
||||
count = await get_message_count(request.session_id)
|
||||
if count == 2:
|
||||
title = request.message[:40] + ("…" if len(request.message) > 40 else "")
|
||||
await update_session_title(request.session_id, title)
|
||||
|
||||
image_path = None
|
||||
image_error = None
|
||||
if prompt_str and SD_AUTO_GENERATE:
|
||||
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
|
||||
if rel:
|
||||
image_path = rel
|
||||
msg_id = await get_last_assistant_message_id(request.session_id)
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
else:
|
||||
image_error = err
|
||||
|
||||
yield f"data: {json.dumps({
|
||||
'done': True,
|
||||
'image_prompt': prompt_str,
|
||||
'image_path': f'/static/{image_path}' if image_path else None,
|
||||
'image_error': image_error,
|
||||
})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
|
||||
await add_message(request.session_id, "user", request.message)
|
||||
messages = await get_history(request.session_id)
|
||||
reply = await send_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
)
|
||||
display = strip_image_prompt_tag(reply)
|
||||
prompt_tuple = await generate_sd_prompt(messages, persona_id)
|
||||
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
|
||||
|
||||
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
|
||||
|
||||
return ChatResponse(
|
||||
reply=display,
|
||||
session_id=request.session_id,
|
||||
image_prompt=prompt_str,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def clear_chat(session_id: str):
|
||||
await clear_history(session_id)
|
||||
return {"status": "cleared", "session_id": session_id}
|
||||
Reference in New Issue
Block a user