first commit
This commit is contained in:
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from services.character_card import list_characters, get_character, import_card_file, update_character, update_appearance_tags
|
||||
|
||||
router = APIRouter(prefix="/characters", tags=["characters"])
|
||||
|
||||
|
||||
class CardPatch(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
personality: Optional[str] = None
|
||||
scenario: Optional[str] = None
|
||||
first_mes: Optional[str] = None
|
||||
mes_example: Optional[str] = None
|
||||
appearance_tags: Optional[str] = None
|
||||
lora_name: Optional[str] = None
|
||||
lora_weight: Optional[float] = None
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_all():
|
||||
return await list_characters()
|
||||
|
||||
|
||||
@router.get("/{card_id}")
|
||||
async def get_one(card_id: str):
|
||||
card = await get_character(card_id)
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Карточка не найдена")
|
||||
return card
|
||||
|
||||
|
||||
@router.patch("/{card_id}")
|
||||
async def patch_card(card_id: str, body: CardPatch):
|
||||
card = await get_character(card_id)
|
||||
if not card:
|
||||
raise HTTPException(status_code=404, detail="Карточка не найдена")
|
||||
fields = {k: v for k, v in body.model_dump().items() if v is not None}
|
||||
await update_character(card_id, fields)
|
||||
# sync appearance_tags and lora to persona
|
||||
from services.personas import update_persona_appearance
|
||||
if "appearance_tags" in fields:
|
||||
await update_persona_appearance(f"card_{card_id}", fields["appearance_tags"])
|
||||
if {"lora_name", "lora_weight"} & fields.keys():
|
||||
from services.personas import update_persona_lora
|
||||
await update_persona_lora(f"card_{card_id}", fields.get("lora_name"), fields.get("lora_weight"))
|
||||
# rebuild system prompt if character fields changed
|
||||
char_fields = {"name", "description", "personality", "scenario", "first_mes", "mes_example"}
|
||||
if char_fields & fields.keys():
|
||||
updated = await get_character(card_id)
|
||||
from services.character_card import build_system_prompt
|
||||
from services.personas import update_persona_prompt
|
||||
await update_persona_prompt(f"card_{card_id}", build_system_prompt(updated))
|
||||
return await get_character(card_id)
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
async def import_card(
|
||||
file: UploadFile = File(...),
|
||||
lora_name: str = Form(""),
|
||||
lora_weight: float = Form(0.8),
|
||||
):
|
||||
content = await file.read()
|
||||
try:
|
||||
card = await import_card_file(
|
||||
content,
|
||||
file.filename or "card.json",
|
||||
lora_name=lora_name,
|
||||
lora_weight=lora_weight,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {
|
||||
"status": "imported",
|
||||
"card_id": card["card_id"],
|
||||
"persona_id": f"card_{card['card_id']}",
|
||||
"name": card["name"],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{card_id}")
|
||||
async def remove_card(card_id: str):
|
||||
from services.personas import delete_persona
|
||||
|
||||
if not await delete_persona(f"card_{card_id}"):
|
||||
raise HTTPException(status_code=404, detail="Карточка не найдена")
|
||||
return {"status": "deleted", "card_id": card_id}
|
||||
|
||||
+199
@@ -0,0 +1,199 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from database.db import DB_PATH
|
||||
from models.schemas import ChatRequest, ChatResponse
|
||||
from services.llm import send_message, stream_message
|
||||
from services.memory import (
|
||||
get_history,
|
||||
add_message,
|
||||
clear_history,
|
||||
get_or_create_session,
|
||||
update_session_title,
|
||||
get_message_count,
|
||||
get_last_assistant_message_id,
|
||||
update_message_image,
|
||||
)
|
||||
from services.personas import get_persona
|
||||
from services.sd_prompt import (
|
||||
generate_sd_prompt,
|
||||
strip_image_prompt_tag,
|
||||
extract_image_prompt_tag,
|
||||
)
|
||||
from services.lorebook import get_lorebook_context
|
||||
from services.character_card import get_character
|
||||
from services import sdbackend as sd_service
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу."
|
||||
SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
return DEFAULT_PROMPT
|
||||
|
||||
prompt = persona["prompt"]
|
||||
|
||||
if persona_id.startswith("card_"):
|
||||
card_id = persona_id[5:]
|
||||
card = await get_character(card_id)
|
||||
if card:
|
||||
# match lorebook against recent context + current message
|
||||
recent = [m for m in history if m["role"] in ("user", "assistant")][-5:]
|
||||
context = recent + [{"role": "user", "content": user_message}]
|
||||
lore = get_lorebook_context(card.get("lorebook_json", "[]"), context)
|
||||
if lore:
|
||||
prompt = prompt + "\n\n" + lore
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@router.get("/history/{session_id}")
|
||||
async def get_chat_history(session_id: str):
|
||||
return await get_history(session_id)
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def init_chat(request: ChatRequest):
|
||||
"""Called when opening a new chat. Seeds system prompt and first_mes if card persona."""
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
history = await get_history(request.session_id)
|
||||
if history:
|
||||
return {"first_mes": None} # already initialized
|
||||
|
||||
system_prompt = await get_system_prompt(persona_id, [], "")
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
|
||||
first_mes = None
|
||||
if persona_id.startswith("card_"):
|
||||
card = await get_character(persona_id[5:])
|
||||
if card and card.get("first_mes"):
|
||||
first_mes = card["first_mes"]
|
||||
await add_message(request.session_id, "assistant", first_mes)
|
||||
|
||||
return {"first_mes": first_mes}
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
elif history[0]["role"] == "system" and history[0]["content"] != system_prompt:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(system_prompt, request.session_id, request.session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
await add_message(request.session_id, "user", request.message)
|
||||
messages = await get_history(request.session_id)
|
||||
|
||||
full_reply = []
|
||||
|
||||
async def generate():
|
||||
async for chunk in stream_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
):
|
||||
full_reply.append(chunk)
|
||||
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||||
|
||||
complete = "".join(full_reply)
|
||||
display_text = strip_image_prompt_tag(complete)
|
||||
|
||||
hist_with_reply = await get_history(request.session_id) + [
|
||||
{"role": "assistant", "content": display_text}
|
||||
]
|
||||
sd_result = await generate_sd_prompt(hist_with_reply, persona_id)
|
||||
prompt_str = sd_result[0] if sd_result else None
|
||||
if not prompt_str:
|
||||
prompt_str = extract_image_prompt_tag(complete)
|
||||
|
||||
await add_message(
|
||||
request.session_id,
|
||||
"assistant",
|
||||
display_text or complete,
|
||||
image_prompt=prompt_str,
|
||||
)
|
||||
|
||||
count = await get_message_count(request.session_id)
|
||||
if count == 2:
|
||||
title = request.message[:40] + ("…" if len(request.message) > 40 else "")
|
||||
await update_session_title(request.session_id, title)
|
||||
|
||||
image_path = None
|
||||
image_error = None
|
||||
if prompt_str and SD_AUTO_GENERATE:
|
||||
rel, err = await sd_service.generate_from_full_prompt(prompt_str)
|
||||
if rel:
|
||||
image_path = rel
|
||||
msg_id = await get_last_assistant_message_id(request.session_id)
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
else:
|
||||
image_error = err
|
||||
|
||||
yield f"data: {json.dumps({
|
||||
'done': True,
|
||||
'image_prompt': prompt_str,
|
||||
'image_path': f'/static/{image_path}' if image_path else None,
|
||||
'image_error': image_error,
|
||||
})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
persona_id = request.persona_id or "default"
|
||||
await get_or_create_session(request.session_id, persona_id)
|
||||
|
||||
history = await get_history(request.session_id)
|
||||
system_prompt = await get_system_prompt(persona_id, history, request.message)
|
||||
|
||||
if not history:
|
||||
await add_message(request.session_id, "system", system_prompt)
|
||||
|
||||
await add_message(request.session_id, "user", request.message)
|
||||
messages = await get_history(request.session_id)
|
||||
reply = await send_message(
|
||||
[{"role": m["role"], "content": m["content"]} for m in messages]
|
||||
)
|
||||
display = strip_image_prompt_tag(reply)
|
||||
prompt_tuple = await generate_sd_prompt(messages, persona_id)
|
||||
prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply)
|
||||
|
||||
await add_message(request.session_id, "assistant", display, image_prompt=prompt_str)
|
||||
|
||||
return ChatResponse(
|
||||
reply=display,
|
||||
session_id=request.session_id,
|
||||
image_prompt=prompt_str,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def clear_chat(session_id: str):
|
||||
await clear_history(session_id)
|
||||
return {"status": "cleared", "session_id": session_id}
|
||||
@@ -0,0 +1,34 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services import sdbackend as sd_service
|
||||
from services.memory import get_last_assistant_message_id, update_message_image
|
||||
|
||||
router = APIRouter(prefix="/images", tags=["images"])
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
session_id: str
|
||||
prompt: str
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def sd_health():
|
||||
ok = await sd_service.check_sd()
|
||||
return {"sd_available": ok, "url": sd_service.SD_BASE_URL}
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_image(req: GenerateRequest):
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(status_code=400, detail="Пустой промпт")
|
||||
|
||||
rel, err = await sd_service.generate_from_full_prompt(req.prompt)
|
||||
if not rel:
|
||||
raise HTTPException(status_code=502, detail=err or "SD backend недоступен")
|
||||
|
||||
msg_id = await get_last_assistant_message_id(req.session_id)
|
||||
if msg_id:
|
||||
await update_message_image(msg_id, rel)
|
||||
|
||||
return {"image_path": f"/static/{rel}", "status": "ok"}
|
||||
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from models.schemas import PersonaCreate
|
||||
from services.personas import get_all_personas, get_persona, create_persona, delete_persona
|
||||
|
||||
router = APIRouter(prefix="/personas", tags=["personas"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_personas():
|
||||
personas = await get_all_personas()
|
||||
return [{"persona_id": pid, **data} for pid, data in personas.items()]
|
||||
|
||||
|
||||
@router.get("/{persona_id}")
|
||||
async def get_one_persona(persona_id: str):
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Персонаж не найден")
|
||||
return {"persona_id": persona_id, **persona}
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_new_persona(data: PersonaCreate):
|
||||
persona = await create_persona(
|
||||
persona_id=data.persona_id,
|
||||
name=data.name,
|
||||
emoji=data.emoji,
|
||||
description=data.description,
|
||||
prompt=data.prompt,
|
||||
sd_enabled=data.sd_enabled,
|
||||
lora_name=data.lora_name,
|
||||
lora_weight=data.lora_weight,
|
||||
appearance_tags=data.appearance_tags,
|
||||
)
|
||||
return {"persona_id": data.persona_id, **persona}
|
||||
|
||||
|
||||
@router.delete("/{persona_id}")
|
||||
async def remove_persona(persona_id: str):
|
||||
if not await delete_persona(persona_id):
|
||||
raise HTTPException(status_code=400, detail="Нельзя удалить встроенного персонажа")
|
||||
return {"status": "deleted", "persona_id": persona_id}
|
||||
@@ -0,0 +1,48 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from services.memory import (
|
||||
get_all_sessions,
|
||||
get_or_create_session,
|
||||
delete_session,
|
||||
update_session_title,
|
||||
update_session_persona,
|
||||
get_history,
|
||||
get_message_count
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_sessions():
|
||||
sessions = await get_all_sessions()
|
||||
result = []
|
||||
for s in sessions:
|
||||
count = await get_message_count(s["session_id"])
|
||||
result.append({**s, "message_count": count})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
sessions = await get_all_sessions()
|
||||
s = next((x for x in sessions if x["session_id"] == session_id), None)
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Сессия не найдена")
|
||||
return s
|
||||
|
||||
|
||||
@router.patch("/{session_id}")
|
||||
async def patch_session(session_id: str, data: dict):
|
||||
# ensure session exists before patching
|
||||
await get_or_create_session(session_id, data.get("persona_id", "default"))
|
||||
if "title" in data:
|
||||
await update_session_title(session_id, data["title"])
|
||||
if "persona_id" in data:
|
||||
await update_session_persona(session_id, data["persona_id"])
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def remove_session(session_id: str):
|
||||
await delete_session(session_id)
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from services.translate import translate_to_russian
|
||||
|
||||
router = APIRouter(prefix="/translate", tags=["translate"])
|
||||
|
||||
|
||||
class TranslateRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def translate(req: TranslateRequest):
|
||||
try:
|
||||
result = await translate_to_russian(req.text)
|
||||
return {"translated": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
Reference in New Issue
Block a user