44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.character.card import DEFAULT_CARD, build_system_prompt, normalize_card
|
|
from app.db.models import CharacterCard
|
|
|
|
|
|
class CharacterService:
|
|
def __init__(self, db: Session, user_id: int):
|
|
self.db = db
|
|
self.user_id = user_id
|
|
|
|
def get_card(self) -> dict[str, Any]:
|
|
row = self.db.scalar(
|
|
select(CharacterCard).where(CharacterCard.user_id == self.user_id).limit(1)
|
|
)
|
|
if not row:
|
|
return normalize_card(DEFAULT_CARD)
|
|
try:
|
|
return normalize_card(json.loads(row.card_json or "{}"))
|
|
except json.JSONDecodeError:
|
|
return normalize_card(DEFAULT_CARD)
|
|
|
|
def save_card(self, raw: dict[str, Any]) -> dict[str, Any]:
|
|
card = normalize_card(raw)
|
|
row = self.db.scalar(
|
|
select(CharacterCard).where(CharacterCard.user_id == self.user_id).limit(1)
|
|
)
|
|
if not row:
|
|
row = CharacterCard(user_id=self.user_id, card_json="{}")
|
|
self.db.add(row)
|
|
self.db.flush()
|
|
row.card_json = json.dumps(card, ensure_ascii=False)
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
self.db.commit()
|
|
return card
|
|
|
|
def get_system_prompt(self) -> str:
|
|
return build_system_prompt(self.get_card())
|