283 lines
11 KiB
Python
283 lines
11 KiB
Python
from typing import Optional
|
||
import aiosqlite
|
||
from database.db import DB_PATH
|
||
|
||
DEFAULT_PERSONAS = {
|
||
"default": {
|
||
"name": "AI Ассистент",
|
||
"emoji": "🤖",
|
||
"description": "Универсальный помощник",
|
||
"prompt": "Ты — полезный AI ассистент. Отвечай чётко и по делу.",
|
||
"sd_enabled": False,
|
||
},
|
||
"rpg_master": {
|
||
"name": "Мастер RPG",
|
||
"emoji": "🧙",
|
||
"description": "Ведёт ролевые игры, создаёт атмосферу",
|
||
"prompt": """Ты — опытный Мастер ролевых игр.
|
||
Создавай живые описания, веди нарратив, реагируй на действия игрока.
|
||
Мир детальный, персонажи запоминающиеся.
|
||
Отвечай только текстом сюжета — без тегов изображений.""",
|
||
"sd_enabled": True,
|
||
},
|
||
"villain": {
|
||
"name": "Злодей",
|
||
"emoji": "😈",
|
||
"description": "Харизматичный антагонист",
|
||
"prompt": """Ты — харизматичный злодей с грандиозными планами.
|
||
Говоришь театрально, с сарказмом и превосходством.
|
||
Никогда не выходишь из роли. Называешь собеседника 'герой' с иронией.""",
|
||
"sd_enabled": False,
|
||
},
|
||
"scientist": {
|
||
"name": "Учёный",
|
||
"emoji": "🔬",
|
||
"description": "Объясняет сложное простыми словами",
|
||
"prompt": """Ты — увлечённый учёный. Объясняешь любые темы
|
||
через факты, аналогии и примеры. Любишь уточнять детали.
|
||
Иногда уходишь в интересные отступления.""",
|
||
"sd_enabled": False,
|
||
},
|
||
"samurai": {
|
||
"name": "Самурай",
|
||
"emoji": "⚔️",
|
||
"description": "Мудрый воин феодальной Японии",
|
||
"prompt": """Ты — самурай феодальной Японии.
|
||
Говоришь кратко, мудро, с достоинством.
|
||
Используешь метафоры природы и войны.
|
||
Чтишь кодекс бусидо.""",
|
||
"sd_enabled": True,
|
||
"appearance_tags": "samurai armor, katana, feudal japan",
|
||
},
|
||
}
|
||
|
||
|
||
def _row_to_persona(row: dict) -> dict:
|
||
return {
|
||
"name": row["name"],
|
||
"emoji": row["emoji"],
|
||
"description": row["description"],
|
||
"prompt": row["prompt"],
|
||
"custom": bool(row["custom"]),
|
||
"sd_enabled": bool(row["sd_enabled"]),
|
||
"lora_name": row["lora_name"] or "",
|
||
"lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8,
|
||
"appearance_tags": row["appearance_tags"] or "",
|
||
"appearance_prose": row.get("appearance_prose", "") or "",
|
||
"personality": row.get("personality", "") or "",
|
||
"scenario": row.get("scenario", "") or "",
|
||
"first_mes": row.get("first_mes", "") or "",
|
||
"mes_example": row.get("mes_example", "") or "",
|
||
"lorebook_json": row.get("lorebook_json", "[]") or "[]",
|
||
"avatar_path": row.get("avatar_path", "") or "",
|
||
"alternate_greetings_json": row.get("alternate_greetings_json", "[]") or "[]",
|
||
}
|
||
|
||
|
||
def build_persona_prompt(data: dict) -> str:
|
||
parts = [
|
||
f"You are {data.get('name', '').strip()}." if data.get("name") else "",
|
||
f"Description: {data.get('description', '').strip()}",
|
||
f"Personality: {data.get('personality', '').strip()}",
|
||
f"Scenario: {data.get('scenario', '').strip()}",
|
||
]
|
||
ex = (data.get("mes_example") or "").strip()
|
||
if ex:
|
||
parts.append(f"Example dialogue:\n{ex}")
|
||
parts.append("Stay in character. Reply as the character. Do not add image tags.")
|
||
return "\n\n".join(p for p in parts if p and p.split(": ", 1)[-1].strip())
|
||
|
||
|
||
async def get_all_personas() -> dict:
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
db.row_factory = aiosqlite.Row
|
||
async with db.execute("SELECT * FROM personas ORDER BY custom ASC, persona_id ASC") as cur:
|
||
rows = await cur.fetchall()
|
||
return {r["persona_id"]: _row_to_persona(dict(r)) for r in rows}
|
||
|
||
|
||
async def get_persona(persona_id: str) -> Optional[dict]:
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
db.row_factory = aiosqlite.Row
|
||
async with db.execute(
|
||
"SELECT * FROM personas WHERE persona_id = ?", (persona_id,)
|
||
) as cur:
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
return None
|
||
return _row_to_persona(dict(row))
|
||
|
||
|
||
async def create_persona(
|
||
persona_id: str,
|
||
name: str,
|
||
emoji: str,
|
||
description: str,
|
||
prompt: str,
|
||
sd_enabled: bool = False,
|
||
lora_name: str = "",
|
||
lora_weight: float = 0.8,
|
||
appearance_tags: str = "",
|
||
appearance_prose: str = "",
|
||
personality: str = "",
|
||
scenario: str = "",
|
||
first_mes: str = "",
|
||
mes_example: str = "",
|
||
lorebook_json: str = "[]",
|
||
avatar_path: str = "",
|
||
alternate_greetings_json: str = "[]",
|
||
) -> dict:
|
||
final_prompt = prompt.strip() or build_persona_prompt(
|
||
{
|
||
"name": name,
|
||
"description": description,
|
||
"personality": personality,
|
||
"scenario": scenario,
|
||
"mes_example": mes_example,
|
||
}
|
||
)
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
await db.execute(
|
||
"""INSERT INTO personas
|
||
(persona_id, name, emoji, description, prompt, custom,
|
||
sd_enabled, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||
alternate_greetings_json)
|
||
VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||
(
|
||
persona_id, name, emoji, description, final_prompt,
|
||
1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, appearance_prose,
|
||
personality, scenario, first_mes, mes_example, lorebook_json, avatar_path,
|
||
alternate_greetings_json,
|
||
),
|
||
)
|
||
await db.commit()
|
||
return {
|
||
"name": name,
|
||
"emoji": emoji,
|
||
"description": description,
|
||
"prompt": final_prompt,
|
||
"custom": True,
|
||
"sd_enabled": sd_enabled,
|
||
"lora_name": lora_name,
|
||
"lora_weight": lora_weight,
|
||
"appearance_tags": appearance_tags,
|
||
"appearance_prose": appearance_prose,
|
||
"personality": personality,
|
||
"scenario": scenario,
|
||
"first_mes": first_mes,
|
||
"mes_example": mes_example,
|
||
"lorebook_json": lorebook_json,
|
||
"avatar_path": avatar_path,
|
||
"alternate_greetings_json": alternate_greetings_json,
|
||
}
|
||
|
||
|
||
async def delete_persona(persona_id: str) -> bool:
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
async with db.execute(
|
||
"SELECT custom FROM personas WHERE persona_id = ?", (persona_id,)
|
||
) as cur:
|
||
row = await cur.fetchone()
|
||
if not row or not row[0]:
|
||
return False
|
||
await db.execute("DELETE FROM personas WHERE persona_id = ?", (persona_id,))
|
||
await db.commit()
|
||
|
||
if persona_id.startswith("card_"):
|
||
from services.character_card import delete_character
|
||
await delete_character(persona_id[5:])
|
||
|
||
return True
|
||
|
||
|
||
async def update_persona_appearance(persona_id: str, appearance_tags: str):
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
await db.execute(
|
||
"UPDATE personas SET appearance_tags = ? WHERE persona_id = ?",
|
||
(appearance_tags, persona_id),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def update_persona_lora(persona_id: str, lora_name: str | None, lora_weight: float | None):
|
||
fields, vals = [], []
|
||
if lora_name is not None:
|
||
fields.append("lora_name = ?"); vals.append(lora_name)
|
||
if lora_weight is not None:
|
||
fields.append("lora_weight = ?"); vals.append(lora_weight)
|
||
if not fields:
|
||
return
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
await db.execute(f"UPDATE personas SET {', '.join(fields)} WHERE persona_id = ?", (*vals, persona_id))
|
||
await db.commit()
|
||
|
||
|
||
async def update_persona_prompt(persona_id: str, prompt: str):
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
await db.execute("UPDATE personas SET prompt = ? WHERE persona_id = ?", (prompt, persona_id))
|
||
await db.commit()
|
||
|
||
|
||
async def patch_persona(persona_id: str, fields: dict) -> bool:
|
||
allowed = {
|
||
"name",
|
||
"emoji",
|
||
"description",
|
||
"prompt",
|
||
"sd_enabled",
|
||
"lora_name",
|
||
"lora_weight",
|
||
"appearance_tags",
|
||
"appearance_prose",
|
||
"personality",
|
||
"scenario",
|
||
"first_mes",
|
||
"mes_example",
|
||
"lorebook_json",
|
||
"avatar_path",
|
||
"alternate_greetings_json",
|
||
}
|
||
updates = {k: v for k, v in fields.items() if k in allowed}
|
||
if not updates:
|
||
return False
|
||
|
||
async with aiosqlite.connect(DB_PATH) as db:
|
||
db.row_factory = aiosqlite.Row
|
||
# disallow editing built-in personas
|
||
async with db.execute("SELECT custom FROM personas WHERE persona_id = ?", (persona_id,)) as cur:
|
||
row = await cur.fetchone()
|
||
if not row or not row[0]:
|
||
return False
|
||
|
||
# rebuild prompt if user didn't explicitly set it
|
||
raw_fields = {"name", "description", "personality", "scenario", "mes_example"}
|
||
if "prompt" not in updates and (raw_fields & updates.keys()):
|
||
async with db.execute("SELECT * FROM personas WHERE persona_id = ?", (persona_id,)) as cur:
|
||
existing = await cur.fetchone()
|
||
if existing:
|
||
merged = dict(existing)
|
||
merged.update(updates)
|
||
updates["prompt"] = build_persona_prompt(merged)
|
||
|
||
if "appearance_tags" in updates and "appearance_prose" not in updates:
|
||
tags = updates["appearance_tags"].strip()
|
||
if tags:
|
||
from services.llm import send_message
|
||
try:
|
||
prose = await send_message([
|
||
{"role": "system", "content": "Convert danbooru tags to natural English description. Output only the description, no markdown."},
|
||
{"role": "user", "content": f"Tags: {tags}"}
|
||
])
|
||
updates["appearance_prose"] = prose.strip()
|
||
except Exception:
|
||
pass
|
||
|
||
cols = ", ".join(f"{k} = ?" for k in updates)
|
||
cur2 = await db.execute(
|
||
f"UPDATE personas SET {cols} WHERE persona_id = ?",
|
||
(*updates.values(), persona_id),
|
||
)
|
||
await db.commit()
|
||
return cur2.rowcount > 0
|