Fixed SD Promt

This commit is contained in:
2026-06-02 15:03:39 +03:00
parent d4cd8f02f4
commit 03cbda5dce
46 changed files with 3285 additions and 429 deletions
+124 -17
View File
@@ -2,7 +2,8 @@ import aiosqlite
from database.db import DB_PATH
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict:
"""Existing sessions keep their persona_id; persona_id applies only on INSERT."""
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
@@ -13,9 +14,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") ->
if row:
return dict(row)
pid = (persona_id or "default").strip() or "default"
await db.execute(
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
(session_id, persona_id),
(session_id, pid),
)
await db.commit()
@@ -71,24 +73,99 @@ async def update_session_persona(session_id: str, persona_id: str):
(persona_id, session_id),
)
# If persona changed, reset RPG state bound to the persona/arc.
if prev is not None and prev != persona_id:
await db.execute(
"""UPDATE sessions
SET facts_json = '[]',
global_plot = '',
status_quo = '',
plot_arc_json = '{}'
WHERE session_id = ?""",
(session_id,),
)
await db.execute(
"DELETE FROM action_resolutions WHERE session_id = ?",
(session_id,),
)
await _reset_persona_bound_state(db, session_id)
await db.commit()
async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None:
await db.execute(
"""UPDATE sessions
SET facts_json = '[]',
global_plot = '',
status_quo = '',
plot_arc_json = '{}',
outfit_json = '[]',
affinity = 0
WHERE session_id = ?""",
(session_id,),
)
await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,))
await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,))
async def upsert_static_system_message(
session_id: str, static_prompt: str, history: list | None = None
) -> bool:
"""Store only static persona prompt in messages. Returns True if written."""
hist = history if history is not None else await get_history(session_id)
async with aiosqlite.connect(DB_PATH) as db:
if not hist:
await db.execute(
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
VALUES (?, 'system', ?, NULL, NULL)""",
(session_id, static_prompt),
)
await db.execute(
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
(session_id,),
)
await db.commit()
return True
if hist[0]["role"] == "system":
if hist[0]["content"] == static_prompt:
return False
await db.execute(
"""UPDATE messages SET content = ?
WHERE session_id = ? AND role = 'system'
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
(static_prompt, session_id, session_id),
)
await db.commit()
return True
await db.execute(
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
VALUES (?, 'system', ?, NULL, NULL)""",
(session_id, static_prompt),
)
await db.commit()
return True
async def delete_dialog_messages(session_id: str) -> None:
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')",
(session_id,),
)
await db.commit()
async def rebind_session_persona(
session_id: str,
persona_id: str,
*,
clear_history: bool = False,
static_prompt: str,
first_mes: str | None = None,
) -> None:
session = await get_session(session_id)
if not session:
raise ValueError("Session not found")
await update_session_persona(session_id, persona_id)
if clear_history:
await delete_dialog_messages(session_id)
history = await get_history(session_id)
await upsert_static_system_message(session_id, static_prompt, history)
if clear_history and first_mes and first_mes.strip():
await add_message(session_id, "assistant", first_mes.strip())
async def update_session_rpg(session_id: str, rpg_enabled: bool):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
@@ -178,7 +255,8 @@ async def get_history(session_id: str) -> list:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"""SELECT id, role, content, image_prompt, image_path
"""SELECT id, role, content, image_prompt, image_path,
image_prompt_alt, image_path_alt
FROM messages WHERE session_id = ? ORDER BY id""",
(session_id,),
) as cursor:
@@ -190,6 +268,8 @@ async def get_history(session_id: str) -> list:
"content": r["content"],
"image_prompt": r["image_prompt"],
"image_path": r["image_path"],
"image_prompt_alt": r["image_prompt_alt"],
"image_path_alt": r["image_path_alt"],
}
for r in rows
]
@@ -332,6 +412,33 @@ async def update_message_image(message_id: int, image_path: str):
await db.commit()
async def update_message_prompt(message_id: int, image_prompt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_prompt = ? WHERE id = ?",
(image_prompt, message_id),
)
await db.commit()
async def update_message_prompt_alt(message_id: int, image_prompt_alt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_prompt_alt = ? WHERE id = ?",
(image_prompt_alt, message_id),
)
await db.commit()
async def update_message_image_alt(message_id: int, image_path_alt: str):
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"UPDATE messages SET image_path_alt = ? WHERE id = ?",
(image_path_alt, message_id),
)
await db.commit()
async def get_last_assistant_message_id(session_id: str) -> int | None:
async with aiosqlite.connect(DB_PATH) as db:
db.row_factory = aiosqlite.Row