first commit
This commit is contained in:
@@ -0,0 +1,214 @@
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
def parse_card_v2(data: dict) -> dict:
|
||||
inner = data.get("data", data)
|
||||
if isinstance(inner, str):
|
||||
inner = json.loads(inner)
|
||||
|
||||
book = inner.get("character_book") or {}
|
||||
entries = book.get("entries", [])
|
||||
if isinstance(entries, dict):
|
||||
entries = list(entries.values())
|
||||
|
||||
return {
|
||||
"card_id": (
|
||||
inner.get("name", "imported").lower().replace(" ", "_")[:48]
|
||||
+ "_"
|
||||
+ uuid.uuid4().hex[:8]
|
||||
),
|
||||
"name": inner.get("name", "Character"),
|
||||
"description": inner.get("description", ""),
|
||||
"personality": inner.get("personality", ""),
|
||||
"scenario": inner.get("scenario", ""),
|
||||
"first_mes": inner.get("first_mes", ""),
|
||||
"mes_example": inner.get("mes_example", ""),
|
||||
"appearance_tags": _extract_appearance(inner),
|
||||
"lorebook_json": json.dumps(entries, ensure_ascii=False),
|
||||
"raw_json": json.dumps(data if "data" in data else {"data": inner}, ensure_ascii=False),
|
||||
}
|
||||
|
||||
|
||||
def _extract_appearance(inner: dict) -> str:
|
||||
"""Extract booru-style appearance tags from character fields."""
|
||||
import re
|
||||
# fall back: scan description for visual keywords, skip world-building sentences
|
||||
desc = inner.get("description", "")
|
||||
appearance_keywords = re.findall(
|
||||
r'\b(?:'
|
||||
r'\w*hair|hair\w*|\w*eyes|eye\w*|\w*skin|skin\w*'
|
||||
r'|tall|short|slim|curvy|muscular|petite'
|
||||
r'|ears?|tail|horns?|wings?|cloak|dress|outfit|uniform|armor'
|
||||
r'|wolf\w*|cat\w*|fox\w*|elf\w*|demon\w*|angel\w*'
|
||||
r'|silver|blonde|black|white|red|blue|green|purple|pink|brown|golden'
|
||||
r')\b',
|
||||
desc, re.IGNORECASE
|
||||
)
|
||||
seen = []
|
||||
for kw in appearance_keywords:
|
||||
kw_lower = kw.lower()
|
||||
if kw_lower not in seen:
|
||||
seen.append(kw_lower)
|
||||
return ", ".join(seen[:20])
|
||||
|
||||
|
||||
def parse_png_card(file_bytes: bytes) -> dict | None:
|
||||
if not file_bytes.startswith(b"\x89PNG"):
|
||||
return None
|
||||
idx = 8 # skip PNG file signature
|
||||
while idx < len(file_bytes) - 12:
|
||||
length = int.from_bytes(file_bytes[idx : idx + 4], "big")
|
||||
chunk_type = file_bytes[idx + 4 : idx + 8]
|
||||
chunk_data = file_bytes[idx + 8 : idx + 8 + length]
|
||||
if chunk_type == b"tEXt":
|
||||
try:
|
||||
key, _, val = chunk_data.partition(b"\x00")
|
||||
if key in (b"chara", b"ccv3"):
|
||||
decoded = base64.b64decode(val).decode("utf-8")
|
||||
return parse_card_v2(json.loads(decoded))
|
||||
except Exception:
|
||||
pass
|
||||
elif chunk_type == b"iTXt":
|
||||
try:
|
||||
# iTXt: keyword \x00 compression_flag \x00 compression_method \x00 language \x00 translated_keyword \x00 text
|
||||
key, _, rest = chunk_data.partition(b"\x00")
|
||||
if key in (b"chara", b"ccv3"):
|
||||
# skip compression_flag, compression_method, language tag, translated keyword
|
||||
text = rest[2:].split(b"\x00", 2)[-1].decode("utf-8")
|
||||
# text may be base64 or raw JSON
|
||||
try:
|
||||
return parse_card_v2(json.loads(base64.b64decode(text).decode("utf-8")))
|
||||
except Exception:
|
||||
return parse_card_v2(json.loads(text))
|
||||
except Exception:
|
||||
pass
|
||||
idx += 12 + length
|
||||
return None
|
||||
|
||||
|
||||
def build_system_prompt(card: dict) -> str:
|
||||
parts = [
|
||||
f"You are {card['name']}. Stay in character.",
|
||||
f"Description: {card['description']}",
|
||||
f"Personality: {card['personality']}",
|
||||
f"Scenario: {card['scenario']}",
|
||||
]
|
||||
if card.get("mes_example"):
|
||||
parts.append(f"Example dialogue:\n{card['mes_example']}")
|
||||
parts.append("Reply only as the character. Do not add image tags.")
|
||||
return "\n\n".join(p for p in parts if p.split(": ", 1)[-1].strip())
|
||||
|
||||
|
||||
async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0.8) -> dict:
|
||||
card_id = card["card_id"]
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT OR REPLACE INTO characters
|
||||
(card_id, name, description, personality, scenario, first_mes,
|
||||
mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
card_id,
|
||||
card["name"],
|
||||
card["description"],
|
||||
card["personality"],
|
||||
card["scenario"],
|
||||
card["first_mes"],
|
||||
card["mes_example"],
|
||||
card["raw_json"],
|
||||
lora_name,
|
||||
lora_weight,
|
||||
card.get("appearance_tags", ""),
|
||||
card["lorebook_json"],
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
return {**card, "lora_name": lora_name, "lora_weight": lora_weight}
|
||||
|
||||
|
||||
async def get_character(card_id: str) -> dict | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM characters WHERE card_id = ?", (card_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
async def list_characters() -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT card_id, name, description, lora_name FROM characters ORDER BY name"
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def delete_character(card_id: str) -> bool:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
cur = await db.execute(
|
||||
"DELETE FROM characters WHERE card_id = ?", (card_id,)
|
||||
)
|
||||
await db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def update_appearance_tags(card_id: str, appearance_tags: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE characters SET appearance_tags = ? WHERE card_id = ?",
|
||||
(appearance_tags, card_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_character(card_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "personality", "scenario", "first_mes",
|
||||
"mes_example", "appearance_tags", "lora_name", "lora_weight"}
|
||||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not updates:
|
||||
return False
|
||||
cols = ", ".join(f"{k} = ?" for k in updates)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
cur = await db.execute(
|
||||
f"UPDATE characters SET {cols} WHERE card_id = ?",
|
||||
(*updates.values(), card_id),
|
||||
)
|
||||
await db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def import_card_file(content: bytes, filename: str, lora_name: str = "", lora_weight: float = 0.8) -> dict:
|
||||
if filename.lower().endswith(".png"):
|
||||
card = parse_png_card(content)
|
||||
if not card:
|
||||
raise ValueError("PNG does not contain character card metadata")
|
||||
else:
|
||||
card = parse_card_v2(json.loads(content.decode("utf-8")))
|
||||
|
||||
saved = await save_character(card, lora_name=lora_name, lora_weight=lora_weight)
|
||||
|
||||
persona_id = f"card_{saved['card_id']}"
|
||||
from services.personas import create_persona, get_persona
|
||||
|
||||
existing = await get_persona(persona_id)
|
||||
if not existing:
|
||||
await create_persona(
|
||||
persona_id=persona_id,
|
||||
name=saved["name"],
|
||||
emoji="🎭",
|
||||
description=saved["description"][:80] or "Character card",
|
||||
prompt=build_system_prompt(saved),
|
||||
sd_enabled=True,
|
||||
lora_name=lora_name,
|
||||
lora_weight=lora_weight,
|
||||
appearance_tags=saved.get("appearance_tags", ""),
|
||||
)
|
||||
return saved
|
||||
@@ -0,0 +1,63 @@
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
OPENROUTER_KEY = os.getenv("ROUTER_KEY")
|
||||
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
MODEL = "google/gemini-2.5-flash"
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {OPENROUTER_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": "http://localhost:8000",
|
||||
}
|
||||
|
||||
async def send_message(messages: list) -> str:
|
||||
"""Обычный запрос — используем для внутренних нужд"""
|
||||
payload = {
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
response = await client.post(
|
||||
OPENROUTER_URL,
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
async def stream_message(messages: list):
|
||||
"""Стриминг — отдаём чанки по мере получения"""
|
||||
payload = {
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
OPENROUTER_URL,
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:] # убираем "data: "
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json
|
||||
chunk = json.loads(data)
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except Exception:
|
||||
continue
|
||||
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
|
||||
|
||||
def _match_entry(entry: dict, text: str) -> bool:
|
||||
keys = entry.get("keys", [])
|
||||
if isinstance(keys, str):
|
||||
keys = [k.strip() for k in keys.split(",") if k.strip()]
|
||||
text_lower = text.lower()
|
||||
for key in keys:
|
||||
if key and key.lower() in text_lower:
|
||||
return True
|
||||
secondary = entry.get("secondary_keys", []) or entry.get("keysecondary", [])
|
||||
if isinstance(secondary, str):
|
||||
secondary = [k.strip() for k in secondary.split(",") if k.strip()]
|
||||
for key in secondary:
|
||||
if key and key.lower() in text_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_lorebook_context(lorebook_json: str, context: str | list, max_entries: int = 5) -> str:
|
||||
"""Match lorebook entries against context.
|
||||
context can be a string or a list of message dicts (role/content).
|
||||
"""
|
||||
try:
|
||||
entries = json.loads(lorebook_json or "[]")
|
||||
except json.JSONDecodeError:
|
||||
return ""
|
||||
|
||||
if isinstance(entries, dict):
|
||||
entries = list(entries.values())
|
||||
|
||||
if isinstance(context, list):
|
||||
text = " ".join(m.get("content", "") for m in context if m.get("role") in ("user", "assistant"))
|
||||
else:
|
||||
text = context
|
||||
|
||||
matched = []
|
||||
for entry in entries:
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
if _match_entry(entry, text):
|
||||
content = entry.get("content", "").strip()
|
||||
if content:
|
||||
name = entry.get("name", entry.get("comment", "Lore"))
|
||||
matched.append(f"[{name}]\n{content}")
|
||||
|
||||
if not matched:
|
||||
return ""
|
||||
|
||||
block = "\n\n".join(matched[:max_entries])
|
||||
return f"--- Lorebook (relevant world info) ---\n{block}\n---"
|
||||
@@ -0,0 +1,142 @@
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions WHERE session_id = ?", (session_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
|
||||
(session_id, persona_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions WHERE session_id = ?", (session_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return dict(row)
|
||||
|
||||
|
||||
async def get_all_sessions() -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions ORDER BY updated_at DESC"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_persona(session_id: str, persona_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET persona_id = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(persona_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def delete_session(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
await db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT role, content, image_prompt, image_path
|
||||
FROM messages WHERE session_id = ? ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"image_prompt": r["image_prompt"],
|
||||
"image_path": r["image_path"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def add_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
image_prompt: str | None = None,
|
||||
image_path: str | None = None,
|
||||
):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, role, content, image_prompt, image_path),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_image(message_id: int, image_path: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_path = ? WHERE id = ?",
|
||||
(image_path, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_assistant_message_id(session_id: str) -> int | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id FROM messages
|
||||
WHERE session_id = ? AND role = 'assistant'
|
||||
ORDER BY id DESC LIMIT 1""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["id"] if row else None
|
||||
|
||||
|
||||
async def clear_history(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_message_count(session_id: str) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE session_id = ? AND role != 'system'",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"]
|
||||
@@ -0,0 +1,26 @@
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
from services.personas import DEFAULT_PERSONAS
|
||||
|
||||
|
||||
async def seed_default_personas():
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
for pid, data in DEFAULT_PERSONAS.items():
|
||||
await db.execute(
|
||||
"""INSERT OR IGNORE INTO personas
|
||||
(persona_id, name, emoji, description, prompt, custom, sd_enabled,
|
||||
lora_name, lora_weight, appearance_tags)
|
||||
VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?)""",
|
||||
(
|
||||
pid,
|
||||
data["name"],
|
||||
data["emoji"],
|
||||
data["description"],
|
||||
data["prompt"],
|
||||
1 if data.get("sd_enabled") else 0,
|
||||
data.get("lora_name", ""),
|
||||
data.get("lora_weight", 0.8),
|
||||
data.get("appearance_tags", ""),
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
@@ -0,0 +1,168 @@
|
||||
from typing import Optional
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
DEFAULT_PERSONAS = {
|
||||
"default": {
|
||||
"name": "AI Ассистент",
|
||||
"emoji": "🤖",
|
||||
"description": "Универсальный помощник",
|
||||
"prompt": "Ты — полезный AI ассистент. Отвечай чётко и по делу.",
|
||||
"sd_enabled": False,
|
||||
},
|
||||
"rpg_master": {
|
||||
"name": "Мастер RPG",
|
||||
"emoji": "🧙",
|
||||
"description": "Ведёт ролевые игры, создаёт атмосферу",
|
||||
"prompt": """Ты — опытный Мастер ролевых игр.
|
||||
Создавай живые описания, веди нарратив, реагируй на действия игрока.
|
||||
Мир детальный, персонажи запоминающиеся.
|
||||
Отвечай только текстом сюжета — без тегов изображений.""",
|
||||
"sd_enabled": True,
|
||||
},
|
||||
"villain": {
|
||||
"name": "Злодей",
|
||||
"emoji": "😈",
|
||||
"description": "Харизматичный антагонист",
|
||||
"prompt": """Ты — харизматичный злодей с грандиозными планами.
|
||||
Говоришь театрально, с сарказмом и превосходством.
|
||||
Никогда не выходишь из роли. Называешь собеседника 'герой' с иронией.""",
|
||||
"sd_enabled": False,
|
||||
},
|
||||
"scientist": {
|
||||
"name": "Учёный",
|
||||
"emoji": "🔬",
|
||||
"description": "Объясняет сложное простыми словами",
|
||||
"prompt": """Ты — увлечённый учёный. Объясняешь любые темы
|
||||
через факты, аналогии и примеры. Любишь уточнять детали.
|
||||
Иногда уходишь в интересные отступления.""",
|
||||
"sd_enabled": False,
|
||||
},
|
||||
"samurai": {
|
||||
"name": "Самурай",
|
||||
"emoji": "⚔️",
|
||||
"description": "Мудрый воин феодальной Японии",
|
||||
"prompt": """Ты — самурай феодальной Японии.
|
||||
Говоришь кратко, мудро, с достоинством.
|
||||
Используешь метафоры природы и войны.
|
||||
Чтишь кодекс бусидо.""",
|
||||
"sd_enabled": True,
|
||||
"appearance_tags": "samurai armor, katana, feudal japan",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _row_to_persona(row: dict) -> dict:
|
||||
return {
|
||||
"name": row["name"],
|
||||
"emoji": row["emoji"],
|
||||
"description": row["description"],
|
||||
"prompt": row["prompt"],
|
||||
"custom": bool(row["custom"]),
|
||||
"sd_enabled": bool(row["sd_enabled"]),
|
||||
"lora_name": row["lora_name"] or "",
|
||||
"lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8,
|
||||
"appearance_tags": row["appearance_tags"] or "",
|
||||
}
|
||||
|
||||
|
||||
async def get_all_personas() -> dict:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute("SELECT * FROM personas ORDER BY custom ASC, persona_id ASC") as cur:
|
||||
rows = await cur.fetchall()
|
||||
return {r["persona_id"]: _row_to_persona(dict(r)) for r in rows}
|
||||
|
||||
|
||||
async def get_persona(persona_id: str) -> Optional[dict]:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM personas WHERE persona_id = ?", (persona_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return _row_to_persona(dict(row))
|
||||
|
||||
|
||||
async def create_persona(
|
||||
persona_id: str,
|
||||
name: str,
|
||||
emoji: str,
|
||||
description: str,
|
||||
prompt: str,
|
||||
sd_enabled: bool = False,
|
||||
lora_name: str = "",
|
||||
lora_weight: float = 0.8,
|
||||
appearance_tags: str = "",
|
||||
) -> dict:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT INTO personas
|
||||
(persona_id, name, emoji, description, prompt, custom,
|
||||
sd_enabled, lora_name, lora_weight, appearance_tags)
|
||||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?)""",
|
||||
(
|
||||
persona_id, name, emoji, description, prompt,
|
||||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
"name": name,
|
||||
"emoji": emoji,
|
||||
"description": description,
|
||||
"prompt": prompt,
|
||||
"custom": True,
|
||||
"sd_enabled": sd_enabled,
|
||||
"lora_name": lora_name,
|
||||
"lora_weight": lora_weight,
|
||||
"appearance_tags": appearance_tags,
|
||||
}
|
||||
|
||||
|
||||
async def delete_persona(persona_id: str) -> bool:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
async with db.execute(
|
||||
"SELECT custom FROM personas WHERE persona_id = ?", (persona_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row or not row[0]:
|
||||
return False
|
||||
await db.execute("DELETE FROM personas WHERE persona_id = ?", (persona_id,))
|
||||
await db.commit()
|
||||
|
||||
if persona_id.startswith("card_"):
|
||||
from services.character_card import delete_character
|
||||
await delete_character(persona_id[5:])
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_persona_appearance(persona_id: str, appearance_tags: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE personas SET appearance_tags = ? WHERE persona_id = ?",
|
||||
(appearance_tags, persona_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_persona_lora(persona_id: str, lora_name: str | None, lora_weight: float | None):
|
||||
fields, vals = [], []
|
||||
if lora_name is not None:
|
||||
fields.append("lora_name = ?"); vals.append(lora_name)
|
||||
if lora_weight is not None:
|
||||
fields.append("lora_weight = ?"); vals.append(lora_weight)
|
||||
if not fields:
|
||||
return
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(f"UPDATE personas SET {', '.join(fields)} WHERE persona_id = ?", (*vals, persona_id))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_persona_prompt(persona_id: str, prompt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute("UPDATE personas SET prompt = ? WHERE persona_id = ?", (prompt, persona_id))
|
||||
await db.commit()
|
||||
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from services.llm import send_message
|
||||
from services.personas import get_persona
|
||||
|
||||
PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models.
|
||||
Given a roleplay chat excerpt and character appearance hints, output ONLY valid JSON (no markdown):
|
||||
{
|
||||
"should_generate": true,
|
||||
"shot_type": "first_person_pov" | "landscape" | "third_person",
|
||||
"appearance_tags": "booru-style tags for character appearance extracted from hints, e.g. 'white hair, wolf ears, wolf tail, yellow eyes'",
|
||||
"action_tags": "booru-style tags for pose/action, e.g. 'sitting, smiling, looking at viewer'",
|
||||
"environment_tags": "booru-style tags for location/lighting, e.g. 'indoors, kitchen, sunlight'"
|
||||
}
|
||||
Rules:
|
||||
- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be written as single tags: 'white hair' not 'white, hair'. 'wolf ears' not 'wolf, ears'.
|
||||
- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words.
|
||||
- Do NOT invent tags. If unsure — omit.
|
||||
- Keep each field to 3-6 tags."""
|
||||
|
||||
|
||||
def extract_image_prompt_tag(text: str) -> str | None:
|
||||
if "[IMAGE_PROMPT:" not in text:
|
||||
return None
|
||||
try:
|
||||
start = text.index("[IMAGE_PROMPT:") + len("[IMAGE_PROMPT:")
|
||||
end = text.index("]", start)
|
||||
return text[start:end].strip()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def strip_image_prompt_tag(text: str) -> str:
|
||||
return re.sub(r"\[IMAGE_PROMPT:.*?\]", "", text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"}
|
||||
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "")
|
||||
PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored"
|
||||
|
||||
def build_positive_prompt(scene: dict, persona: dict | None) -> str:
|
||||
is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS
|
||||
quality = "score_9, score_8_up, score_7_up, source_anime, highres" if is_pony else "masterpiece, best quality, highres"
|
||||
parts = [quality]
|
||||
|
||||
# prefer LLM-extracted appearance over raw persona tags
|
||||
appearance = scene.get("appearance_tags") or (persona or {}).get("appearance_tags", "")
|
||||
if appearance:
|
||||
parts.append(appearance)
|
||||
|
||||
if scene.get("shot_type") == "landscape":
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
else:
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
parts.append("pov, first-person view, looking at viewer")
|
||||
parts.append(scene.get("action_tags", ""))
|
||||
parts.append(scene.get("environment_tags", ""))
|
||||
|
||||
lora = (persona or {}).get("lora_name", "")
|
||||
weight = (persona or {}).get("lora_weight", 0.8)
|
||||
if lora:
|
||||
parts.append(f"<lora:{lora}:{weight}>")
|
||||
|
||||
positive = ", ".join(p.strip() for p in parts if p and p.strip())
|
||||
seen, deduped = set(), []
|
||||
for tag in positive.split(", "):
|
||||
t = tag.strip()
|
||||
if t and t not in seen:
|
||||
seen.add(t)
|
||||
deduped.append(t)
|
||||
return ", ".join(deduped)
|
||||
|
||||
|
||||
async def generate_sd_prompt(
|
||||
messages: list,
|
||||
persona_id: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
persona = await get_persona(persona_id)
|
||||
if not persona or not persona.get("sd_enabled"):
|
||||
return None, None
|
||||
|
||||
recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:]
|
||||
if not recent:
|
||||
return None, None
|
||||
|
||||
excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent)
|
||||
|
||||
appearance = persona.get("appearance_tags", "")
|
||||
# For card personas, also include description for better visual context
|
||||
if persona_id.startswith("card_"):
|
||||
from services.character_card import get_character
|
||||
card = await get_character(persona_id[5:])
|
||||
if card and card.get("description"):
|
||||
appearance = f"{appearance}\nCharacter description: {card['description'][:400]}"
|
||||
|
||||
builder_messages = [
|
||||
{"role": "system", "content": PROMPT_BUILDER_SYSTEM},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Persona appearance hints: {appearance}\n\nChat:\n{excerpt}",
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
raw = await send_message(builder_messages)
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = re.sub(r"^```\w*\n?", "", raw)
|
||||
raw = re.sub(r"\n?```$", "", raw)
|
||||
scene = json.loads(raw)
|
||||
except (json.JSONDecodeError, Exception):
|
||||
return None, None
|
||||
|
||||
|
||||
positive = build_positive_prompt(scene, persona)
|
||||
is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS
|
||||
negative = PONY_NEGATIVE if is_pony else "low quality, blurry, bad anatomy, watermark, text"
|
||||
if scene.get("shot_type") == "first_person_pov":
|
||||
negative += ", third person, over the shoulder"
|
||||
|
||||
full = positive
|
||||
if negative:
|
||||
full += f"\n\nNegative prompt: {negative}"
|
||||
return full, negative
|
||||
@@ -0,0 +1,121 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/")
|
||||
SD_STEPS = int(os.getenv("SD_STEPS", "28"))
|
||||
SD_CFG = float(os.getenv("SD_CFG", "7"))
|
||||
SD_SAMPLER = os.getenv("SD_SAMPLER", "euler")
|
||||
SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal")
|
||||
SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "NetaYumev35_pretrained_all_in_one.safetensors")
|
||||
SD_DEFAULT_NEGATIVE = os.getenv(
|
||||
"SD_DEFAULT_NEGATIVE",
|
||||
"low quality, worst quality, blurry, bad anatomy, watermark, text",
|
||||
)
|
||||
IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images"))
|
||||
|
||||
|
||||
def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]:
|
||||
if "\n\nNegative prompt:" in full_prompt:
|
||||
pos, _, neg = full_prompt.partition("\n\nNegative prompt:")
|
||||
return pos.strip(), neg.strip()
|
||||
return full_prompt.strip(), SD_DEFAULT_NEGATIVE
|
||||
|
||||
|
||||
def _build_workflow(positive: str, negative: str) -> dict:
|
||||
"""Minimal KSampler workflow for ComfyUI API."""
|
||||
return {
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}},
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}},
|
||||
"10": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": ["4", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["7", 0],
|
||||
"latent_image": ["5", 0],
|
||||
"seed": int(uuid.uuid4().int % 2**32),
|
||||
"steps": SD_STEPS,
|
||||
"cfg": SD_CFG,
|
||||
"sampler_name": SD_SAMPLER,
|
||||
"scheduler": SD_SCHEDULER,
|
||||
"denoise": 1.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def check_sd() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(f"{SD_BASE_URL}/system_stats")
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]:
|
||||
neg = negative_prompt or SD_DEFAULT_NEGATIVE
|
||||
workflow = _build_workflow(prompt, neg)
|
||||
client_id = uuid.uuid4().hex
|
||||
|
||||
logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
# queue the prompt
|
||||
resp = await client.post(
|
||||
f"{SD_BASE_URL}/prompt",
|
||||
json={"prompt": workflow, "client_id": client_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
prompt_id = resp.json()["prompt_id"]
|
||||
logger.info("ComfyUI queued prompt_id=%s", prompt_id)
|
||||
|
||||
# poll until done
|
||||
for _ in range(300):
|
||||
await asyncio.sleep(1)
|
||||
hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}")
|
||||
data = hist.json()
|
||||
if prompt_id in data:
|
||||
outputs = data[prompt_id]["outputs"]
|
||||
# find first image output
|
||||
for node_output in outputs.values():
|
||||
if "images" in node_output:
|
||||
img_info = node_output["images"][0]
|
||||
img_resp = await client.get(
|
||||
f"{SD_BASE_URL}/view",
|
||||
params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")},
|
||||
)
|
||||
img_resp.raise_for_status()
|
||||
image_bytes = img_resp.content
|
||||
|
||||
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex}.png"
|
||||
(IMAGES_DIR / filename).write_bytes(image_bytes)
|
||||
logger.info("ComfyUI done → saved %s", filename)
|
||||
return image_bytes, f"images/{filename}"
|
||||
break
|
||||
|
||||
raise RuntimeError("ComfyUI generation timed out or produced no output")
|
||||
|
||||
|
||||
async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]:
|
||||
positive, negative = split_prompt_and_negative(full_prompt)
|
||||
try:
|
||||
_, rel_path = await txt2img(positive, negative)
|
||||
return rel_path, None
|
||||
except Exception as e:
|
||||
logger.error("ComfyUI error: %s", e)
|
||||
return None, str(e)
|
||||
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
LIBRETRANSLATE_URL = os.getenv("LIBRETRANSLATE_URL", "http://192.168.1.109:5100")
|
||||
|
||||
|
||||
async def translate_to_russian(text: str) -> str:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
r = await client.post(
|
||||
f"{LIBRETRANSLATE_URL}/translate",
|
||||
json={"q": text, "source": "auto", "target": "ru", "format": "text"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["translatedText"]
|
||||
Reference in New Issue
Block a user