324 lines
12 KiB
Python
324 lines
12 KiB
Python
import json
|
|
import base64
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
from database.db import DB_PATH
|
|
|
|
|
|
def _normalize_alternate_greetings(inner: dict) -> list[str]:
|
|
raw = inner.get("alternate_greetings") or []
|
|
if not isinstance(raw, list):
|
|
return []
|
|
out = []
|
|
for item in raw:
|
|
text = str(item).strip()
|
|
if text and text not in out:
|
|
out.append(text)
|
|
return out
|
|
|
|
|
|
def parse_card_v2(data: dict, card_id: str | None = None) -> dict:
|
|
inner = data.get("data", data)
|
|
if isinstance(inner, str):
|
|
inner = json.loads(inner)
|
|
|
|
book = inner.get("character_book") or {}
|
|
entries = book.get("entries", [])
|
|
if isinstance(entries, dict):
|
|
entries = list(entries.values())
|
|
|
|
alternates = _normalize_alternate_greetings(inner)
|
|
cid = card_id or (
|
|
inner.get("name", "imported").lower().replace(" ", "_")[:48]
|
|
+ "_"
|
|
+ uuid.uuid4().hex[:8]
|
|
)
|
|
|
|
return {
|
|
"card_id": cid,
|
|
"name": inner.get("name", "Character"),
|
|
"description": inner.get("description", ""),
|
|
"personality": inner.get("personality", ""),
|
|
"scenario": inner.get("scenario", ""),
|
|
"first_mes": inner.get("first_mes", ""),
|
|
"mes_example": inner.get("mes_example", ""),
|
|
"appearance_tags": _extract_appearance(inner),
|
|
"appearance_prose": "",
|
|
"lorebook_json": json.dumps(entries, ensure_ascii=False),
|
|
"alternate_greetings": alternates,
|
|
"alternate_greetings_json": json.dumps(alternates, ensure_ascii=False),
|
|
"raw_json": json.dumps(data if "data" in data else {"data": inner}, ensure_ascii=False),
|
|
}
|
|
|
|
|
|
def parse_card_bytes(content: bytes, filename: str) -> dict:
|
|
if filename.lower().endswith(".png"):
|
|
card = parse_png_card(content)
|
|
if not card:
|
|
raise ValueError("PNG does not contain character card metadata")
|
|
card["_png_bytes"] = content
|
|
return card
|
|
return parse_card_v2(json.loads(content.decode("utf-8")))
|
|
|
|
|
|
def _extract_appearance(inner: dict) -> str:
|
|
"""Extract booru-style appearance tags from character fields."""
|
|
import re
|
|
# fall back: scan description for visual keywords, skip world-building sentences
|
|
desc = inner.get("description", "")
|
|
appearance_keywords = re.findall(
|
|
r'\b(?:'
|
|
r'\w*hair|hair\w*|\w*eyes|eye\w*|\w*skin|skin\w*'
|
|
r'|tall|short|slim|curvy|muscular|petite'
|
|
r'|ears?|tail|horns?|wings?|cloak|dress|outfit|uniform|armor'
|
|
r'|wolf\w*|cat\w*|fox\w*|elf\w*|demon\w*|angel\w*'
|
|
r'|silver|blonde|black|white|red|blue|green|purple|pink|brown|golden'
|
|
r')\b',
|
|
desc, re.IGNORECASE
|
|
)
|
|
seen = []
|
|
for kw in appearance_keywords:
|
|
kw_lower = kw.lower()
|
|
if kw_lower not in seen:
|
|
seen.append(kw_lower)
|
|
return ", ".join(seen[:20])
|
|
|
|
|
|
def parse_png_card(file_bytes: bytes) -> dict | None:
|
|
if not file_bytes.startswith(b"\x89PNG"):
|
|
return None
|
|
idx = 8 # skip PNG file signature
|
|
while idx < len(file_bytes) - 12:
|
|
length = int.from_bytes(file_bytes[idx : idx + 4], "big")
|
|
chunk_type = file_bytes[idx + 4 : idx + 8]
|
|
chunk_data = file_bytes[idx + 8 : idx + 8 + length]
|
|
if chunk_type == b"tEXt":
|
|
try:
|
|
key, _, val = chunk_data.partition(b"\x00")
|
|
if key in (b"chara", b"ccv3"):
|
|
decoded = base64.b64decode(val).decode("utf-8")
|
|
return parse_card_v2(json.loads(decoded))
|
|
except Exception:
|
|
pass
|
|
elif chunk_type == b"iTXt":
|
|
try:
|
|
# iTXt: keyword \x00 compression_flag \x00 compression_method \x00 language \x00 translated_keyword \x00 text
|
|
key, _, rest = chunk_data.partition(b"\x00")
|
|
if key in (b"chara", b"ccv3"):
|
|
# skip compression_flag, compression_method, language tag, translated keyword
|
|
text = rest[2:].split(b"\x00", 2)[-1].decode("utf-8")
|
|
# text may be base64 or raw JSON
|
|
try:
|
|
return parse_card_v2(json.loads(base64.b64decode(text).decode("utf-8")))
|
|
except Exception:
|
|
return parse_card_v2(json.loads(text))
|
|
except Exception:
|
|
pass
|
|
idx += 12 + length
|
|
return None
|
|
|
|
|
|
def build_system_prompt(card: dict) -> str:
|
|
parts = [
|
|
f"You are {card['name']}. Stay in character.",
|
|
f"Description: {card['description']}",
|
|
f"Personality: {card['personality']}",
|
|
f"Scenario: {card['scenario']}",
|
|
]
|
|
if card.get("mes_example"):
|
|
parts.append(f"Example dialogue:\n{card['mes_example']}")
|
|
parts.append("Reply only as the character. Do not add image tags.")
|
|
return "\n\n".join(p for p in parts if p.split(": ", 1)[-1].strip())
|
|
|
|
|
|
async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0.8) -> dict:
|
|
card_id = card["card_id"]
|
|
alt_json = card.get("alternate_greetings_json")
|
|
if alt_json is None:
|
|
alts = card.get("alternate_greetings") or []
|
|
alt_json = json.dumps(alts, ensure_ascii=False) if isinstance(alts, list) else "[]"
|
|
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
await db.execute(
|
|
"""INSERT INTO characters
|
|
(card_id, name, description, personality, scenario, first_mes, mes_example,
|
|
raw_json, lora_name, lora_weight, appearance_tags, appearance_prose, lorebook_json, avatar_path,
|
|
alternate_greetings_json)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
card["card_id"],
|
|
card["name"],
|
|
card["description"],
|
|
card["personality"],
|
|
card["scenario"],
|
|
card["first_mes"],
|
|
card["mes_example"],
|
|
card["raw_json"],
|
|
lora_name,
|
|
lora_weight,
|
|
card["appearance_tags"],
|
|
card.get("appearance_prose", ""),
|
|
card["lorebook_json"],
|
|
card.get("avatar_path", ""),
|
|
card.get("alternate_greetings_json", "[]"),
|
|
),
|
|
)
|
|
await db.commit()
|
|
return {**card, "lora_name": lora_name, "lora_weight": lora_weight}
|
|
|
|
|
|
async def get_character(card_id: str) -> dict | None:
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
async with db.execute(
|
|
"SELECT * FROM characters WHERE card_id = ?", (card_id,)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return card_to_api(dict(row)) if row else None
|
|
|
|
|
|
async def list_characters() -> list:
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
async with db.execute(
|
|
"SELECT card_id, name, description, lora_name FROM characters ORDER BY name"
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
async def delete_character(card_id: str) -> bool:
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
cur = await db.execute(
|
|
"DELETE FROM characters WHERE card_id = ?", (card_id,)
|
|
)
|
|
await db.commit()
|
|
return cur.rowcount > 0
|
|
|
|
|
|
async def update_appearance_tags(card_id: str, appearance_tags: str):
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
await db.execute(
|
|
"UPDATE characters SET appearance_tags = ?, appearance_prose = ? WHERE card_id = ?",
|
|
(appearance_tags, "", card_id),
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
def card_to_api(card: dict) -> dict:
|
|
alts = card.get("alternate_greetings")
|
|
if alts is None:
|
|
try:
|
|
alts = json.loads(card.get("alternate_greetings_json") or "[]")
|
|
except Exception:
|
|
alts = []
|
|
if not isinstance(alts, list):
|
|
alts = []
|
|
return {**card, "alternate_greetings": alts}
|
|
|
|
|
|
async def preview_card_file(content: bytes, filename: str) -> dict:
|
|
card = parse_card_bytes(content, filename)
|
|
png_bytes = card.pop("_png_bytes", None)
|
|
preview = card_to_api(card)
|
|
preview["is_png"] = bool(png_bytes)
|
|
preview["alternate_count"] = len(preview.get("alternate_greetings") or [])
|
|
return preview
|
|
|
|
|
|
async def update_character(card_id: str, fields: dict) -> bool:
|
|
allowed = {"name", "description", "personality", "scenario", "first_mes",
|
|
"mes_example", "appearance_tags", "appearance_prose", "lora_name", "lora_weight", "avatar_path",
|
|
"alternate_greetings_json"}
|
|
updates = {k: v for k, v in fields.items() if k in allowed}
|
|
if not updates:
|
|
return False
|
|
cols = ", ".join(f"{k} = ?" for k in updates)
|
|
async with aiosqlite.connect(DB_PATH) as db:
|
|
cur = await db.execute(
|
|
f"UPDATE characters SET {cols} WHERE card_id = ?",
|
|
(*updates.values(), card_id),
|
|
)
|
|
await db.commit()
|
|
return cur.rowcount > 0
|
|
|
|
|
|
async def import_card_file(
|
|
content: bytes,
|
|
filename: str,
|
|
lora_name: str = "",
|
|
lora_weight: float = 0.8,
|
|
overrides: dict | None = None,
|
|
card_id: str | None = None,
|
|
) -> dict:
|
|
card = parse_card_bytes(content, filename)
|
|
png_bytes = card.pop("_png_bytes", None)
|
|
|
|
if card_id:
|
|
card["card_id"] = card_id
|
|
|
|
if overrides:
|
|
for key in (
|
|
"name", "description", "personality", "scenario", "first_mes",
|
|
"mes_example", "appearance_tags", "lorebook_json",
|
|
):
|
|
if key in overrides and overrides[key] is not None:
|
|
card[key] = overrides[key]
|
|
if overrides.get("alternate_greetings_json") is not None:
|
|
card["alternate_greetings_json"] = overrides["alternate_greetings_json"]
|
|
elif overrides.get("alternate_greetings") is not None:
|
|
alts = overrides["alternate_greetings"]
|
|
if isinstance(alts, str):
|
|
try:
|
|
alts = json.loads(alts)
|
|
except Exception:
|
|
alts = []
|
|
card["alternate_greetings"] = alts
|
|
card["alternate_greetings_json"] = json.dumps(alts, ensure_ascii=False)
|
|
|
|
if png_bytes:
|
|
avatar_rel = _save_avatar_bytes(png_bytes, f"card_{card['card_id']}")
|
|
card["avatar_path"] = avatar_rel
|
|
|
|
saved = await save_character(card, lora_name=lora_name, lora_weight=lora_weight)
|
|
|
|
persona_id = f"card_{saved['card_id']}"
|
|
from services.personas import create_persona, get_persona, patch_persona
|
|
|
|
existing = await get_persona(persona_id)
|
|
persona_fields = {
|
|
"name": saved["name"],
|
|
"emoji": "🎭",
|
|
"description": (saved["description"] or "")[:80] or "Character card",
|
|
"prompt": build_system_prompt(saved),
|
|
"sd_enabled": True,
|
|
"lora_name": lora_name,
|
|
"lora_weight": lora_weight,
|
|
"appearance_tags": saved.get("appearance_tags", ""),
|
|
"appearance_prose": saved.get("appearance_prose", ""),
|
|
"avatar_path": saved.get("avatar_path", ""),
|
|
"personality": saved.get("personality", ""),
|
|
"scenario": saved.get("scenario", ""),
|
|
"first_mes": saved.get("first_mes", ""),
|
|
"mes_example": saved.get("mes_example", ""),
|
|
"lorebook_json": saved.get("lorebook_json", "[]"),
|
|
"alternate_greetings_json": saved.get("alternate_greetings_json", "[]"),
|
|
}
|
|
if not existing:
|
|
await create_persona(persona_id=persona_id, **persona_fields)
|
|
else:
|
|
await patch_persona(persona_id, persona_fields)
|
|
|
|
return card_to_api(saved)
|
|
|
|
|
|
def _save_avatar_bytes(png_bytes: bytes, prefix: str) -> str:
|
|
avatars_dir = Path("static/avatars")
|
|
avatars_dir.mkdir(parents=True, exist_ok=True)
|
|
fname = f"{prefix}_{uuid.uuid4().hex[:8]}.png"
|
|
path = avatars_dir / fname
|
|
path.write_bytes(png_bytes)
|
|
return f"avatars/{fname}"
|