first commit

This commit is contained in:
Grigo
2026-05-28 08:42:46 +03:00
commit e5c0df308f
38 changed files with 2753 additions and 0 deletions
+199
View File
@@ -0,0 +1,199 @@
import json
import os
import aiosqlite
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from database.db import DB_PATH
from models.schemas import ChatRequest, ChatResponse
from services.llm import send_message, stream_message
from services.memory import (
get_history,
add_message,
clear_history,
get_or_create_session,
update_session_title,
get_message_count,
get_last_assistant_message_id,
update_message_image,
)
from services.personas import get_persona
from services.sd_prompt import (
generate_sd_prompt,
strip_image_prompt_tag,
extract_image_prompt_tag,
)
from services.lorebook import get_lorebook_context
from services.character_card import get_character
from services import sdbackend as sd_service
router = APIRouter(prefix="/chat", tags=["chat"])
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
persona = await get_persona(persona_id)
if not persona:
return DEFAULT_PROMPT
prompt = persona["prompt"]
if persona_id.startswith("card_"):
card_id = persona_id[5:]
card = await get_character(card_id)
if card:
# match lorebook against recent context + current message
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
context = recent + [{"role": "user", "content": user_message}]
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
if lore:
prompt = prompt + "\n\n" + lore
return prompt
@router.get("/history/{session_id}")
async def get_chat_history(session_id: str):
return await get_history(session_id)
@router.post("/init")
async def init_chat(request: ChatRequest):
"""Called when opening a new chat. Seeds system prompt and first_mes if card persona."""
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
if history:
return {"first_mes": None} # already initialized
system_prompt = await get_system_prompt(persona_id, [], "")
await add_message(request.session_id, "system", system_prompt)
first_mes = None
if persona_id.startswith("card_"):
card = await get_character(persona_id[5:])
if card and card.get("first_mes"):
first_mes = card["first_mes"]
await add_message(request.session_id, "assistant", first_mes)
return {"first_mes": first_mes}
@router.post("/stream")
async def chat_stream(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
if not history:
await add_message(request.session_id, "system", system_prompt)
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""UPDATE messages SET content = ?
WHERE session_id = ? AND role = 'system'
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
(system_prompt, request.session_id, request.session_id),
)
await db.commit()
await add_message(request.session_id, "user", request.message)
messages = await get_history(request.session_id)
full_reply = []
async def generate():
async for chunk in stream_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
):
full_reply.append(chunk)
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
complete = "".join(full_reply)
display_text = strip_image_prompt_tag(complete)
hist_with_reply = await get_history(request.session_id) + [
{"role": "assistant", "content": display_text}
]
sd_result = await generate_sd_prompt(hist_with_reply, persona_id)
prompt_str = sd_result[0] if sd_result else None
if not prompt_str:
prompt_str = extract_image_prompt_tag(complete)
await add_message(
request.session_id,
"assistant",
display_text or complete,
image_prompt=prompt_str,
)
count = await get_message_count(request.session_id)
if count == 2:
title = request.message[:40] + ("" if len(request.message) > 40 else "")
await update_session_title(request.session_id, title)
image_path = None
image_error = None
if prompt_str and SD_AUTO_GENERATE:
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
if rel:
image_path = rel
msg_id = await get_last_assistant_message_id(request.session_id)
if msg_id:
await update_message_image(msg_id, rel)
else:
image_error = err
yield f"data: {json.dumps({
'done': True,
'image_prompt': prompt_str,
'image_path': f'/static/{image_path}' if image_path else None,
'image_error': image_error,
})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/", response_model=ChatResponse)
async def chat(request: ChatRequest):
persona_id = request.persona_id or "default"
await get_or_create_session(request.session_id, persona_id)
history = await get_history(request.session_id)
system_prompt = await get_system_prompt(persona_id, history, request.message)
if not history:
await add_message(request.session_id, "system", system_prompt)
await add_message(request.session_id, "user", request.message)
messages = await get_history(request.session_id)
reply = await send_message(
[{"role": m["role"], "content": m["content"]} for m in messages]
)
display = strip_image_prompt_tag(reply)
prompt_tuple = await generate_sd_prompt(messages, persona_id)
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
return ChatResponse(
reply=display,
session_id=request.session_id,
image_prompt=prompt_str,
)
@router.delete("/{session_id}")
async def clear_chat(session_id: str):
await clear_history(session_id)
return {"status": "cleared", "session_id": session_id}