commit e5c0df308f2d3372f3d102e89c0a9c90e7751132 Author: Grigo Date: Thu May 28 08:42:46 2026 +0300 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c4d9fa5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,43 @@ +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +*.egg-info/ +.eggs/ +dist/ +build/ + +# Virtual environments +.venv/ +venv/ +env/ + +# Environment & secrets +.env +.env.* +!.env.example + +# Generated images (ComfyUI / SD output) +static/images/ +data/images/ + +# Local database +data/ +*.db +*.sqlite3 + +# IDE / OS +.idea/ +.vscode/ +*.swp +*~ +.DS_Store +Thumbs.db + +# Logs +*.log diff --git a/Luna.png b/Luna.png new file mode 100644 index 0000000..b5466d0 Binary files /dev/null and b/Luna.png differ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/Rin.png b/Rin.png new file mode 100644 index 0000000..08e4b70 Binary files /dev/null and b/Rin.png differ diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/db.py b/database/db.py new file mode 100644 index 0000000..c6a69ff --- /dev/null +++ b/database/db.py @@ -0,0 +1,70 @@ +import aiosqlite +import os + +DB_PATH = os.getenv("DB_PATH", "data/chat.db") + + +async def init_db(): + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + async with aiosqlite.connect(DB_PATH) as db: + await db.executescript(""" + CREATE TABLE IF NOT EXISTS sessions ( + session_id TEXT PRIMARY KEY, + persona_id TEXT DEFAULT 'default', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + title TEXT DEFAULT 'Новый чат' + ); + + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES sessions(session_id) + ); + + CREATE INDEX IF NOT EXISTS idx_messages_session + ON messages(session_id); + + CREATE TABLE IF NOT EXISTS personas ( + persona_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + emoji TEXT DEFAULT '🤖', + description TEXT DEFAULT '', + prompt TEXT NOT NULL, + custom INTEGER DEFAULT 1, + sd_enabled INTEGER DEFAULT 0, + lora_name TEXT DEFAULT '', + lora_weight REAL DEFAULT 0.8, + appearance_tags TEXT DEFAULT '' + ); + + CREATE TABLE IF NOT EXISTS characters ( + card_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT DEFAULT '', + personality TEXT DEFAULT '', + scenario TEXT DEFAULT '', + first_mes TEXT DEFAULT '', + mes_example TEXT DEFAULT '', + raw_json TEXT NOT NULL, + lora_name TEXT DEFAULT '', + lora_weight REAL DEFAULT 0.8, + appearance_tags TEXT DEFAULT '', + lorebook_json TEXT DEFAULT '[]', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """) + await _migrate_messages_columns(db) + await db.commit() + + +async def _migrate_messages_columns(db): + async with db.execute("PRAGMA table_info(messages)") as cur: + cols = {row[1] for row in await cur.fetchall()} + if "image_prompt" not in cols: + await db.execute("ALTER TABLE messages ADD COLUMN image_prompt TEXT") + if "image_path" not in cols: + await db.execute("ALTER TABLE messages ADD COLUMN image_path TEXT") diff --git a/libretranslate/docker-compose.yml b/libretranslate/docker-compose.yml new file mode 100644 index 0000000..57cd110 --- /dev/null +++ b/libretranslate/docker-compose.yml @@ -0,0 +1,13 @@ +services: + libretranslate: + image: libretranslate/libretranslate:latest + ports: + - "5100:5000" + environment: + - LT_LOAD_ONLY=en,ru,ja,zh,ko + volumes: + - lt-data:/home/libretranslate/.local + restart: unless-stopped + +volumes: + lt-data: diff --git a/main.py b/main.py new file mode 100644 index 0000000..802a95a --- /dev/null +++ b/main.py @@ -0,0 +1,39 @@ +import logging +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from routers import chat, personas, sessions, characters, images, translate +from database.db import init_db +from services.persona_seed import seed_default_personas + +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await init_db() + await seed_default_personas() + yield + + +app = FastAPI(title="AI Chat Bot", lifespan=lifespan) + +app.include_router(chat.router) +app.include_router(personas.router) +app.include_router(sessions.router) +app.include_router(characters.router) +app.include_router(images.router) +app.include_router(translate.router) + +app.mount("/static", StaticFiles(directory="static"), name="static") + + +@app.get("/") +async def root(): + return FileResponse("static/index.html") + + +@app.get("/health") +async def health(): + return {"status": "ok"} diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/schemas.py b/models/schemas.py new file mode 100644 index 0000000..85066c5 --- /dev/null +++ b/models/schemas.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel +from typing import Optional + +class ChatRequest(BaseModel): + message: str + session_id: str + persona_id: Optional[str] = "default" + +class ChatResponse(BaseModel): + reply: str + session_id: str + image_prompt: Optional[str] = None + +class PersonaCreate(BaseModel): + persona_id: str + name: str + emoji: str = "🤖" + description: str = "" + prompt: str + sd_enabled: bool = False + lora_name: str = "" + lora_weight: float = 0.8 + appearance_tags: str = "" + +class PersonaResponse(BaseModel): + persona_id: str + name: str + emoji: str + description: str + prompt: str + custom: bool = False + sd_enabled: bool = False + lora_name: str = "" + lora_weight: float = 0.8 + appearance_tags: str = "" diff --git a/pull.sh b/pull.sh new file mode 100644 index 0000000..8fa0078 --- /dev/null +++ b/pull.sh @@ -0,0 +1,5 @@ +#!/bin/bash +rsync -avz -e "ssh -p 22022" --exclude='__pycache__' --exclude='*.pyc' --exclude='data/' \ + grigo@grigowashere.ru:/home/grigo/to_services/aiChatBot/ \ + /mnt/t/sources/aiChatBot/ +echo "✅ Скачано!" \ No newline at end of file diff --git a/push.sh b/push.sh new file mode 100644 index 0000000..ae43c6f --- /dev/null +++ b/push.sh @@ -0,0 +1,5 @@ +#!/bin/bash +rsync -avz -e "ssh -p 22022" --exclude='__pycache__' --exclude='*.pyc' --exclude='data/' \ + /mnt/t/sources/aiChatBot/ \ + grigo@grigowashere.ru:/home/grigo/to_services/aiChatBot/ +echo "✅ Залито на сервер!" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..37f6db5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,182 @@ +acme==4.0.0 +aiofiles==25.1.0 +aiosqlite==0.22.1 +annotated-types==0.7.0 +anyio==4.11.0 +argcomplete==3.6.3 +attrs==25.4.0 +autocommand==2.2.2 +Automat==25.4.16 +babel==2.17.0 +bcc==0.35.0 +bcrypt==5.0.0 +beautifulsoup4==4.14.3 +blinker==1.9.0 +boto3==1.40.72 +botocore==1.40.72 +Brotli==1.2.0 +certbot==4.0.0 +certbot-nginx==4.0.0 +certifi==2026.1.4 +chardet==5.2.0 +click==8.1.8 +command-not-found==0.3 +ConfigArgParse==1.7 +configobj==5.0.9 +constantly==23.10.4 +contourpy==1.3.3 +cryptography==46.0.5 +cssselect==1.4.0 +cycler==0.12.1 +dbus-python==1.4.0 +defusedxml==0.7.1 +distro==1.9.0 +distro-info==1.15 +dnspython==2.8.0 +docker==7.1.0 +email_validator==2.2.0 +fastapi==0.118.0 +fonttools==4.61.1 +ghp-import==2.1.0 +Glances==4.5.4 +h11==0.16.0 +html5lib-modern==1.2 +httpcore==1.0.9 +httplib2==0.22.0 +httptools==0.8.0 +httpx==0.28.1 +hyperlink==21.0.0 +idna==3.11 +incremental==24.7.2 +inflect==7.5.0 +influxdb==5.3.2 +iniconfig==2.1.0 +itsdangerous==2.2.0 +jaraco.context==6.0.1 +jaraco.functools==4.1.0 +jaraco.text==4.0.0 +Jinja2==3.1.6 +jmespath==1.0.1 +joblib==1.5.2 +josepy==2.2.0 +jsonpatch==1.32 +jsonpointer==2.4 +jsonschema==4.19.2 +jsonschema-specifications==2023.12.1 +kiwisolver==1.4.10rc0 +launchpadlib==2.1.0 +lazr.restfulclient==0.14.6 +lazr.uri==1.0.6 +libpass==1.9.3 +libvirt-python==12.0.0 +linkify-it-py==2.0.3 +livereload==2.7.1 +lunr==0.8.0 +lxml==6.0.2 +lz4==4.4.5+dfsg +Markdown==3.10.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.7+dfsg1 +mdurl==0.1.2 +mergedeep==1.3.4 +mkdocs==1.6.1 +mkdocs-get-deps==0.2.0 +more-itertools==10.8.0 +mpmath==1.3.0 +msgpack==1.1.2 +munkres==1.1.4 +mutagen==1.47.0 +netaddr==1.3.0 +netifaces==0.11.0 +nltk==3.9.2 +numpy==2.3.5 +oauthlib==3.3.1 +olefile==0.47 +orjson==3.11.5 +packaging==26.0 +parsedatetime==2.6 +pathspec==1.0.4 +pexpect==4.9.0 +pillow==12.1.1 +pipx==1.8.0 +platformdirs==4.9.4 +pluggy==1.6.0 +psutil==7.2.2 +psycopg2==2.9.11 +ptyprocess==0.7.0 +pyasn1==0.6.3 +pyasn1_modules==0.4.1 +pyasyncore==1.0.2 +pydantic==2.12.5 +pydantic_core==2.41.5 +pyelftools==0.32 +Pygments==2.19.2 +PyGObject==3.56.2 +PyHamcrest==2.1.0 +pyicu==2.16.1 +pyinotify==0.9.6 +pyinstrument==5.1.2 +PyJWT==2.10.1 +pyOpenSSL==25.3.0 +pyparsing==3.3.2 +pyRFC3339==2.0.1 +pyserial==3.5 +pysnmp==7.1.21 +pystache==0.6.8 +pytest==9.0.2 +python-apt==3.1.0+ubuntu1 +python-dateutil==2.9.0 +python-debian==1.0.1+ubuntu2 +python-dotenv==1.2.2 +python-magic==0.4.27 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.3 +pyyaml_env_tag==1.1 +referencing==0.36.2 +regex==2025.9.18 +requests==2.32.5 +rich==13.9.4 +rpds-py==0.27.1 +s3transfer==0.14.0 +screen-resolution-extra==0.0.0 +service-identity==24.2.0 +setuptools==78.1.1 +shtab==1.8.0 +six==1.17.0 +sniffio==1.3.1 +sos==4.10.2 +soupsieve==2.8.3 +speedtest-cli==2.1.3 +ssh-import-id==5.11 +starlette==0.48.0 +sympy==1.14.0 +systemd-python==235 +tornado==6.5.4 +tqdm==4.67.3 +Twisted==25.5.0 +typeguard==4.4.4 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +ubuntu-drivers-common==0.0.0 +ubuntu-pro-client==8001 +uc-micro-py==1.0.3 +ufoLib2==0.18.1 +unattended-upgrades==0.1 +unicodedata2==16.0.0 +urllib3==2.6.3 +userpath==1.9.2 +uvicorn==0.38.0 +uvloop==0.22.1 +wadllib==2.0.0 +watchdog==6.0.0 +watchfiles==1.2.0 +webencodings==0.5.1 +websockets==16.0 +wheel==0.46.3 +wsproto==1.3.2 +xkit==0.0.0 +zipp==3.23.0 +zope.interface==8.2 +zopfli==0.4.1 diff --git a/routers/__init__.py b/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/routers/characters.py b/routers/characters.py new file mode 100644 index 0000000..65a37eb --- /dev/null +++ b/routers/characters.py @@ -0,0 +1,90 @@ +from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from pydantic import BaseModel +from typing import Optional + +from services.character_card import list_characters, get_character, import_card_file, update_character, update_appearance_tags + +router = APIRouter(prefix="/characters", tags=["characters"]) + + +class CardPatch(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + personality: Optional[str] = None + scenario: Optional[str] = None + first_mes: Optional[str] = None + mes_example: Optional[str] = None + appearance_tags: Optional[str] = None + lora_name: Optional[str] = None + lora_weight: Optional[float] = None + + +@router.get("/") +async def list_all(): + return await list_characters() + + +@router.get("/{card_id}") +async def get_one(card_id: str): + card = await get_character(card_id) + if not card: + raise HTTPException(status_code=404, detail="Карточка не найдена") + return card + + +@router.patch("/{card_id}") +async def patch_card(card_id: str, body: CardPatch): + card = await get_character(card_id) + if not card: + raise HTTPException(status_code=404, detail="Карточка не найдена") + fields = {k: v for k, v in body.model_dump().items() if v is not None} + await update_character(card_id, fields) + # sync appearance_tags and lora to persona + from services.personas import update_persona_appearance + if "appearance_tags" in fields: + await update_persona_appearance(f"card_{card_id}", fields["appearance_tags"]) + if {"lora_name", "lora_weight"} & fields.keys(): + from services.personas import update_persona_lora + await update_persona_lora(f"card_{card_id}", fields.get("lora_name"), fields.get("lora_weight")) + # rebuild system prompt if character fields changed + char_fields = {"name", "description", "personality", "scenario", "first_mes", "mes_example"} + if char_fields & fields.keys(): + updated = await get_character(card_id) + from services.character_card import build_system_prompt + from services.personas import update_persona_prompt + await update_persona_prompt(f"card_{card_id}", build_system_prompt(updated)) + return await get_character(card_id) + + +@router.post("/import") +async def import_card( + file: UploadFile = File(...), + lora_name: str = Form(""), + lora_weight: float = Form(0.8), +): + content = await file.read() + try: + card = await import_card_file( + content, + file.filename or "card.json", + lora_name=lora_name, + lora_weight=lora_weight, + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + return { + "status": "imported", + "card_id": card["card_id"], + "persona_id": f"card_{card['card_id']}", + "name": card["name"], + } + + +@router.delete("/{card_id}") +async def remove_card(card_id: str): + from services.personas import delete_persona + + if not await delete_persona(f"card_{card_id}"): + raise HTTPException(status_code=404, detail="Карточка не найдена") + return {"status": "deleted", "card_id": card_id} + diff --git a/routers/chat.py b/routers/chat.py new file mode 100644 index 0000000..3a28984 --- /dev/null +++ b/routers/chat.py @@ -0,0 +1,199 @@ +import json +import os + +import aiosqlite +from fastapi import APIRouter +from fastapi.responses import StreamingResponse + +from database.db import DB_PATH +from models.schemas import ChatRequest, ChatResponse +from services.llm import send_message, stream_message +from services.memory import ( + get_history, + add_message, + clear_history, + get_or_create_session, + update_session_title, + get_message_count, + get_last_assistant_message_id, + update_message_image, +) +from services.personas import get_persona +from services.sd_prompt import ( + generate_sd_prompt, + strip_image_prompt_tag, + extract_image_prompt_tag, +) +from services.lorebook import get_lorebook_context +from services.character_card import get_character +from services import sdbackend as sd_service + +router = APIRouter(prefix="/chat", tags=["chat"]) + +DEFAULT_PROMPT = "Ты — полезный AI ассистент. Отвечай чётко и по делу." +SD_AUTO_GENERATE = os.getenv("SD_AUTO_GENERATE", "false").lower() in ("1", "true", "yes") + + +async def get_system_prompt(persona_id: str, history: list, user_message: str = "") -> str: + persona = await get_persona(persona_id) + if not persona: + return DEFAULT_PROMPT + + prompt = persona["prompt"] + + if persona_id.startswith("card_"): + card_id = persona_id[5:] + card = await get_character(card_id) + if card: + # match lorebook against recent context + current message + recent = [m for m in history if m["role"] in ("user", "assistant")][-5:] + context = recent + [{"role": "user", "content": user_message}] + lore = get_lorebook_context(card.get("lorebook_json", "[]"), context) + if lore: + prompt = prompt + "\n\n" + lore + + return prompt + + +@router.get("/history/{session_id}") +async def get_chat_history(session_id: str): + return await get_history(session_id) + + +@router.post("/init") +async def init_chat(request: ChatRequest): + """Called when opening a new chat. Seeds system prompt and first_mes if card persona.""" + persona_id = request.persona_id or "default" + await get_or_create_session(request.session_id, persona_id) + history = await get_history(request.session_id) + if history: + return {"first_mes": None} # already initialized + + system_prompt = await get_system_prompt(persona_id, [], "") + await add_message(request.session_id, "system", system_prompt) + + first_mes = None + if persona_id.startswith("card_"): + card = await get_character(persona_id[5:]) + if card and card.get("first_mes"): + first_mes = card["first_mes"] + await add_message(request.session_id, "assistant", first_mes) + + return {"first_mes": first_mes} + + +@router.post("/stream") +async def chat_stream(request: ChatRequest): + persona_id = request.persona_id or "default" + + await get_or_create_session(request.session_id, persona_id) + + history = await get_history(request.session_id) + system_prompt = await get_system_prompt(persona_id, history, request.message) + + if not history: + await add_message(request.session_id, "system", system_prompt) + elif history[0]["role"] == "system" and history[0]["content"] != system_prompt: + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + """UPDATE messages SET content = ? + WHERE session_id = ? AND role = 'system' + AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""", + (system_prompt, request.session_id, request.session_id), + ) + await db.commit() + + await add_message(request.session_id, "user", request.message) + messages = await get_history(request.session_id) + + full_reply = [] + + async def generate(): + async for chunk in stream_message( + [{"role": m["role"], "content": m["content"]} for m in messages] + ): + full_reply.append(chunk) + yield f"data: {json.dumps({'chunk': chunk})}\n\n" + + complete = "".join(full_reply) + display_text = strip_image_prompt_tag(complete) + + hist_with_reply = await get_history(request.session_id) + [ + {"role": "assistant", "content": display_text} + ] + sd_result = await generate_sd_prompt(hist_with_reply, persona_id) + prompt_str = sd_result[0] if sd_result else None + if not prompt_str: + prompt_str = extract_image_prompt_tag(complete) + + await add_message( + request.session_id, + "assistant", + display_text or complete, + image_prompt=prompt_str, + ) + + count = await get_message_count(request.session_id) + if count == 2: + title = request.message[:40] + ("…" if len(request.message) > 40 else "") + await update_session_title(request.session_id, title) + + image_path = None + image_error = None + if prompt_str and SD_AUTO_GENERATE: + rel, err = await sd_service.generate_from_full_prompt(prompt_str) + if rel: + image_path = rel + msg_id = await get_last_assistant_message_id(request.session_id) + if msg_id: + await update_message_image(msg_id, rel) + else: + image_error = err + + yield f"data: {json.dumps({ + 'done': True, + 'image_prompt': prompt_str, + 'image_path': f'/static/{image_path}' if image_path else None, + 'image_error': image_error, + })}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/", response_model=ChatResponse) +async def chat(request: ChatRequest): + persona_id = request.persona_id or "default" + await get_or_create_session(request.session_id, persona_id) + + history = await get_history(request.session_id) + system_prompt = await get_system_prompt(persona_id, history, request.message) + + if not history: + await add_message(request.session_id, "system", system_prompt) + + await add_message(request.session_id, "user", request.message) + messages = await get_history(request.session_id) + reply = await send_message( + [{"role": m["role"], "content": m["content"]} for m in messages] + ) + display = strip_image_prompt_tag(reply) + prompt_tuple = await generate_sd_prompt(messages, persona_id) + prompt_str = prompt_tuple[0] if prompt_tuple else extract_image_prompt_tag(reply) + + await add_message(request.session_id, "assistant", display, image_prompt=prompt_str) + + return ChatResponse( + reply=display, + session_id=request.session_id, + image_prompt=prompt_str, + ) + + +@router.delete("/{session_id}") +async def clear_chat(session_id: str): + await clear_history(session_id) + return {"status": "cleared", "session_id": session_id} diff --git a/routers/images.py b/routers/images.py new file mode 100644 index 0000000..dde08f2 --- /dev/null +++ b/routers/images.py @@ -0,0 +1,34 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from services import sdbackend as sd_service +from services.memory import get_last_assistant_message_id, update_message_image + +router = APIRouter(prefix="/images", tags=["images"]) + + +class GenerateRequest(BaseModel): + session_id: str + prompt: str + + +@router.get("/health") +async def sd_health(): + ok = await sd_service.check_sd() + return {"sd_available": ok, "url": sd_service.SD_BASE_URL} + + +@router.post("/generate") +async def generate_image(req: GenerateRequest): + if not req.prompt.strip(): + raise HTTPException(status_code=400, detail="Пустой промпт") + + rel, err = await sd_service.generate_from_full_prompt(req.prompt) + if not rel: + raise HTTPException(status_code=502, detail=err or "SD backend недоступен") + + msg_id = await get_last_assistant_message_id(req.session_id) + if msg_id: + await update_message_image(msg_id, rel) + + return {"image_path": f"/static/{rel}", "status": "ok"} diff --git a/routers/personas.py b/routers/personas.py new file mode 100644 index 0000000..4fccb94 --- /dev/null +++ b/routers/personas.py @@ -0,0 +1,42 @@ +from fastapi import APIRouter, HTTPException +from models.schemas import PersonaCreate +from services.personas import get_all_personas, get_persona, create_persona, delete_persona + +router = APIRouter(prefix="/personas", tags=["personas"]) + + +@router.get("/") +async def list_personas(): + personas = await get_all_personas() + return [{"persona_id": pid, **data} for pid, data in personas.items()] + + +@router.get("/{persona_id}") +async def get_one_persona(persona_id: str): + persona = await get_persona(persona_id) + if not persona: + raise HTTPException(status_code=404, detail="Персонаж не найден") + return {"persona_id": persona_id, **persona} + + +@router.post("/") +async def create_new_persona(data: PersonaCreate): + persona = await create_persona( + persona_id=data.persona_id, + name=data.name, + emoji=data.emoji, + description=data.description, + prompt=data.prompt, + sd_enabled=data.sd_enabled, + lora_name=data.lora_name, + lora_weight=data.lora_weight, + appearance_tags=data.appearance_tags, + ) + return {"persona_id": data.persona_id, **persona} + + +@router.delete("/{persona_id}") +async def remove_persona(persona_id: str): + if not await delete_persona(persona_id): + raise HTTPException(status_code=400, detail="Нельзя удалить встроенного персонажа") + return {"status": "deleted", "persona_id": persona_id} diff --git a/routers/sessions.py b/routers/sessions.py new file mode 100644 index 0000000..15b47b5 --- /dev/null +++ b/routers/sessions.py @@ -0,0 +1,48 @@ +from fastapi import APIRouter, HTTPException +from services.memory import ( + get_all_sessions, + get_or_create_session, + delete_session, + update_session_title, + update_session_persona, + get_history, + get_message_count +) + +router = APIRouter(prefix="/sessions", tags=["sessions"]) + + +@router.get("/") +async def list_sessions(): + sessions = await get_all_sessions() + result = [] + for s in sessions: + count = await get_message_count(s["session_id"]) + result.append({**s, "message_count": count}) + return result + + +@router.get("/{session_id}") +async def get_session(session_id: str): + sessions = await get_all_sessions() + s = next((x for x in sessions if x["session_id"] == session_id), None) + if not s: + raise HTTPException(status_code=404, detail="Сессия не найдена") + return s + + +@router.patch("/{session_id}") +async def patch_session(session_id: str, data: dict): + # ensure session exists before patching + await get_or_create_session(session_id, data.get("persona_id", "default")) + if "title" in data: + await update_session_title(session_id, data["title"]) + if "persona_id" in data: + await update_session_persona(session_id, data["persona_id"]) + return {"status": "updated"} + + +@router.delete("/{session_id}") +async def remove_session(session_id: str): + await delete_session(session_id) + return {"status": "deleted", "session_id": session_id} diff --git a/routers/translate.py b/routers/translate.py new file mode 100644 index 0000000..84bbfd4 --- /dev/null +++ b/routers/translate.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from services.translate import translate_to_russian + +router = APIRouter(prefix="/translate", tags=["translate"]) + + +class TranslateRequest(BaseModel): + text: str + + +@router.post("/") +async def translate(req: TranslateRequest): + try: + result = await translate_to_russian(req.text) + return {"translated": result} + except Exception as e: + raise HTTPException(status_code=502, detail=str(e)) diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/character_card.py b/services/character_card.py new file mode 100644 index 0000000..5144aa0 --- /dev/null +++ b/services/character_card.py @@ -0,0 +1,214 @@ +import json +import base64 +import uuid + +import aiosqlite +from database.db import DB_PATH + + +def parse_card_v2(data: dict) -> dict: + inner = data.get("data", data) + if isinstance(inner, str): + inner = json.loads(inner) + + book = inner.get("character_book") or {} + entries = book.get("entries", []) + if isinstance(entries, dict): + entries = list(entries.values()) + + return { + "card_id": ( + inner.get("name", "imported").lower().replace(" ", "_")[:48] + + "_" + + uuid.uuid4().hex[:8] + ), + "name": inner.get("name", "Character"), + "description": inner.get("description", ""), + "personality": inner.get("personality", ""), + "scenario": inner.get("scenario", ""), + "first_mes": inner.get("first_mes", ""), + "mes_example": inner.get("mes_example", ""), + "appearance_tags": _extract_appearance(inner), + "lorebook_json": json.dumps(entries, ensure_ascii=False), + "raw_json": json.dumps(data if "data" in data else {"data": inner}, ensure_ascii=False), + } + + +def _extract_appearance(inner: dict) -> str: + """Extract booru-style appearance tags from character fields.""" + import re + # fall back: scan description for visual keywords, skip world-building sentences + desc = inner.get("description", "") + appearance_keywords = re.findall( + r'\b(?:' + r'\w*hair|hair\w*|\w*eyes|eye\w*|\w*skin|skin\w*' + r'|tall|short|slim|curvy|muscular|petite' + r'|ears?|tail|horns?|wings?|cloak|dress|outfit|uniform|armor' + r'|wolf\w*|cat\w*|fox\w*|elf\w*|demon\w*|angel\w*' + r'|silver|blonde|black|white|red|blue|green|purple|pink|brown|golden' + r')\b', + desc, re.IGNORECASE + ) + seen = [] + for kw in appearance_keywords: + kw_lower = kw.lower() + if kw_lower not in seen: + seen.append(kw_lower) + return ", ".join(seen[:20]) + + +def parse_png_card(file_bytes: bytes) -> dict | None: + if not file_bytes.startswith(b"\x89PNG"): + return None + idx = 8 # skip PNG file signature + while idx < len(file_bytes) - 12: + length = int.from_bytes(file_bytes[idx : idx + 4], "big") + chunk_type = file_bytes[idx + 4 : idx + 8] + chunk_data = file_bytes[idx + 8 : idx + 8 + length] + if chunk_type == b"tEXt": + try: + key, _, val = chunk_data.partition(b"\x00") + if key in (b"chara", b"ccv3"): + decoded = base64.b64decode(val).decode("utf-8") + return parse_card_v2(json.loads(decoded)) + except Exception: + pass + elif chunk_type == b"iTXt": + try: + # iTXt: keyword \x00 compression_flag \x00 compression_method \x00 language \x00 translated_keyword \x00 text + key, _, rest = chunk_data.partition(b"\x00") + if key in (b"chara", b"ccv3"): + # skip compression_flag, compression_method, language tag, translated keyword + text = rest[2:].split(b"\x00", 2)[-1].decode("utf-8") + # text may be base64 or raw JSON + try: + return parse_card_v2(json.loads(base64.b64decode(text).decode("utf-8"))) + except Exception: + return parse_card_v2(json.loads(text)) + except Exception: + pass + idx += 12 + length + return None + + +def build_system_prompt(card: dict) -> str: + parts = [ + f"You are {card['name']}. Stay in character.", + f"Description: {card['description']}", + f"Personality: {card['personality']}", + f"Scenario: {card['scenario']}", + ] + if card.get("mes_example"): + parts.append(f"Example dialogue:\n{card['mes_example']}") + parts.append("Reply only as the character. Do not add image tags.") + return "\n\n".join(p for p in parts if p.split(": ", 1)[-1].strip()) + + +async def save_character(card: dict, lora_name: str = "", lora_weight: float = 0.8) -> dict: + card_id = card["card_id"] + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + """INSERT OR REPLACE INTO characters + (card_id, name, description, personality, scenario, first_mes, + mes_example, raw_json, lora_name, lora_weight, appearance_tags, lorebook_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + card_id, + card["name"], + card["description"], + card["personality"], + card["scenario"], + card["first_mes"], + card["mes_example"], + card["raw_json"], + lora_name, + lora_weight, + card.get("appearance_tags", ""), + card["lorebook_json"], + ), + ) + await db.commit() + return {**card, "lora_name": lora_name, "lora_weight": lora_weight} + + +async def get_character(card_id: str) -> dict | None: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM characters WHERE card_id = ?", (card_id,) + ) as cur: + row = await cur.fetchone() + return dict(row) if row else None + + +async def list_characters() -> list: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT card_id, name, description, lora_name FROM characters ORDER BY name" + ) as cur: + rows = await cur.fetchall() + return [dict(r) for r in rows] + + +async def delete_character(card_id: str) -> bool: + async with aiosqlite.connect(DB_PATH) as db: + cur = await db.execute( + "DELETE FROM characters WHERE card_id = ?", (card_id,) + ) + await db.commit() + return cur.rowcount > 0 + + +async def update_appearance_tags(card_id: str, appearance_tags: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE characters SET appearance_tags = ? WHERE card_id = ?", + (appearance_tags, card_id), + ) + await db.commit() + + +async def update_character(card_id: str, fields: dict) -> bool: + allowed = {"name", "description", "personality", "scenario", "first_mes", + "mes_example", "appearance_tags", "lora_name", "lora_weight"} + updates = {k: v for k, v in fields.items() if k in allowed} + if not updates: + return False + cols = ", ".join(f"{k} = ?" for k in updates) + async with aiosqlite.connect(DB_PATH) as db: + cur = await db.execute( + f"UPDATE characters SET {cols} WHERE card_id = ?", + (*updates.values(), card_id), + ) + await db.commit() + return cur.rowcount > 0 + + +async def import_card_file(content: bytes, filename: str, lora_name: str = "", lora_weight: float = 0.8) -> dict: + if filename.lower().endswith(".png"): + card = parse_png_card(content) + if not card: + raise ValueError("PNG does not contain character card metadata") + else: + card = parse_card_v2(json.loads(content.decode("utf-8"))) + + saved = await save_character(card, lora_name=lora_name, lora_weight=lora_weight) + + persona_id = f"card_{saved['card_id']}" + from services.personas import create_persona, get_persona + + existing = await get_persona(persona_id) + if not existing: + await create_persona( + persona_id=persona_id, + name=saved["name"], + emoji="🎭", + description=saved["description"][:80] or "Character card", + prompt=build_system_prompt(saved), + sd_enabled=True, + lora_name=lora_name, + lora_weight=lora_weight, + appearance_tags=saved.get("appearance_tags", ""), + ) + return saved diff --git a/services/llm.py b/services/llm.py new file mode 100644 index 0000000..17dfa8b --- /dev/null +++ b/services/llm.py @@ -0,0 +1,63 @@ +import httpx +import os +from dotenv import load_dotenv + +load_dotenv() + +OPENROUTER_KEY = os.getenv("ROUTER_KEY") +OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" +MODEL = "google/gemini-2.5-flash" + +HEADERS = { + "Authorization": f"Bearer {OPENROUTER_KEY}", + "Content-Type": "application/json", + "HTTP-Referer": "http://localhost:8000", +} + +async def send_message(messages: list) -> str: + """Обычный запрос — используем для внутренних нужд""" + payload = { + "model": MODEL, + "messages": messages, + } + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post( + OPENROUTER_URL, + headers=HEADERS, + json=payload + ) + response.raise_for_status() + data = response.json() + return data["choices"][0]["message"]["content"] + + +async def stream_message(messages: list): + """Стриминг — отдаём чанки по мере получения""" + payload = { + "model": MODEL, + "messages": messages, + "stream": True, + } + async with httpx.AsyncClient(timeout=60) as client: + async with client.stream( + "POST", + OPENROUTER_URL, + headers=HEADERS, + json=payload + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + data = line[6:] # убираем "data: " + if data == "[DONE]": + break + try: + import json + chunk = json.loads(data) + delta = chunk["choices"][0]["delta"] + content = delta.get("content", "") + if content: + yield content + except Exception: + continue diff --git a/services/lorebook.py b/services/lorebook.py new file mode 100644 index 0000000..4d461b6 --- /dev/null +++ b/services/lorebook.py @@ -0,0 +1,52 @@ +import json + + +def _match_entry(entry: dict, text: str) -> bool: + keys = entry.get("keys", []) + if isinstance(keys, str): + keys = [k.strip() for k in keys.split(",") if k.strip()] + text_lower = text.lower() + for key in keys: + if key and key.lower() in text_lower: + return True + secondary = entry.get("secondary_keys", []) or entry.get("keysecondary", []) + if isinstance(secondary, str): + secondary = [k.strip() for k in secondary.split(",") if k.strip()] + for key in secondary: + if key and key.lower() in text_lower: + return True + return False + + +def get_lorebook_context(lorebook_json: str, context: str | list, max_entries: int = 5) -> str: + """Match lorebook entries against context. + context can be a string or a list of message dicts (role/content). + """ + try: + entries = json.loads(lorebook_json or "[]") + except json.JSONDecodeError: + return "" + + if isinstance(entries, dict): + entries = list(entries.values()) + + if isinstance(context, list): + text = " ".join(m.get("content", "") for m in context if m.get("role") in ("user", "assistant")) + else: + text = context + + matched = [] + for entry in entries: + if not entry.get("enabled", True): + continue + if _match_entry(entry, text): + content = entry.get("content", "").strip() + if content: + name = entry.get("name", entry.get("comment", "Lore")) + matched.append(f"[{name}]\n{content}") + + if not matched: + return "" + + block = "\n\n".join(matched[:max_entries]) + return f"--- Lorebook (relevant world info) ---\n{block}\n---" diff --git a/services/memory.py b/services/memory.py new file mode 100644 index 0000000..4668dfe --- /dev/null +++ b/services/memory.py @@ -0,0 +1,142 @@ +import aiosqlite +from database.db import DB_PATH + + +async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM sessions WHERE session_id = ?", (session_id,) + ) as cursor: + row = await cursor.fetchone() + + if row: + return dict(row) + + await db.execute( + "INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)", + (session_id, persona_id), + ) + await db.commit() + + async with db.execute( + "SELECT * FROM sessions WHERE session_id = ?", (session_id,) + ) as cursor: + row = await cursor.fetchone() + return dict(row) + + +async def get_all_sessions() -> list: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM sessions ORDER BY updated_at DESC" + ) as cursor: + rows = await cursor.fetchall() + return [dict(r) for r in rows] + + +async def update_session_title(session_id: str, title: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE sessions SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", + (title, session_id), + ) + await db.commit() + + +async def update_session_persona(session_id: str, persona_id: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE sessions SET persona_id = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", + (persona_id, session_id), + ) + await db.commit() + + +async def delete_session(session_id: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) + await db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) + await db.commit() + + +async def get_history(session_id: str) -> list: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + """SELECT role, content, image_prompt, image_path + FROM messages WHERE session_id = ? ORDER BY id""", + (session_id,), + ) as cursor: + rows = await cursor.fetchall() + return [ + { + "role": r["role"], + "content": r["content"], + "image_prompt": r["image_prompt"], + "image_path": r["image_path"], + } + for r in rows + ] + + +async def add_message( + session_id: str, + role: str, + content: str, + image_prompt: str | None = None, + image_path: str | None = None, +): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + """INSERT INTO messages (session_id, role, content, image_prompt, image_path) + VALUES (?, ?, ?, ?, ?)""", + (session_id, role, content, image_prompt, image_path), + ) + await db.execute( + "UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", + (session_id,), + ) + await db.commit() + + +async def update_message_image(message_id: int, image_path: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE messages SET image_path = ? WHERE id = ?", + (image_path, message_id), + ) + await db.commit() + + +async def get_last_assistant_message_id(session_id: str) -> int | None: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + """SELECT id FROM messages + WHERE session_id = ? AND role = 'assistant' + ORDER BY id DESC LIMIT 1""", + (session_id,), + ) as cursor: + row = await cursor.fetchone() + return row["id"] if row else None + + +async def clear_history(session_id: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "DELETE FROM messages WHERE session_id = ?", (session_id,) + ) + await db.commit() + + +async def get_message_count(session_id: str) -> int: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT COUNT(*) as cnt FROM messages WHERE session_id = ? AND role != 'system'", + (session_id,), + ) as cursor: + row = await cursor.fetchone() + return row["cnt"] diff --git a/services/persona_seed.py b/services/persona_seed.py new file mode 100644 index 0000000..975d897 --- /dev/null +++ b/services/persona_seed.py @@ -0,0 +1,26 @@ +import aiosqlite +from database.db import DB_PATH +from services.personas import DEFAULT_PERSONAS + + +async def seed_default_personas(): + async with aiosqlite.connect(DB_PATH) as db: + for pid, data in DEFAULT_PERSONAS.items(): + await db.execute( + """INSERT OR IGNORE INTO personas + (persona_id, name, emoji, description, prompt, custom, sd_enabled, + lora_name, lora_weight, appearance_tags) + VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?)""", + ( + pid, + data["name"], + data["emoji"], + data["description"], + data["prompt"], + 1 if data.get("sd_enabled") else 0, + data.get("lora_name", ""), + data.get("lora_weight", 0.8), + data.get("appearance_tags", ""), + ), + ) + await db.commit() diff --git a/services/personas.py b/services/personas.py new file mode 100644 index 0000000..af92817 --- /dev/null +++ b/services/personas.py @@ -0,0 +1,168 @@ +from typing import Optional +import aiosqlite +from database.db import DB_PATH + +DEFAULT_PERSONAS = { + "default": { + "name": "AI Ассистент", + "emoji": "🤖", + "description": "Универсальный помощник", + "prompt": "Ты — полезный AI ассистент. Отвечай чётко и по делу.", + "sd_enabled": False, + }, + "rpg_master": { + "name": "Мастер RPG", + "emoji": "🧙", + "description": "Ведёт ролевые игры, создаёт атмосферу", + "prompt": """Ты — опытный Мастер ролевых игр. +Создавай живые описания, веди нарратив, реагируй на действия игрока. +Мир детальный, персонажи запоминающиеся. +Отвечай только текстом сюжета — без тегов изображений.""", + "sd_enabled": True, + }, + "villain": { + "name": "Злодей", + "emoji": "😈", + "description": "Харизматичный антагонист", + "prompt": """Ты — харизматичный злодей с грандиозными планами. +Говоришь театрально, с сарказмом и превосходством. +Никогда не выходишь из роли. Называешь собеседника 'герой' с иронией.""", + "sd_enabled": False, + }, + "scientist": { + "name": "Учёный", + "emoji": "🔬", + "description": "Объясняет сложное простыми словами", + "prompt": """Ты — увлечённый учёный. Объясняешь любые темы +через факты, аналогии и примеры. Любишь уточнять детали. +Иногда уходишь в интересные отступления.""", + "sd_enabled": False, + }, + "samurai": { + "name": "Самурай", + "emoji": "⚔️", + "description": "Мудрый воин феодальной Японии", + "prompt": """Ты — самурай феодальной Японии. +Говоришь кратко, мудро, с достоинством. +Используешь метафоры природы и войны. +Чтишь кодекс бусидо.""", + "sd_enabled": True, + "appearance_tags": "samurai armor, katana, feudal japan", + }, +} + + +def _row_to_persona(row: dict) -> dict: + return { + "name": row["name"], + "emoji": row["emoji"], + "description": row["description"], + "prompt": row["prompt"], + "custom": bool(row["custom"]), + "sd_enabled": bool(row["sd_enabled"]), + "lora_name": row["lora_name"] or "", + "lora_weight": row["lora_weight"] if row["lora_weight"] is not None else 0.8, + "appearance_tags": row["appearance_tags"] or "", + } + + +async def get_all_personas() -> dict: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute("SELECT * FROM personas ORDER BY custom ASC, persona_id ASC") as cur: + rows = await cur.fetchall() + return {r["persona_id"]: _row_to_persona(dict(r)) for r in rows} + + +async def get_persona(persona_id: str) -> Optional[dict]: + async with aiosqlite.connect(DB_PATH) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM personas WHERE persona_id = ?", (persona_id,) + ) as cur: + row = await cur.fetchone() + if not row: + return None + return _row_to_persona(dict(row)) + + +async def create_persona( + persona_id: str, + name: str, + emoji: str, + description: str, + prompt: str, + sd_enabled: bool = False, + lora_name: str = "", + lora_weight: float = 0.8, + appearance_tags: str = "", +) -> dict: + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + """INSERT INTO personas + (persona_id, name, emoji, description, prompt, custom, + sd_enabled, lora_name, lora_weight, appearance_tags) + VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?, ?)""", + ( + persona_id, name, emoji, description, prompt, + 1 if sd_enabled else 0, lora_name, lora_weight, appearance_tags, + ), + ) + await db.commit() + return { + "name": name, + "emoji": emoji, + "description": description, + "prompt": prompt, + "custom": True, + "sd_enabled": sd_enabled, + "lora_name": lora_name, + "lora_weight": lora_weight, + "appearance_tags": appearance_tags, + } + + +async def delete_persona(persona_id: str) -> bool: + async with aiosqlite.connect(DB_PATH) as db: + async with db.execute( + "SELECT custom FROM personas WHERE persona_id = ?", (persona_id,) + ) as cur: + row = await cur.fetchone() + if not row or not row[0]: + return False + await db.execute("DELETE FROM personas WHERE persona_id = ?", (persona_id,)) + await db.commit() + + if persona_id.startswith("card_"): + from services.character_card import delete_character + await delete_character(persona_id[5:]) + + return True + + +async def update_persona_appearance(persona_id: str, appearance_tags: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute( + "UPDATE personas SET appearance_tags = ? WHERE persona_id = ?", + (appearance_tags, persona_id), + ) + await db.commit() + + +async def update_persona_lora(persona_id: str, lora_name: str | None, lora_weight: float | None): + fields, vals = [], [] + if lora_name is not None: + fields.append("lora_name = ?"); vals.append(lora_name) + if lora_weight is not None: + fields.append("lora_weight = ?"); vals.append(lora_weight) + if not fields: + return + async with aiosqlite.connect(DB_PATH) as db: + await db.execute(f"UPDATE personas SET {', '.join(fields)} WHERE persona_id = ?", (*vals, persona_id)) + await db.commit() + + +async def update_persona_prompt(persona_id: str, prompt: str): + async with aiosqlite.connect(DB_PATH) as db: + await db.execute("UPDATE personas SET prompt = ? WHERE persona_id = ?", (prompt, persona_id)) + await db.commit() diff --git a/services/sd_prompt.py b/services/sd_prompt.py new file mode 100644 index 0000000..398596e --- /dev/null +++ b/services/sd_prompt.py @@ -0,0 +1,125 @@ +import json +import os +import re +from services.llm import send_message +from services.personas import get_persona + +PROMPT_BUILDER_SYSTEM = """You are a Stable Diffusion prompt engineer for anime illustration models. +Given a roleplay chat excerpt and character appearance hints, output ONLY valid JSON (no markdown): +{ + "should_generate": true, + "shot_type": "first_person_pov" | "landscape" | "third_person", + "appearance_tags": "booru-style tags for character appearance extracted from hints, e.g. 'white hair, wolf ears, wolf tail, yellow eyes'", + "action_tags": "booru-style tags for pose/action, e.g. 'sitting, smiling, looking at viewer'", + "environment_tags": "booru-style tags for location/lighting, e.g. 'indoors, kitchen, sunlight'" +} +Rules: +- ONLY use real danbooru/e621 tags. Multi-word concepts MUST be written as single tags: 'white hair' not 'white, hair'. 'wolf ears' not 'wolf, ears'. +- Do NOT include quality tags, model names, style words, 'pov', or category/metadata words. +- Do NOT invent tags. If unsure — omit. +- Keep each field to 3-6 tags.""" + + +def extract_image_prompt_tag(text: str) -> str | None: + if "[IMAGE_PROMPT:" not in text: + return None + try: + start = text.index("[IMAGE_PROMPT:") + len("[IMAGE_PROMPT:") + end = text.index("]", start) + return text[start:end].strip() + except ValueError: + return None + + +def strip_image_prompt_tag(text: str) -> str: + return re.sub(r"\[IMAGE_PROMPT:.*?\]", "", text, flags=re.DOTALL).strip() + + +PONY_CHECKPOINTS = {"ponyDiffusionV6XL_v6StartWithThisOne.safetensors"} +SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "") +PONY_NEGATIVE = "score_1, score_2, score_3, score_4, worst quality, low quality, blurry, bad anatomy, watermark, text, censored" + +def build_positive_prompt(scene: dict, persona: dict | None) -> str: + is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS + quality = "score_9, score_8_up, score_7_up, source_anime, highres" if is_pony else "masterpiece, best quality, highres" + parts = [quality] + + # prefer LLM-extracted appearance over raw persona tags + appearance = scene.get("appearance_tags") or (persona or {}).get("appearance_tags", "") + if appearance: + parts.append(appearance) + + if scene.get("shot_type") == "landscape": + parts.append(scene.get("environment_tags", "")) + else: + if scene.get("shot_type") == "first_person_pov": + parts.append("pov, first-person view, looking at viewer") + parts.append(scene.get("action_tags", "")) + parts.append(scene.get("environment_tags", "")) + + lora = (persona or {}).get("lora_name", "") + weight = (persona or {}).get("lora_weight", 0.8) + if lora: + parts.append(f"") + + positive = ", ".join(p.strip() for p in parts if p and p.strip()) + seen, deduped = set(), [] + for tag in positive.split(", "): + t = tag.strip() + if t and t not in seen: + seen.add(t) + deduped.append(t) + return ", ".join(deduped) + + +async def generate_sd_prompt( + messages: list, + persona_id: str, +) -> tuple[str | None, str | None]: + persona = await get_persona(persona_id) + if not persona or not persona.get("sd_enabled"): + return None, None + + recent = [m for m in messages if m["role"] in ("user", "assistant")][-6:] + if not recent: + return None, None + + excerpt = "\n".join(f"{m['role']}: {strip_image_prompt_tag(m['content'])}" for m in recent) + + appearance = persona.get("appearance_tags", "") + # For card personas, also include description for better visual context + if persona_id.startswith("card_"): + from services.character_card import get_character + card = await get_character(persona_id[5:]) + if card and card.get("description"): + appearance = f"{appearance}\nCharacter description: {card['description'][:400]}" + + builder_messages = [ + {"role": "system", "content": PROMPT_BUILDER_SYSTEM}, + { + "role": "user", + "content": f"Persona appearance hints: {appearance}\n\nChat:\n{excerpt}", + }, + ] + + try: + raw = await send_message(builder_messages) + raw = raw.strip() + if raw.startswith("```"): + raw = re.sub(r"^```\w*\n?", "", raw) + raw = re.sub(r"\n?```$", "", raw) + scene = json.loads(raw) + except (json.JSONDecodeError, Exception): + return None, None + + + positive = build_positive_prompt(scene, persona) + is_pony = SD_CHECKPOINT in PONY_CHECKPOINTS + negative = PONY_NEGATIVE if is_pony else "low quality, blurry, bad anatomy, watermark, text" + if scene.get("shot_type") == "first_person_pov": + negative += ", third person, over the shoulder" + + full = positive + if negative: + full += f"\n\nNegative prompt: {negative}" + return full, negative diff --git a/services/sdbackend.py b/services/sdbackend.py new file mode 100644 index 0000000..fbd0691 --- /dev/null +++ b/services/sdbackend.py @@ -0,0 +1,121 @@ +import asyncio +import logging +import os +import uuid +from pathlib import Path + +import httpx +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + +SD_BASE_URL = os.getenv("SD_BASE_URL", "http://127.0.0.1:8188").rstrip("/") +SD_STEPS = int(os.getenv("SD_STEPS", "28")) +SD_CFG = float(os.getenv("SD_CFG", "7")) +SD_SAMPLER = os.getenv("SD_SAMPLER", "euler") +SD_SCHEDULER = os.getenv("SD_SCHEDULER", "normal") +SD_CHECKPOINT = os.getenv("SD_CHECKPOINT", "NetaYumev35_pretrained_all_in_one.safetensors") +SD_DEFAULT_NEGATIVE = os.getenv( + "SD_DEFAULT_NEGATIVE", + "low quality, worst quality, blurry, bad anatomy, watermark, text", +) +IMAGES_DIR = Path(os.getenv("IMAGES_DIR", "static/images")) + + +def split_prompt_and_negative(full_prompt: str) -> tuple[str, str]: + if "\n\nNegative prompt:" in full_prompt: + pos, _, neg = full_prompt.partition("\n\nNegative prompt:") + return pos.strip(), neg.strip() + return full_prompt.strip(), SD_DEFAULT_NEGATIVE + + +def _build_workflow(positive: str, negative: str) -> dict: + """Minimal KSampler workflow for ComfyUI API.""" + return { + "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": SD_CHECKPOINT}}, + "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 832, "height": 1216, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", "inputs": {"text": positive, "clip": ["4", 1]}}, + "7": {"class_type": "CLIPTextEncode", "inputs": {"text": negative, "clip": ["4", 1]}}, + "8": {"class_type": "VAEDecode", "inputs": {"samples": ["10", 0], "vae": ["4", 2]}}, + "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "chatbot", "images": ["8", 0]}}, + "10": { + "class_type": "KSampler", + "inputs": { + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0], + "seed": int(uuid.uuid4().int % 2**32), + "steps": SD_STEPS, + "cfg": SD_CFG, + "sampler_name": SD_SAMPLER, + "scheduler": SD_SCHEDULER, + "denoise": 1.0, + }, + }, + } + + +async def check_sd() -> bool: + try: + async with httpx.AsyncClient(timeout=5) as client: + r = await client.get(f"{SD_BASE_URL}/system_stats") + return r.status_code == 200 + except Exception: + return False + + +async def txt2img(prompt: str, negative_prompt: str | None = None) -> tuple[bytes, str]: + neg = negative_prompt or SD_DEFAULT_NEGATIVE + workflow = _build_workflow(prompt, neg) + client_id = uuid.uuid4().hex + + logger.info("ComfyUI request → %s prompt: %.120s", SD_BASE_URL, prompt) + async with httpx.AsyncClient(timeout=300) as client: + # queue the prompt + resp = await client.post( + f"{SD_BASE_URL}/prompt", + json={"prompt": workflow, "client_id": client_id}, + ) + resp.raise_for_status() + prompt_id = resp.json()["prompt_id"] + logger.info("ComfyUI queued prompt_id=%s", prompt_id) + + # poll until done + for _ in range(300): + await asyncio.sleep(1) + hist = await client.get(f"{SD_BASE_URL}/history/{prompt_id}") + data = hist.json() + if prompt_id in data: + outputs = data[prompt_id]["outputs"] + # find first image output + for node_output in outputs.values(): + if "images" in node_output: + img_info = node_output["images"][0] + img_resp = await client.get( + f"{SD_BASE_URL}/view", + params={"filename": img_info["filename"], "subfolder": img_info.get("subfolder", ""), "type": img_info.get("type", "output")}, + ) + img_resp.raise_for_status() + image_bytes = img_resp.content + + IMAGES_DIR.mkdir(parents=True, exist_ok=True) + filename = f"{uuid.uuid4().hex}.png" + (IMAGES_DIR / filename).write_bytes(image_bytes) + logger.info("ComfyUI done → saved %s", filename) + return image_bytes, f"images/{filename}" + break + + raise RuntimeError("ComfyUI generation timed out or produced no output") + + +async def generate_from_full_prompt(full_prompt: str) -> tuple[str | None, str | None]: + positive, negative = split_prompt_and_negative(full_prompt) + try: + _, rel_path = await txt2img(positive, negative) + return rel_path, None + except Exception as e: + logger.error("ComfyUI error: %s", e) + return None, str(e) diff --git a/services/translate.py b/services/translate.py new file mode 100644 index 0000000..123d5fa --- /dev/null +++ b/services/translate.py @@ -0,0 +1,17 @@ +import os +import httpx +from dotenv import load_dotenv + +load_dotenv() + +LIBRETRANSLATE_URL = os.getenv("LIBRETRANSLATE_URL", "http://192.168.1.109:5100") + + +async def translate_to_russian(text: str) -> str: + async with httpx.AsyncClient(timeout=30) as client: + r = await client.post( + f"{LIBRETRANSLATE_URL}/translate", + json={"q": text, "source": "auto", "target": "ru", "format": "text"}, + ) + r.raise_for_status() + return r.json()["translatedText"] diff --git a/static/css/app.css b/static/css/app.css new file mode 100644 index 0000000..2fc4c95 --- /dev/null +++ b/static/css/app.css @@ -0,0 +1,306 @@ +* { margin: 0; padding: 0; box-sizing: border-box; } + +body { + background: #1a1a2e; + color: #e0e0e0; + font-family: 'Segoe UI', sans-serif; + height: 100vh; + display: flex; + flex-direction: column; + overflow: hidden; +} + +header { + width: 100%; + padding: 12px 20px; + background: #16213e; + border-bottom: 1px solid #0f3460; + display: flex; + align-items: center; + gap: 12px; + flex-shrink: 0; + z-index: 10; +} + +header h1 { font-size: 1.1rem; color: #e94560; } + +#sidebarToggle { + background: none; + border: 1px solid #0f3460; + border-radius: 8px; + color: #888; + padding: 4px 10px; + cursor: pointer; + font-size: 1rem; + transition: all 0.2s; +} +#sidebarToggle:hover { border-color: #e94560; color: #e94560; } + +.header-title { + flex: 1; + font-size: 0.9rem; + color: #ccc; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.app-body { display: flex; flex: 1; overflow: hidden; } + +.sidebar { + width: 260px; + background: #16213e; + border-right: 1px solid #0f3460; + display: flex; + flex-direction: column; + flex-shrink: 0; + transition: width 0.25s ease, opacity 0.25s ease; + overflow: hidden; +} +.sidebar.collapsed { width: 0; opacity: 0; pointer-events: none; } + +.sidebar-header { + padding: 12px 14px; + display: flex; + align-items: center; + justify-content: space-between; + border-bottom: 1px solid #0f3460; +} +.sidebar-header span { + font-size: 0.8rem; + color: #888; + text-transform: uppercase; + letter-spacing: 0.05em; +} + +#newChatBtn { + background: #e94560; + border: none; + border-radius: 8px; + color: white; + padding: 5px 12px; + font-size: 0.8rem; + cursor: pointer; +} +#newChatBtn:hover { background: #c73652; } + +.session-list { flex: 1; overflow-y: auto; padding: 8px 0; } +.session-item { + display: flex; + align-items: center; + gap: 8px; + padding: 9px 14px; + cursor: pointer; + border-left: 3px solid transparent; +} +.session-item:hover { background: #1a1a2e; } +.session-item.active { background: #1a1a2e; border-left-color: #e94560; } +.session-item .s-title { flex: 1; font-size: 0.82rem; color: #ccc; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; } +.session-item .s-meta { font-size: 0.7rem; color: #555; } +.session-item .s-del { background: none; border: none; color: #555; cursor: pointer; opacity: 0; } +.session-item:hover .s-del { opacity: 1; } +.session-item .s-del:hover { color: #e94560; } + +.main { flex: 1; display: flex; flex-direction: column; overflow: hidden; } + +.persona-bar { + display: flex; + gap: 8px; + padding: 8px 16px; + overflow-x: auto; + border-bottom: 1px solid #0f3460; +} + +.persona-card { + display: flex; + flex-direction: column; + align-items: center; + gap: 2px; + padding: 6px 12px; + background: #16213e; + border: 1px solid #0f3460; + border-radius: 10px; + cursor: pointer; + position: relative; + flex-shrink: 0; +} +.persona-card:hover { border-color: #e94560; } +.persona-card.active { border-color: #e94560; background: #1f1535; } +.persona-card .emoji { font-size: 1.2rem; } +.persona-card .pname { font-size: 0.7rem; color: #ccc; } +.persona-card .del-btn { + position: absolute; top: -5px; right: -5px; + width: 14px; height: 14px; + background: #e94560; border: none; border-radius: 50%; + color: white; font-size: 0.55rem; cursor: pointer; + display: none; +} +.persona-card .edit-btn { + position: absolute; top: -5px; left: -5px; + width: 16px; height: 16px; + background: #0f3460; border: 1px solid #e94560; border-radius: 50%; + color: white; font-size: 0.55rem; cursor: pointer; + display: none; align-items: center; justify-content: center; +} +.persona-card:hover .del-btn { display: flex; align-items: center; justify-content: center; } +.persona-card:hover .edit-btn { display: flex; } + +.persona-add, .card-import-btn { + display: flex; flex-direction: column; align-items: center; + padding: 6px 12px; + background: transparent; + border: 1px dashed #0f3460; border-radius: 10px; + cursor: pointer; color: #555; font-size: 0.7rem; + flex-shrink: 0; +} +.persona-add:hover, .card-import-btn:hover { border-color: #e94560; color: #e94560; } + +.messages { + flex: 1; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 12px; + padding: 16px; +} + +.message { display: flex; flex-direction: column; max-width: 75%; animation: fadeIn 0.2s ease; } +@keyframes fadeIn { from { opacity: 0; transform: translateY(6px); } to { opacity: 1; transform: translateY(0); } } +.message.user { align-self: flex-end; } +.message.assistant { align-self: flex-start; } + +.bubble { + padding: 10px 14px; + border-radius: 16px; + line-height: 1.5; + font-size: 0.95rem; + white-space: pre-wrap; + word-break: break-word; +} +.bubble.typing-active::after { content: '▋'; animation: blink 0.7s infinite; color: #e94560; } +@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } } + +.message.user .bubble { background: #0f3460; border-bottom-right-radius: 4px; } +.message.assistant .bubble { background: #16213e; border: 1px solid #0f3460; border-bottom-left-radius: 4px; } + +.label { font-size: 0.7rem; color: #888; margin-bottom: 4px; padding: 0 4px; } +.message.user .label { text-align: right; } + +.image-prompt-block { + margin-top: 8px; padding: 8px 12px; + background: #1a1a2e; + border: 1px dashed #e94560; + border-radius: 8px; + font-size: 0.8rem; color: #e94560; +} +.image-prompt-header { display: flex; align-items: center; justify-content: space-between; gap: 8px; } +.image-prompt-block .prompt-text { display: block; color: #aaa; margin-top: 4px; font-style: italic; white-space: pre-wrap; } + +.copy-prompt-btn, .gen-image-btn { + background: #0f3460; + border: 1px solid #e94560; + border-radius: 6px; + color: #e94560; + font-size: 0.7rem; + padding: 2px 8px; + cursor: pointer; +} +.copy-prompt-btn:hover, .gen-image-btn:hover { background: #e94560; color: white; } + +.translate-btn { + align-self: flex-end; + background: #0f3460; + border: 1px solid #4a90d9; + border-radius: 6px; + color: #4a90d9; + font-size: 0.7rem; + padding: 2px 8px; + cursor: pointer; + margin-top: 4px; +} +.translate-btn:hover { background: #4a90d9; color: white; } +.translate-btn:disabled { opacity: 0.5; cursor: default; } + +.chat-image { margin-top: 8px; max-width: 100%; border-radius: 8px; border: 1px solid #0f3460; } +.image-error { margin-top: 6px; font-size: 0.75rem; color: #888; } + +.typing { + align-self: flex-start; + display: flex; gap: 4px; + padding: 12px 16px; + background: #16213e; + border: 1px solid #0f3460; + border-radius: 16px; +} +.typing span { width: 6px; height: 6px; background: #888; border-radius: 50%; animation: bounce 1.2s infinite; } +.typing span:nth-child(2) { animation-delay: 0.2s; } +.typing span:nth-child(3) { animation-delay: 0.4s; } +@keyframes bounce { 0%, 60%, 100% { transform: translateY(0); } 30% { transform: translateY(-6px); } } + +.input-area { + display: flex; gap: 10px; + padding: 12px 16px; + border-top: 1px solid #0f3460; +} + +textarea { + flex: 1; + background: #16213e; + border: 1px solid #0f3460; + border-radius: 12px; + color: #e0e0e0; + font-size: 0.95rem; + padding: 10px 14px; + resize: none; outline: none; + font-family: inherit; + max-height: 120px; +} +textarea:focus { border-color: #e94560; } + +#sendBtn { + background: #e94560; border: none; + border-radius: 12px; color: white; + padding: 0 20px; cursor: pointer; +} +#sendBtn:disabled { background: #555; cursor: not-allowed; } + +#clearBtn { + background: transparent; + border: 1px solid #0f3460; + border-radius: 12px; color: #888; + padding: 0 14px; cursor: pointer; +} +#clearBtn:hover { border-color: #e94560; color: #e94560; } + +.modal-overlay { + display: none; position: fixed; inset: 0; + background: rgba(0,0,0,0.7); + z-index: 100; align-items: center; justify-content: center; +} +.modal-overlay.open { display: flex; } + +.modal { + background: #16213e; border: 1px solid #0f3460; + border-radius: 16px; padding: 24px; + width: 100%; max-width: 440px; + display: flex; flex-direction: column; gap: 12px; +} +.modal h2 { font-size: 1.1rem; color: #e94560; } +.modal label { display: flex; flex-direction: column; gap: 4px; font-size: 0.8rem; color: #888; } +.modal input, .modal textarea { + background: #1a1a2e; border: 1px solid #0f3460; + border-radius: 8px; color: #e0e0e0; + padding: 8px 10px; outline: none; font-family: inherit; +} +.modal-buttons { display: flex; gap: 8px; justify-content: flex-end; } +.modal-buttons button { padding: 8px 18px; border-radius: 8px; border: none; cursor: pointer; } +#modalCancel, #cardModalCancel { background: #0f3460; color: #aaa; } +#modalSave, #cardModalImport { background: #e94560; color: white; } + +.empty-state { + flex: 1; display: flex; + align-items: center; justify-content: center; + color: #444; flex-direction: column; gap: 8px; +} +.empty-state .big { font-size: 2.5rem; } +.hidden { display: none !important; } diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..cc25218 --- /dev/null +++ b/static/index.html @@ -0,0 +1,116 @@ + + + + + + AI Chat + + + + +
+ +

