Fixed SD RPG
This commit is contained in:
+308
-26
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
|
||||
async def get_or_create_session(session_id: str, persona_id: str | None = None) -> dict:
|
||||
"""Existing sessions keep their persona_id; persona_id applies only on INSERT."""
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
@@ -13,9 +16,10 @@ async def get_or_create_session(session_id: str, persona_id: str = "default") ->
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
pid = (persona_id or "default").strip() or "default"
|
||||
await db.execute(
|
||||
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
|
||||
(session_id, persona_id),
|
||||
(session_id, pid),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -71,24 +75,104 @@ async def update_session_persona(session_id: str, persona_id: str):
|
||||
(persona_id, session_id),
|
||||
)
|
||||
|
||||
# If persona changed, reset RPG state bound to the persona/arc.
|
||||
if prev is not None and prev != persona_id:
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}'
|
||||
WHERE session_id = ?""",
|
||||
(session_id,),
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM action_resolutions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await _reset_persona_bound_state(db, session_id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _reset_persona_bound_state(db: aiosqlite.Connection, session_id: str) -> None:
|
||||
from services.rpg_state import DEFAULT_NARRATIVE_STATS
|
||||
|
||||
stats_default = json.dumps(DEFAULT_NARRATIVE_STATS, ensure_ascii=False)
|
||||
await db.execute(
|
||||
"""UPDATE sessions
|
||||
SET facts_json = '[]',
|
||||
global_plot = '',
|
||||
status_quo = '',
|
||||
plot_arc_json = '{}',
|
||||
outfit_json = '[]',
|
||||
affinity = 0,
|
||||
scene_json = '{}',
|
||||
narrative_stats_json = ?
|
||||
WHERE session_id = ?""",
|
||||
(stats_default, session_id),
|
||||
)
|
||||
await db.execute("DELETE FROM action_resolutions WHERE session_id = ?", (session_id,))
|
||||
await db.execute("DELETE FROM rpg_quests WHERE session_id = ?", (session_id,))
|
||||
|
||||
|
||||
async def upsert_static_system_message(
|
||||
session_id: str, static_prompt: str, history: list | None = None
|
||||
) -> bool:
|
||||
"""Store only static persona prompt in messages. Returns True if written."""
|
||||
hist = history if history is not None else await get_history(session_id)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
if not hist:
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
if hist[0]["role"] == "system":
|
||||
if hist[0]["content"] == static_prompt:
|
||||
return False
|
||||
await db.execute(
|
||||
"""UPDATE messages SET content = ?
|
||||
WHERE session_id = ? AND role = 'system'
|
||||
AND id = (SELECT MIN(id) FROM messages WHERE session_id = ?)""",
|
||||
(static_prompt, session_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, 'system', ?, NULL, NULL)""",
|
||||
(session_id, static_prompt),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def delete_dialog_messages(session_id: str) -> None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ? AND role IN ('user', 'assistant')",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def rebind_session_persona(
|
||||
session_id: str,
|
||||
persona_id: str,
|
||||
*,
|
||||
clear_history: bool = False,
|
||||
static_prompt: str,
|
||||
first_mes: str | None = None,
|
||||
) -> None:
|
||||
session = await get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError("Session not found")
|
||||
|
||||
await update_session_persona(session_id, persona_id)
|
||||
if clear_history:
|
||||
await delete_dialog_messages(session_id)
|
||||
|
||||
history = await get_history(session_id)
|
||||
await upsert_static_system_message(session_id, static_prompt, history)
|
||||
|
||||
if clear_history and first_mes and first_mes.strip():
|
||||
await add_message(session_id, "assistant", first_mes.strip())
|
||||
|
||||
|
||||
async def update_session_rpg(session_id: str, rpg_enabled: bool):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
@@ -174,25 +258,116 @@ async def delete_session(session_id: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
async def get_action_resolutions_map(session_id: str) -> dict[int, dict]:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id, role, content, image_prompt, image_path
|
||||
"""SELECT message_id, intent_text, roll, outcome, resolution_text
|
||||
FROM action_resolutions
|
||||
WHERE session_id = ? AND message_id IS NOT NULL
|
||||
ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
out: dict[int, dict] = {}
|
||||
for r in rows:
|
||||
mid = r["message_id"]
|
||||
if mid is not None:
|
||||
out[int(mid)] = {
|
||||
"intent_text": r["intent_text"],
|
||||
"roll": r["roll"],
|
||||
"outcome": r["outcome"],
|
||||
"resolution_text": r["resolution_text"],
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def narrator_message_content(narrator: dict) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"roll": narrator.get("roll"),
|
||||
"outcome": narrator.get("outcome"),
|
||||
"text": narrator.get("text", ""),
|
||||
"original_intent": narrator.get("original_intent"),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def parse_narrator_message(content: str) -> dict | None:
|
||||
try:
|
||||
data = json.loads(content or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if not isinstance(data, dict) or not (data.get("text") or "").strip():
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
async def seed_quests_from_arc(session_id: str, arc: dict) -> int:
|
||||
"""Create active quests for arc beats that are not already in rpg_quests."""
|
||||
if not arc:
|
||||
return 0
|
||||
existing = {q["title"] for q in await get_quests(session_id)}
|
||||
added = 0
|
||||
for beat in arc.get("beats", []):
|
||||
title = (beat.get("title") or beat.get("injection", "")).strip()[:120]
|
||||
if title and title not in existing:
|
||||
await upsert_quest(session_id, title, "active")
|
||||
existing.add(title)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
resolutions = await get_action_resolutions_map(session_id)
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id, role, content, image_prompt, image_path,
|
||||
image_prompt_alt, image_path_alt, choices_json
|
||||
FROM messages WHERE session_id = ? ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
result = []
|
||||
for idx, r in enumerate(rows):
|
||||
item = {
|
||||
"id": r["id"],
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"image_prompt": r["image_prompt"],
|
||||
"image_path": r["image_path"],
|
||||
"image_prompt_alt": r["image_prompt_alt"],
|
||||
"image_path_alt": r["image_path_alt"],
|
||||
"choices_json": r["choices_json"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
if r["role"] == "user" and r["id"] in resolutions:
|
||||
item["action_resolution"] = resolutions[r["id"]]
|
||||
result.append(item)
|
||||
if r["role"] == "user" and r["id"] in resolutions:
|
||||
nxt = rows[idx + 1] if idx + 1 < len(rows) else None
|
||||
if not nxt or nxt["role"] != "narrator":
|
||||
res = resolutions[r["id"]]
|
||||
result.append(
|
||||
{
|
||||
"id": -int(r["id"]),
|
||||
"role": "narrator",
|
||||
"content": narrator_message_content(
|
||||
{
|
||||
"roll": res.get("roll"),
|
||||
"outcome": res.get("outcome"),
|
||||
"text": res.get("resolution_text", ""),
|
||||
}
|
||||
),
|
||||
"image_prompt": None,
|
||||
"image_path": None,
|
||||
"image_prompt_alt": None,
|
||||
"image_path_alt": None,
|
||||
"choices_json": None,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def get_message(message_id: int) -> dict | None:
|
||||
@@ -230,6 +405,38 @@ async def delete_message(message_id: int):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def delete_message_and_following(session_id: str, message_id: int) -> bool:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ? AND id >= ?",
|
||||
(session_id, message_id),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def update_message_choices(message_id: int, choices_json: str | None):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET choices_json = ? WHERE id = ?",
|
||||
(choices_json, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def clear_choices_for_session(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET choices_json = NULL WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_message_preview(session_id: str, max_len: int = 80) -> str:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
@@ -261,8 +468,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
|
||||
await db.execute(
|
||||
"""INSERT INTO sessions
|
||||
(session_id, persona_id, title, rpg_enabled, facts_json, global_plot,
|
||||
status_quo, plot_arc_json, genre, rpg_settings_json, affinity)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
status_quo, plot_arc_json, genre, rpg_settings_json, affinity,
|
||||
outfit_json, scene_json, narrative_stats_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
new_id,
|
||||
source["persona_id"],
|
||||
@@ -275,6 +483,9 @@ async def fork_session(source_session_id: str, until_message_id: int) -> str | N
|
||||
source.get("genre", "adventure"),
|
||||
source.get("rpg_settings_json", "{}"),
|
||||
source.get("affinity", 0),
|
||||
source.get("outfit_json", "[]"),
|
||||
source.get("scene_json", "{}"),
|
||||
source.get("narrative_stats_json", '{"lust":0,"stamina":10,"tension":0}'),
|
||||
),
|
||||
)
|
||||
async with db.execute(
|
||||
@@ -309,18 +520,20 @@ async def add_message(
|
||||
content: str,
|
||||
image_prompt: str | None = None,
|
||||
image_path: str | None = None,
|
||||
):
|
||||
) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
cur = await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, role, content, image_prompt, image_path),
|
||||
)
|
||||
msg_id = cur.lastrowid
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return msg_id
|
||||
|
||||
|
||||
async def update_message_image(message_id: int, image_path: str):
|
||||
@@ -332,6 +545,33 @@ async def update_message_image(message_id: int, image_path: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt(message_id: int, image_prompt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt = ? WHERE id = ?",
|
||||
(image_prompt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_prompt_alt(message_id: int, image_prompt_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_prompt_alt = ? WHERE id = ?",
|
||||
(image_prompt_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_image_alt(message_id: int, image_path_alt: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_path_alt = ? WHERE id = ?",
|
||||
(image_path_alt, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_assistant_message_id(session_id: str) -> int | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
@@ -362,6 +602,18 @@ async def update_session_affinity(session_id: str, delta: int):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def set_session_affinity(session_id: str, value: int):
|
||||
"""Debug / admin: set absolute affinity (-30..30)."""
|
||||
clamped = max(-30, min(30, int(value)))
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET affinity = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(clamped, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return clamped
|
||||
|
||||
|
||||
async def update_session_genre(session_id: str, genre: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
@@ -389,6 +641,24 @@ async def update_session_outfit(session_id: str, outfit_json: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_scene(session_id: str, scene_json: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET scene_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(scene_json, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_narrative_stats(session_id: str, stats_json: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET narrative_stats_json = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(stats_json, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def upsert_quest(session_id: str, title: str, status: str = "active"):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
async with db.execute(
|
||||
@@ -429,6 +699,18 @@ async def update_quest_status(session_id: str, title: str, status: str):
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_quest_by_id(quest_id: int, session_id: str, status: str) -> bool:
|
||||
if status not in ("active", "done", "failed"):
|
||||
return False
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
cur = await db.execute(
|
||||
"UPDATE rpg_quests SET status = ? WHERE id = ? AND session_id = ?",
|
||||
(status, quest_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def get_message_count(session_id: str) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
Reference in New Issue
Block a user