added RAG, Multiuser, TG bot
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinkedUser:
|
||||
telegram_id: int
|
||||
api_token: str
|
||||
ha_user_id: int
|
||||
display_name: str
|
||||
username: str
|
||||
session_id: int
|
||||
reminder_seq: int
|
||||
pomodoro_seq: int
|
||||
|
||||
|
||||
class Storage:
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self.db_path = db_path
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._db = await aiosqlite.connect(self.db_path)
|
||||
self._db.row_factory = aiosqlite.Row
|
||||
await self._db.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
telegram_id INTEGER PRIMARY KEY,
|
||||
api_token TEXT NOT NULL,
|
||||
ha_user_id INTEGER NOT NULL,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
username TEXT NOT NULL DEFAULT '',
|
||||
session_id INTEGER NOT NULL,
|
||||
reminder_seq INTEGER NOT NULL DEFAULT 0,
|
||||
pomodoro_seq INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session_cursors (
|
||||
telegram_id INTEGER NOT NULL,
|
||||
session_id INTEGER NOT NULL,
|
||||
last_message_id INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (telegram_id, session_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
await self._db.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._db:
|
||||
await self._db.close()
|
||||
self._db = None
|
||||
|
||||
@property
|
||||
def db(self) -> aiosqlite.Connection:
|
||||
if not self._db:
|
||||
raise RuntimeError("Storage is not connected")
|
||||
return self._db
|
||||
|
||||
async def get_user(self, telegram_id: int) -> LinkedUser | None:
|
||||
cursor = await self.db.execute(
|
||||
"""
|
||||
SELECT telegram_id, api_token, ha_user_id, display_name, username,
|
||||
session_id, reminder_seq, pomodoro_seq
|
||||
FROM users WHERE telegram_id = ?
|
||||
""",
|
||||
(telegram_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return LinkedUser(
|
||||
telegram_id=int(row["telegram_id"]),
|
||||
api_token=str(row["api_token"]),
|
||||
ha_user_id=int(row["ha_user_id"]),
|
||||
display_name=str(row["display_name"] or ""),
|
||||
username=str(row["username"] or ""),
|
||||
session_id=int(row["session_id"]),
|
||||
reminder_seq=int(row["reminder_seq"]),
|
||||
pomodoro_seq=int(row["pomodoro_seq"]),
|
||||
)
|
||||
|
||||
async def list_linked_users(self) -> list[LinkedUser]:
|
||||
cursor = await self.db.execute(
|
||||
"""
|
||||
SELECT telegram_id, api_token, ha_user_id, display_name, username,
|
||||
session_id, reminder_seq, pomodoro_seq
|
||||
FROM users
|
||||
"""
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
LinkedUser(
|
||||
telegram_id=int(row["telegram_id"]),
|
||||
api_token=str(row["api_token"]),
|
||||
ha_user_id=int(row["ha_user_id"]),
|
||||
display_name=str(row["display_name"] or ""),
|
||||
username=str(row["username"] or ""),
|
||||
session_id=int(row["session_id"]),
|
||||
reminder_seq=int(row["reminder_seq"]),
|
||||
pomodoro_seq=int(row["pomodoro_seq"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def link_user(
|
||||
self,
|
||||
*,
|
||||
telegram_id: int,
|
||||
api_token: str,
|
||||
ha_user_id: int,
|
||||
display_name: str,
|
||||
username: str,
|
||||
session_id: int,
|
||||
) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await self.db.execute(
|
||||
"""
|
||||
INSERT INTO users (
|
||||
telegram_id, api_token, ha_user_id, display_name, username,
|
||||
session_id, reminder_seq, pomodoro_seq, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, 0, 0, ?)
|
||||
ON CONFLICT(telegram_id) DO UPDATE SET
|
||||
api_token = excluded.api_token,
|
||||
ha_user_id = excluded.ha_user_id,
|
||||
display_name = excluded.display_name,
|
||||
username = excluded.username,
|
||||
session_id = excluded.session_id,
|
||||
reminder_seq = 0,
|
||||
pomodoro_seq = 0
|
||||
""",
|
||||
(telegram_id, api_token, ha_user_id, display_name, username, session_id, now),
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def unlink_user(self, telegram_id: int) -> bool:
|
||||
cursor = await self.db.execute("DELETE FROM users WHERE telegram_id = ?", (telegram_id,))
|
||||
await self.db.execute(
|
||||
"DELETE FROM session_cursors WHERE telegram_id = ?",
|
||||
(telegram_id,),
|
||||
)
|
||||
await self.db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def set_session_id(self, telegram_id: int, session_id: int) -> None:
|
||||
await self.db.execute(
|
||||
"UPDATE users SET session_id = ? WHERE telegram_id = ?",
|
||||
(session_id, telegram_id),
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def update_seq(
|
||||
self,
|
||||
telegram_id: int,
|
||||
*,
|
||||
reminder_seq: int | None = None,
|
||||
pomodoro_seq: int | None = None,
|
||||
) -> None:
|
||||
if reminder_seq is not None:
|
||||
await self.db.execute(
|
||||
"UPDATE users SET reminder_seq = ? WHERE telegram_id = ?",
|
||||
(reminder_seq, telegram_id),
|
||||
)
|
||||
if pomodoro_seq is not None:
|
||||
await self.db.execute(
|
||||
"UPDATE users SET pomodoro_seq = ? WHERE telegram_id = ?",
|
||||
(pomodoro_seq, telegram_id),
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def get_last_message_id(self, telegram_id: int, session_id: int) -> int:
|
||||
cursor = await self.db.execute(
|
||||
"""
|
||||
SELECT last_message_id FROM session_cursors
|
||||
WHERE telegram_id = ? AND session_id = ?
|
||||
""",
|
||||
(telegram_id, session_id),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return int(row["last_message_id"]) if row else 0
|
||||
|
||||
async def set_last_message_id(self, telegram_id: int, session_id: int, message_id: int) -> None:
|
||||
await self.db.execute(
|
||||
"""
|
||||
INSERT INTO session_cursors (telegram_id, session_id, last_message_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(telegram_id, session_id) DO UPDATE SET
|
||||
last_message_id = excluded.last_message_id
|
||||
""",
|
||||
(telegram_id, session_id, message_id),
|
||||
)
|
||||
await self.db.commit()
|
||||
Reference in New Issue
Block a user