import aiosqlite from database.db import DB_PATH async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM sessions WHERE session_id = ?", (session_id,) ) as cursor: row = await cursor.fetchone() if row: return dict(row) await db.execute( "INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)", (session_id, persona_id), ) await db.commit() async with db.execute( "SELECT * FROM sessions WHERE session_id = ?", (session_id,) ) as cursor: row = await cursor.fetchone() return dict(row) async def get_all_sessions() -> list: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM sessions ORDER BY updated_at DESC" ) as cursor: rows = await cursor.fetchall() return [dict(r) for r in rows] async def update_session_title(session_id: str, title: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute( "UPDATE sessions SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", (title, session_id), ) await db.commit() async def update_session_persona(session_id: str, persona_id: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute( "UPDATE sessions SET persona_id = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", (persona_id, session_id), ) await db.commit() async def delete_session(session_id: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) await db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) await db.commit() async def get_history(session_id: str) -> list: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( """SELECT role, content, image_prompt, image_path FROM messages WHERE session_id = ? ORDER BY id""", (session_id,), ) as cursor: rows = await cursor.fetchall() return [ { "role": r["role"], "content": r["content"], "image_prompt": r["image_prompt"], "image_path": r["image_path"], } for r in rows ] async def add_message( session_id: str, role: str, content: str, image_prompt: str | None = None, image_path: str | None = None, ): async with aiosqlite.connect(DB_PATH) as db: await db.execute( """INSERT INTO messages (session_id, role, content, image_prompt, image_path) VALUES (?, ?, ?, ?, ?)""", (session_id, role, content, image_prompt, image_path), ) await db.execute( "UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?", (session_id,), ) await db.commit() async def update_message_image(message_id: int, image_path: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute( "UPDATE messages SET image_path = ? WHERE id = ?", (image_path, message_id), ) await db.commit() async def get_last_assistant_message_id(session_id: str) -> int | None: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( """SELECT id FROM messages WHERE session_id = ? AND role = 'assistant' ORDER BY id DESC LIMIT 1""", (session_id,), ) as cursor: row = await cursor.fetchone() return row["id"] if row else None async def clear_history(session_id: str): async with aiosqlite.connect(DB_PATH) as db: await db.execute( "DELETE FROM messages WHERE session_id = ?", (session_id,) ) await db.commit() async def get_message_count(session_id: str) -> int: async with aiosqlite.connect(DB_PATH) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT COUNT(*) as cnt FROM messages WHERE session_id = ? AND role != 'system'", (session_id,), ) as cursor: row = await cursor.fetchone() return row["cnt"]