first commit
This commit is contained in:
@@ -0,0 +1,142 @@
|
||||
import aiosqlite
|
||||
from database.db import DB_PATH
|
||||
|
||||
|
||||
async def get_or_create_session(session_id: str, persona_id: str = "default") -> dict:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions WHERE session_id = ?", (session_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO sessions (session_id, persona_id) VALUES (?, ?)",
|
||||
(session_id, persona_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions WHERE session_id = ?", (session_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return dict(row)
|
||||
|
||||
|
||||
async def get_all_sessions() -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT * FROM sessions ORDER BY updated_at DESC"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_session_persona(session_id: str, persona_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE sessions SET persona_id = ?, updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(persona_id, session_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def delete_session(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
await db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_history(session_id: str) -> list:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT role, content, image_prompt, image_path
|
||||
FROM messages WHERE session_id = ? ORDER BY id""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
"role": r["role"],
|
||||
"content": r["content"],
|
||||
"image_prompt": r["image_prompt"],
|
||||
"image_path": r["image_path"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def add_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
image_prompt: str | None = None,
|
||||
image_path: str | None = None,
|
||||
):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, image_prompt, image_path)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, role, content, image_prompt, image_path),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE sessions SET updated_at = CURRENT_TIMESTAMP WHERE session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_message_image(message_id: int, image_path: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"UPDATE messages SET image_path = ? WHERE id = ?",
|
||||
(image_path, message_id),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_last_assistant_message_id(session_id: str) -> int | None:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"""SELECT id FROM messages
|
||||
WHERE session_id = ? AND role = 'assistant'
|
||||
ORDER BY id DESC LIMIT 1""",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["id"] if row else None
|
||||
|
||||
|
||||
async def clear_history(session_id: str):
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
await db.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def get_message_count(session_id: str) -> int:
|
||||
async with aiosqlite.connect(DB_PATH) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM messages WHERE session_id = ? AND role != 'system'",
|
||||
(session_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row["cnt"]
|
||||
Reference in New Issue
Block a user