🤖 AI Chat

+ Новый чат +
+ +
+ + +
+
+
+
+ 💬 + Начни новый чат +
+
+
+ + + +
+
+
+ + + + + + + + + + diff --git a/static/js/app.js b/static/js/app.js new file mode 100644 index 0000000..41f143d --- /dev/null +++ b/static/js/app.js @@ -0,0 +1,30 @@ +import { toggleSidebar, dom } from './state.js'; +import { initSessions, createNewChat } from './sessions.js'; +import { loadPersonas, initPersonaModals } from './personas.js'; +import { sendMessage, clearHistory } from './chat.js'; + +document.getElementById('sidebarToggle').addEventListener('click', () => { + const open = toggleSidebar(); + document.getElementById('sidebar').classList.toggle('collapsed', !open); +}); + +document.getElementById('newChatBtn').addEventListener('click', createNewChat); + +dom.inputEl.addEventListener('input', () => { + dom.inputEl.style.height = 'auto'; + dom.inputEl.style.height = dom.inputEl.scrollHeight + 'px'; +}); + +dom.inputEl.addEventListener('keydown', (e) => { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); + sendMessage(); + } +}); + +dom.sendBtn.addEventListener('click', sendMessage); +dom.clearBtn.addEventListener('click', clearHistory); + +initPersonaModals(); +await initSessions(); +loadPersonas(); diff --git a/static/js/chat.js b/static/js/chat.js new file mode 100644 index 0000000..bcfd75f --- /dev/null +++ b/static/js/chat.js @@ -0,0 +1,260 @@ +import { sessionId, currentPersona, dom } from './state.js'; +import { parseImagePromptFromContent, copyToClipboard } from './utils.js'; + +export async function initChat() { + if (!sessionId || !currentPersona) return; + const res = await fetch('/chat/init', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ message: '', session_id: sessionId, persona_id: currentPersona }), + }); + if (!res.ok) return; + const data = await res.json(); + if (data.first_mes) addMessage('assistant', data.first_mes); +} + +export function updateEmptyState() { + const hasMessages = dom.messagesEl.querySelector('.message'); + dom.emptyState?.classList.toggle('hidden', !!hasMessages); +} + +export function createImagePromptBlock(promptText) { + const block = document.createElement('div'); + block.className = 'image-prompt-block'; + + const header = document.createElement('div'); + header.className = 'image-prompt-header'; + header.innerHTML = '🎨 SD prompt'; + + const copyBtn = document.createElement('button'); + copyBtn.type = 'button'; + copyBtn.className = 'copy-prompt-btn'; + copyBtn.textContent = 'Копировать'; + copyBtn.addEventListener('click', async () => { + const ok = await copyToClipboard(promptText); + copyBtn.textContent = ok ? 'Скопировано' : 'Ошибка'; + setTimeout(() => { copyBtn.textContent = 'Копировать'; }, 1500); + }); + header.appendChild(copyBtn); + + const genBtn = document.createElement('button'); + genBtn.type = 'button'; + genBtn.className = 'gen-image-btn'; + genBtn.textContent = '🖼 Генерировать'; + genBtn.addEventListener('click', () => generateImageViaA1111(promptText, block)); + header.appendChild(genBtn); + + const textEl = document.createElement('span'); + textEl.className = 'prompt-text'; + textEl.textContent = promptText; + + block.appendChild(header); + block.appendChild(textEl); + return block; +} + +async function generateImageViaA1111(promptText, block) { + block.parentElement.querySelector('.chat-image')?.remove(); + block.parentElement.querySelector('.image-error')?.remove(); + + try { + const res = await fetch('/images/generate', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ session_id: sessionId, prompt: promptText }), + }); + const data = await res.json(); + if (!res.ok) throw new Error(data.detail || res.statusText); + + const img = document.createElement('img'); + img.className = 'chat-image'; + img.src = data.image_path; + block.parentElement.appendChild(img); + } catch (e) { + const err = document.createElement('div'); + err.className = 'image-error'; + err.textContent = '🖼 ' + e.message; + block.parentElement.appendChild(err); + } +} + +export function appendChatImage(wrapper, imagePath) { + if (!imagePath) return; + const img = document.createElement('img'); + img.className = 'chat-image'; + img.src = imagePath; + wrapper.appendChild(img); +} + +export function addMessage(role, content = '', imagePrompt = null, imagePath = null) { + updateEmptyState(); + + const wrapper = document.createElement('div'); + wrapper.className = `message ${role}`; + + const label = document.createElement('div'); + label.className = 'label'; + label.textContent = role === 'user' ? 'Вы' : 'AI'; + wrapper.appendChild(label); + + let displayContent = content; + let prompt = imagePrompt; + if (role === 'assistant' && !prompt) { + const parsed = parseImagePromptFromContent(content); + displayContent = parsed.text; + prompt = parsed.prompt; + } + + const bubble = document.createElement('div'); + bubble.className = 'bubble'; + bubble.textContent = displayContent; + wrapper.appendChild(bubble); + + if (role === 'assistant') { + const translateBtn = document.createElement('button'); + translateBtn.type = 'button'; + translateBtn.className = 'translate-btn'; + translateBtn.textContent = '🌐 RU'; + let originalText = null; + translateBtn.addEventListener('click', async () => { + if (originalText !== null) { + bubble.textContent = originalText; + originalText = null; + translateBtn.textContent = '🌐 RU'; + return; + } + originalText = bubble.textContent; + translateBtn.disabled = true; + translateBtn.textContent = '…'; + try { + const res = await fetch('/translate/', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ text: originalText }), + }); + if (!res.ok) throw new Error(res.statusText); + const data = await res.json(); + bubble.textContent = data.translated; + translateBtn.textContent = '↩ Оригинал'; + } catch { + originalText = null; + translateBtn.textContent = '⚠️'; + setTimeout(() => { translateBtn.textContent = '🌐 RU'; }, 2000); + } + translateBtn.disabled = false; + }); + wrapper.appendChild(translateBtn); + } + + if (prompt) wrapper.appendChild(createImagePromptBlock(prompt)); + if (imagePath) appendChatImage(wrapper, imagePath); + + dom.messagesEl.appendChild(wrapper); + dom.messagesEl.scrollTop = dom.messagesEl.scrollHeight; + return bubble; +} + +export function showTyping() { + const typing = document.createElement('div'); + typing.className = 'typing'; + typing.id = 'typing'; + typing.innerHTML = ''; + dom.messagesEl.appendChild(typing); + dom.messagesEl.scrollTop = dom.messagesEl.scrollHeight; +} + +export function removeTyping() { + document.getElementById('typing')?.remove(); +} + +export function clearMessages() { + dom.messagesEl.innerHTML = ''; + if (dom.emptyState) { + dom.messagesEl.appendChild(dom.emptyState); + dom.emptyState.classList.remove('hidden'); + } +} + +export async function sendMessage() { + const text = dom.inputEl.value.trim(); + if (!text || !sessionId) return; + + dom.inputEl.value = ''; + dom.inputEl.style.height = 'auto'; + dom.sendBtn.disabled = true; + + addMessage('user', text); + showTyping(); + + try { + const res = await fetch('/chat/stream', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ message: text, session_id: sessionId, persona_id: currentPersona }), + }); + if (!res.ok) throw new Error('Ошибка сервера: ' + res.status); + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let bubble = null; + + removeTyping(); + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop(); + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + try { + const data = JSON.parse(line.slice(6)); + + if (data.chunk !== undefined) { + if (!bubble) { + bubble = addMessage('assistant', ''); + bubble.classList.add('typing-active'); + } + bubble.textContent += data.chunk; + bubble.textContent = bubble.textContent.replace(/\[IMAGE_PROMPT:.*?\]/gs, '').trim(); + dom.messagesEl.scrollTop = dom.messagesEl.scrollHeight; + } + + if (data.done) { + bubble?.classList.remove('typing-active'); + if (data.image_prompt && bubble) { + bubble.parentElement.appendChild(createImagePromptBlock(data.image_prompt)); + } + if (data.image_path && bubble) { + appendChatImage(bubble.parentElement, data.image_path); + } + if (data.image_error && bubble) { + const err = document.createElement('div'); + err.className = 'image-error'; + err.textContent = '🖼 ' + data.image_error; + bubble.parentElement.appendChild(err); + } + const { loadSessions } = await import('./sessions.js'); + loadSessions(); + } + } catch { /* skip */ } + } + } + } catch (err) { + removeTyping(); + addMessage('assistant', '⚠️ Ошибка: ' + err.message); + } finally { + dom.sendBtn.disabled = false; + dom.inputEl.focus(); + } +} + +export async function clearHistory() { + if (!sessionId) return; + await fetch(`/chat/${sessionId}`, { method: 'DELETE' }); + clearMessages(); +} diff --git a/static/js/personas.js b/static/js/personas.js new file mode 100644 index 0000000..e36ca0f --- /dev/null +++ b/static/js/personas.js @@ -0,0 +1,164 @@ +import { currentPersona, setCurrentPersona, sessionId } from './state.js'; +import { initChat } from './chat.js'; + +export function highlightPersona(personaId) { + document.querySelectorAll('.persona-card').forEach(c => { + c.classList.toggle('active', c.dataset.id === personaId); + }); +} + +export async function loadPersonas() { + const res = await fetch('/personas/'); + const personas = await res.json(); + const bar = document.getElementById('personaBar'); + bar.innerHTML = ''; + + personas.forEach(p => { + const card = document.createElement('div'); + card.className = 'persona-card' + (p.persona_id === currentPersona ? ' active' : ''); + card.dataset.id = p.persona_id; + const isCard = p.persona_id.startsWith('card_'); + card.innerHTML = ` + ${p.emoji} + ${p.name} + ${p.custom ? `` : ''} + ${isCard ? `` : ''} + `; + card.addEventListener('click', () => selectPersona(p.persona_id)); + card.querySelector('.del-btn')?.addEventListener('click', async (e) => { + e.stopPropagation(); + await fetch(`/personas/${p.persona_id}`, { method: 'DELETE' }); + if (currentPersona === p.persona_id) await selectPersona('default'); + loadPersonas(); + }); + card.querySelector('.edit-btn')?.addEventListener('click', async (e) => { + e.stopPropagation(); + const cardId = p.persona_id.slice(5); + const r = await fetch(`/characters/${cardId}`); + const data = await r.json(); + document.getElementById('editCardId').value = cardId; + document.getElementById('editName').value = data.name || ''; + document.getElementById('editDescription').value = data.description || ''; + document.getElementById('editPersonality').value = data.personality || ''; + document.getElementById('editScenario').value = data.scenario || ''; + document.getElementById('editFirstMes').value = data.first_mes || ''; + document.getElementById('editMesExample').value = data.mes_example || ''; + document.getElementById('editAppearance').value = data.appearance_tags || ''; + document.getElementById('editLora').value = data.lora_name || ''; + document.getElementById('editLoraWeight').value = data.lora_weight ?? 0.8; + document.getElementById('cardEditOverlay').classList.add('open'); + }); + bar.appendChild(card); + }); + + const addBtn = document.createElement('button'); + addBtn.type = 'button'; + addBtn.className = 'persona-add'; + addBtn.innerHTML = '➕Создать'; + addBtn.addEventListener('click', () => document.getElementById('modalOverlay').classList.add('open')); + bar.appendChild(addBtn); + + const importBtn = document.createElement('button'); + importBtn.type = 'button'; + importBtn.className = 'card-import-btn'; + importBtn.innerHTML = '📥Chub'; + importBtn.addEventListener('click', () => document.getElementById('cardModalOverlay').classList.add('open')); + bar.appendChild(importBtn); +} + +export async function selectPersona(personaId) { + setCurrentPersona(personaId); + highlightPersona(personaId); + if (sessionId) { + await fetch(`/sessions/${sessionId}`, { + method: 'PATCH', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ persona_id: personaId }), + }); + await initChat(); + } +} + +export function initPersonaModals() { + document.getElementById('modalCancel').addEventListener('click', () => { + document.getElementById('modalOverlay').classList.remove('open'); + }); + document.getElementById('cardModalCancel').addEventListener('click', () => { + document.getElementById('cardModalOverlay').classList.remove('open'); + }); + document.getElementById('cardEditCancel').addEventListener('click', () => { + document.getElementById('cardEditOverlay').classList.remove('open'); + }); + + document.getElementById('modalSave').addEventListener('click', async () => { + const data = { + persona_id: document.getElementById('pId').value.trim(), + name: document.getElementById('pName').value.trim(), + emoji: document.getElementById('pEmoji').value.trim() || '🤖', + description: document.getElementById('pDesc').value.trim(), + prompt: document.getElementById('pPrompt').value.trim(), + sd_enabled: document.getElementById('pSdEnabled').checked, + lora_name: document.getElementById('pLora').value.trim(), + appearance_tags: document.getElementById('pAppearance').value.trim(), + }; + if (!data.persona_id || !data.name || !data.prompt) { + alert('Заполни ID, имя и промт'); + return; + } + await fetch('/personas/', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(data), + }); + document.getElementById('modalOverlay').classList.remove('open'); + await loadPersonas(); + await selectPersona(data.persona_id); + }); + + document.getElementById('cardEditSave').addEventListener('click', async () => { + const cardId = document.getElementById('editCardId').value; + const body = { + name: document.getElementById('editName').value.trim() || undefined, + description: document.getElementById('editDescription').value.trim() || undefined, + personality: document.getElementById('editPersonality').value.trim() || undefined, + scenario: document.getElementById('editScenario').value.trim() || undefined, + first_mes: document.getElementById('editFirstMes').value.trim() || undefined, + mes_example: document.getElementById('editMesExample').value.trim() || undefined, + appearance_tags: document.getElementById('editAppearance').value.trim() || undefined, + lora_name: document.getElementById('editLora').value.trim() || undefined, + lora_weight: parseFloat(document.getElementById('editLoraWeight').value) || undefined, + }; + const res = await fetch(`/characters/${cardId}`, { + method: 'PATCH', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }); + if (!res.ok) { alert('Ошибка сохранения'); return; } + document.getElementById('cardEditOverlay').classList.remove('open'); + await loadPersonas(); + }); + + document.getElementById('cardModalImport').addEventListener('click', async () => { + const fileInput = document.getElementById('cardFile'); + if (!fileInput.files?.length) { + alert('Выберите файл карточки (JSON или PNG)'); + return; + } + const form = new FormData(); + form.append('file', fileInput.files[0]); + form.append('lora_name', document.getElementById('cardLora').value.trim()); + form.append('lora_weight', document.getElementById('cardLoraWeight').value || '0.8'); + + const res = await fetch('/characters/import', { method: 'POST', body: form }); + const data = await res.json(); + if (!res.ok) { + alert(data.detail || 'Ошибка импорта'); + return; + } + document.getElementById('cardModalOverlay').classList.remove('open'); + fileInput.value = ''; + await loadPersonas(); + await selectPersona(data.persona_id); + }); +} + diff --git a/static/js/sessions.js b/static/js/sessions.js new file mode 100644 index 0000000..330ccc2 --- /dev/null +++ b/static/js/sessions.js @@ -0,0 +1,86 @@ +import { sessionId, setSessionId, setCurrentPersona, currentPersona, dom } from './state.js'; +import { clearMessages, addMessage, initChat } from './chat.js'; +import { highlightPersona } from './personas.js'; + +function escapeTitle(t) { + const d = document.createElement('div'); + d.textContent = t; + return d.innerHTML; +} + +export async function loadSessions() { + const res = await fetch('/sessions/'); + const sessions = await res.json(); + dom.sessionList.innerHTML = ''; + + sessions.forEach(s => { + const item = document.createElement('div'); + item.className = 'session-item' + (s.session_id === sessionId ? ' active' : ''); + item.innerHTML = ` +
${escapeTitle(s.title || 'Новый чат')}
+
${s.message_count} сообщ.
+ + `; + item.addEventListener('click', () => switchSession(s.session_id)); + item.querySelector('.s-del').addEventListener('click', async (e) => { + e.stopPropagation(); + await fetch(`/sessions/${s.session_id}`, { method: 'DELETE' }); + if (s.session_id === sessionId) createNewChat(); + else loadSessions(); + }); + dom.sessionList.appendChild(item); + }); +} + +export async function switchSession(id) { + setSessionId(id); + clearMessages(); + await loadSessions(); + await loadChatHistory(id); +} + +export async function loadChatHistory(id) { + const sessionRes = await fetch(`/sessions/${id}`); + if (sessionRes.ok) { + const s = await sessionRes.json(); + dom.headerTitle.textContent = s.title || 'Новый чат'; + if (s.persona_id) { + setCurrentPersona(s.persona_id); + highlightPersona(s.persona_id); + } + } + + const histRes = await fetch(`/chat/history/${id}`); + if (!histRes.ok) return; + + const messages = await histRes.json(); + clearMessages(); + messages.filter(m => m.role !== 'system').forEach(m => { + addMessage( + m.role === 'user' ? 'user' : 'assistant', + m.content, + m.image_prompt, + m.image_path ? `/static/${m.image_path}` : null, + ); + }); +} + +export async function createNewChat() { + setSessionId('sess_' + Math.random().toString(36).slice(2, 10)); + clearMessages(); + dom.headerTitle.textContent = 'Новый чат'; + highlightPersona(currentPersona); + await initChat(); + loadSessions(); +} + +export async function initSessions() { + await loadSessions(); + if (sessionId) { + const check = await fetch(`/sessions/${sessionId}`); + if (check.ok) await switchSession(sessionId); + else createNewChat(); + } else { + createNewChat(); + } +} diff --git a/static/js/state.js b/static/js/state.js new file mode 100644 index 0000000..b518c71 --- /dev/null +++ b/static/js/state.js @@ -0,0 +1,24 @@ +export let sessionId = localStorage.getItem('chat_session_id') || null; +export let currentPersona = localStorage.getItem('persona_id') || 'default'; +export let sidebarOpen = true; +export function toggleSidebar() { sidebarOpen = !sidebarOpen; return sidebarOpen; } + +export function setSessionId(id) { + sessionId = id; + if (id) localStorage.setItem('chat_session_id', id); +} + +export function setCurrentPersona(id) { + currentPersona = id; + localStorage.setItem('persona_id', id); +} + +export const dom = { + messagesEl: document.getElementById('messages'), + inputEl: document.getElementById('input'), + sendBtn: document.getElementById('sendBtn'), + clearBtn: document.getElementById('clearBtn'), + sessionList: document.getElementById('sessionList'), + headerTitle: document.getElementById('headerTitle'), + emptyState: document.getElementById('emptyState'), +}; diff --git a/static/js/utils.js b/static/js/utils.js new file mode 100644 index 0000000..7e43ecc --- /dev/null +++ b/static/js/utils.js @@ -0,0 +1,16 @@ +export function parseImagePromptFromContent(content) { + if (!content || !content.includes('[IMAGE_PROMPT:')) return { text: content, prompt: null }; + const match = content.match(/\[IMAGE_PROMPT:(.*?)\]/s); + const prompt = match ? match[1].trim() : null; + const text = content.replace(/\[IMAGE_PROMPT:.*?\]/gs, '').trim(); + return { text, prompt }; +} + +export async function copyToClipboard(text) { + try { + await navigator.clipboard.writeText(text); + return true; + } catch { + return false; + } +}