251 lines
8.4 KiB
Python
251 lines
8.4 KiB
Python
from typing import Any
|
|
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
from app.character.service import CharacterService
|
|
|
|
from app.config import get_settings
|
|
|
|
from app.homelab.anima_prompt import AnimaPromptBundle, build_character_image_prompt, build_scene_tags_prompt
|
|
|
|
from app.homelab.comfyui import ComfyUIClient
|
|
|
|
from app.homelab.scene_tags import extract_scene_tags, looks_like_booru_tags
|
|
|
|
from app.integrations.rp_chat import RpChatClient
|
|
|
|
|
|
|
|
|
|
|
|
def _card_image_settings(db: Session, user_id: int) -> dict[str, Any]:
|
|
|
|
return CharacterService(db, user_id).get_card().get("data", {})
|
|
|
|
|
|
|
|
|
|
|
|
def _session_messages(db: Session, session_id: int | None, limit: int = 8) -> list[dict[str, str]]:
|
|
|
|
if not session_id:
|
|
|
|
return []
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
|
|
from app.db.models import Message
|
|
|
|
|
|
|
|
rows = db.scalars(
|
|
|
|
select(Message)
|
|
|
|
.where(
|
|
|
|
Message.session_id == session_id,
|
|
|
|
Message.role.in_(("user", "assistant")),
|
|
|
|
)
|
|
|
|
.order_by(Message.created_at.desc())
|
|
|
|
.limit(limit)
|
|
|
|
).all()
|
|
|
|
rows = list(reversed(rows))
|
|
|
|
return [{"role": m.role, "content": (m.content or "").strip()} for m in rows if m.content.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
def _last_user_message(messages: list[dict[str, str]]) -> str:
|
|
|
|
for msg in reversed(messages):
|
|
|
|
if msg.get("role") == "user" and (msg.get("content") or "").strip():
|
|
|
|
return str(msg["content"]).strip()
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
def _append_lora(positive: str, lora_name: str, lora_weight: float) -> str:
|
|
|
|
if not lora_name or f"<lora:{lora_name}" in positive:
|
|
|
|
return positive
|
|
|
|
return f"{positive} <lora:{lora_name}:{lora_weight}>"
|
|
|
|
|
|
|
|
|
|
|
|
async def _generate_from_bundle(
|
|
|
|
bundle: AnimaPromptBundle,
|
|
|
|
*,
|
|
|
|
backend: str,
|
|
|
|
persona_id: str = "",
|
|
|
|
prompt_mode: str = "direct",
|
|
|
|
tag_source: str = "",
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
if backend == "rp_chat":
|
|
|
|
client = RpChatClient()
|
|
|
|
gen_result = await client.generate(bundle.positive, bundle.negative)
|
|
|
|
if not gen_result.get("ok"):
|
|
|
|
return gen_result
|
|
|
|
saved = await client.save_image_locally(gen_result["image_path"])
|
|
|
|
if not saved.get("ok"):
|
|
|
|
return saved
|
|
|
|
return {
|
|
|
|
"ok": True,
|
|
|
|
"url": saved["url"],
|
|
|
|
"filename": saved["filename"],
|
|
|
|
"prompt": bundle.positive,
|
|
|
|
"negative_prompt": bundle.negative,
|
|
|
|
"backend": "rp_chat",
|
|
|
|
"persona_id": persona_id,
|
|
|
|
"prompt_mode": prompt_mode,
|
|
|
|
"tag_source": tag_source,
|
|
|
|
}
|
|
|
|
|
|
|
|
result = await ComfyUIClient().generate_image(
|
|
|
|
bundle.positive,
|
|
|
|
negative_prompt=bundle.negative,
|
|
|
|
)
|
|
|
|
if result.get("ok"):
|
|
|
|
result["backend"] = "comfyui_local"
|
|
|
|
result["prompt_mode"] = prompt_mode
|
|
|
|
result["negative_prompt"] = bundle.negative
|
|
|
|
result["tag_source"] = tag_source
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def _build_contextual_bundle(
|
|
|
|
appearance: str,
|
|
|
|
*,
|
|
|
|
request: str,
|
|
|
|
messages: list[dict[str, str]],
|
|
|
|
lora_name: str,
|
|
|
|
lora_weight: float,
|
|
|
|
) -> tuple[AnimaPromptBundle, str]:
|
|
|
|
tags = await extract_scene_tags(request, messages, appearance_tags=appearance)
|
|
|
|
bundle = build_character_image_prompt(
|
|
|
|
appearance,
|
|
|
|
action_tags=tags.get("action_tags", ""),
|
|
|
|
outfit_tags=tags.get("outfit_tags", ""),
|
|
|
|
environment_tags=tags.get("environment_tags", ""),
|
|
|
|
lora_name=lora_name,
|
|
|
|
lora_weight=lora_weight,
|
|
|
|
)
|
|
|
|
return bundle, str(tags.get("source") or "")
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_image(
|
|
|
|
db: Session,
|
|
|
|
*,
|
|
|
|
user_id: int,
|
|
|
|
session_id: int | None = None,
|
|
|
|
draw_self: bool = False,
|
|
|
|
scene_description: str = "",
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
card = _card_image_settings(db, user_id)
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
|
|
if not card.get("sd_enabled", True):
|
|
|
|
return {"ok": False, "error": "Генерация изображений отключена в настройках персонажа"}
|
|
|
|
|
|
|
|
if not draw_self and not scene_description.strip():
|
|
|
|
return {"ok": False, "error": "Нужен draw_self=true или scene_description"}
|
|
|
|
|
|
|
|
appearance = (card.get("appearance_tags") or "").strip()
|
|
|
|
lora_name = (card.get("lora_name") or "").strip()
|
|
|
|
lora_weight = float(card.get("lora_weight") or 0.8)
|
|
|
|
persona_id = (card.get("rp_persona_id") or "").strip() or "default"
|
|
|
|
backend = "rp_chat" if settings.rp_chat_enabled else "comfyui_local"
|
|
|
|
messages = _session_messages(db, session_id)
|
|
|
|
|
|
|
|
if draw_self:
|
|
|
|
if not appearance:
|
|
|
|
return {
|
|
|
|
"ok": False,
|
|
|
|
"error": "Заполни appearance_tags в настройках персонажа для «нарисуй себя»",
|
|
|
|
}
|
|
|
|
request = scene_description.strip() or _last_user_message(messages) or "portrait"
|
|
|
|
bundle, tag_source = await _build_contextual_bundle(
|
|
|
|
appearance,
|
|
|
|
request=request,
|
|
|
|
messages=messages,
|
|
|
|
lora_name=lora_name,
|
|
|
|
lora_weight=lora_weight,
|
|
|
|
)
|
|
|
|
return await _generate_from_bundle(
|
|
|
|
bundle,
|
|
|
|
backend=backend,
|
|
|
|
persona_id=persona_id,
|
|
|
|
prompt_mode="context_tags",
|
|
|
|
tag_source=tag_source,
|
|
|
|
)
|
|
|
|
|
|
|
|
scene = scene_description.strip()
|
|
|
|
if looks_like_booru_tags(scene):
|
|
|
|
bundle = build_scene_tags_prompt(
|
|
|
|
scene,
|
|
|
|
appearance,
|
|
|
|
lora_name=lora_name,
|
|
|
|
lora_weight=lora_weight,
|
|
|
|
)
|
|
|
|
return await _generate_from_bundle(
|
|
|
|
bundle,
|
|
|
|
backend=backend,
|
|
|
|
persona_id=persona_id,
|
|
|
|
prompt_mode="booru_literal",
|
|
|
|
tag_source="booru_literal",
|
|
|
|
)
|
|
|
|
|
|
|
|
request = scene or _last_user_message(messages)
|
|
|
|
if appearance and request:
|
|
|
|
bundle, tag_source = await _build_contextual_bundle(
|
|
|
|
appearance,
|
|
|
|
request=request,
|
|
|
|
messages=messages,
|
|
|
|
lora_name=lora_name,
|
|
|
|
lora_weight=lora_weight,
|
|
|
|
)
|
|
|
|
return await _generate_from_bundle(
|
|
|
|
bundle,
|
|
|
|
backend=backend,
|
|
|
|
persona_id=persona_id,
|
|
|
|
prompt_mode="context_tags",
|
|
|
|
tag_source=tag_source,
|
|
|
|
)
|
|
|
|
|
|
|
|
messages = messages + [{"role": "user", "content": scene}]
|
|
|
|
|
|
|
|
if settings.rp_chat_enabled:
|
|
|
|
return await _generate_via_rp_chat(
|
|
|
|
card,
|
|
|
|
messages,
|
|
|
|
appearance_override=appearance or None,
|
|
|
|
)
|
|
|
|
|
|
|
|
fallback = f"{appearance}, {scene}" if appearance else scene
|
|
|
|
return await ComfyUIClient().generate_image(fallback)
|
|
|
|
|
|
|
|
|
|
|
|
async def _generate_via_rp_chat(
|
|
|
|
card: dict[str, Any],
|
|
|
|
messages: list[dict[str, str]],
|
|
|
|
appearance_override: str | None,
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
client = RpChatClient()
|
|
|
|
persona_id = (card.get("rp_persona_id") or "").strip() or "default"
|
|
|
|
override = appearance_override or (card.get("appearance_tags") or "").strip() or None
|
|
|
|
|
|
|
|
prompt_result = await client.sd_prompt(
|
|
|
|
persona_id,
|
|
|
|
messages,
|
|
|
|
appearance_override=override,
|
|
|
|
)
|
|
|
|
if not prompt_result.get("ok"):
|
|
|
|
return prompt_result
|
|
|
|
|
|
|
|
positive = (
|
|
|
|
prompt_result.get("hybrid_positive")
|
|
|
|
or prompt_result.get("tag_positive")
|
|
|
|
or ""
|
|
|
|
).strip()
|
|
|
|
negative = (prompt_result.get("negative") or "").strip()
|
|
|
|
if not positive:
|
|
|
|
return {"ok": False, "error": "RP-чат не вернул промпт", "raw": prompt_result}
|
|
|
|
|
|
|
|
lora = (card.get("lora_name") or "").strip()
|
|
|
|
if lora:
|
|
|
|
weight = float(card.get("lora_weight") or 0.8)
|
|
|
|
positive = _append_lora(positive, lora, weight)
|
|
|
|
|
|
|
|
gen_result = await client.generate(positive, negative)
|
|
|
|
if not gen_result.get("ok"):
|
|
|
|
return gen_result
|
|
|
|
|
|
|
|
saved = await client.save_image_locally(gen_result["image_path"])
|
|
|
|
if not saved.get("ok"):
|
|
|
|
return saved
|
|
|
|
|
|
|
|
return {
|
|
|
|
"ok": True,
|
|
|
|
"url": saved["url"],
|
|
|
|
"filename": saved["filename"],
|
|
|
|
"prompt": positive,
|
|
|
|
"negative_prompt": negative,
|
|
|
|
"backend": "rp_chat",
|
|
|
|
"persona_id": persona_id,
|
|
|
|
"prompt_mode": "llm",
|
|
|
|
}
|
|
|
|
|