from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone import aiosqlite @dataclass class LinkedUser: telegram_id: int api_token: str ha_user_id: int display_name: str username: str session_id: int reminder_seq: int pomodoro_seq: int class Storage: def __init__(self, db_path: str) -> None: self.db_path = db_path self._db: aiosqlite.Connection | None = None async def connect(self) -> None: self._db = await aiosqlite.connect(self.db_path) self._db.row_factory = aiosqlite.Row await self._db.executescript( """ CREATE TABLE IF NOT EXISTS users ( telegram_id INTEGER PRIMARY KEY, api_token TEXT NOT NULL, ha_user_id INTEGER NOT NULL, display_name TEXT NOT NULL DEFAULT '', username TEXT NOT NULL DEFAULT '', session_id INTEGER NOT NULL, reminder_seq INTEGER NOT NULL DEFAULT 0, pomodoro_seq INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS session_cursors ( telegram_id INTEGER NOT NULL, session_id INTEGER NOT NULL, last_message_id INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (telegram_id, session_id) ); """ ) await self._db.commit() async def close(self) -> None: if self._db: await self._db.close() self._db = None @property def db(self) -> aiosqlite.Connection: if not self._db: raise RuntimeError("Storage is not connected") return self._db async def get_user(self, telegram_id: int) -> LinkedUser | None: cursor = await self.db.execute( """ SELECT telegram_id, api_token, ha_user_id, display_name, username, session_id, reminder_seq, pomodoro_seq FROM users WHERE telegram_id = ? """, (telegram_id,), ) row = await cursor.fetchone() if not row: return None return LinkedUser( telegram_id=int(row["telegram_id"]), api_token=str(row["api_token"]), ha_user_id=int(row["ha_user_id"]), display_name=str(row["display_name"] or ""), username=str(row["username"] or ""), session_id=int(row["session_id"]), reminder_seq=int(row["reminder_seq"]), pomodoro_seq=int(row["pomodoro_seq"]), ) async def list_linked_users(self) -> list[LinkedUser]: cursor = await self.db.execute( """ SELECT telegram_id, api_token, ha_user_id, display_name, username, session_id, reminder_seq, pomodoro_seq FROM users """ ) rows = await cursor.fetchall() return [ LinkedUser( telegram_id=int(row["telegram_id"]), api_token=str(row["api_token"]), ha_user_id=int(row["ha_user_id"]), display_name=str(row["display_name"] or ""), username=str(row["username"] or ""), session_id=int(row["session_id"]), reminder_seq=int(row["reminder_seq"]), pomodoro_seq=int(row["pomodoro_seq"]), ) for row in rows ] async def link_user( self, *, telegram_id: int, api_token: str, ha_user_id: int, display_name: str, username: str, session_id: int, ) -> None: now = datetime.now(timezone.utc).isoformat() await self.db.execute( """ INSERT INTO users ( telegram_id, api_token, ha_user_id, display_name, username, session_id, reminder_seq, pomodoro_seq, created_at ) VALUES (?, ?, ?, ?, ?, ?, 0, 0, ?) ON CONFLICT(telegram_id) DO UPDATE SET api_token = excluded.api_token, ha_user_id = excluded.ha_user_id, display_name = excluded.display_name, username = excluded.username, session_id = excluded.session_id, reminder_seq = 0, pomodoro_seq = 0 """, (telegram_id, api_token, ha_user_id, display_name, username, session_id, now), ) await self.db.commit() async def unlink_user(self, telegram_id: int) -> bool: cursor = await self.db.execute("DELETE FROM users WHERE telegram_id = ?", (telegram_id,)) await self.db.execute( "DELETE FROM session_cursors WHERE telegram_id = ?", (telegram_id,), ) await self.db.commit() return cursor.rowcount > 0 async def set_session_id(self, telegram_id: int, session_id: int) -> None: await self.db.execute( "UPDATE users SET session_id = ? WHERE telegram_id = ?", (session_id, telegram_id), ) await self.db.commit() async def update_seq( self, telegram_id: int, *, reminder_seq: int | None = None, pomodoro_seq: int | None = None, ) -> None: if reminder_seq is not None: await self.db.execute( "UPDATE users SET reminder_seq = ? WHERE telegram_id = ?", (reminder_seq, telegram_id), ) if pomodoro_seq is not None: await self.db.execute( "UPDATE users SET pomodoro_seq = ? WHERE telegram_id = ?", (pomodoro_seq, telegram_id), ) await self.db.commit() async def get_last_message_id(self, telegram_id: int, session_id: int) -> int: cursor = await self.db.execute( """ SELECT last_message_id FROM session_cursors WHERE telegram_id = ? AND session_id = ? """, (telegram_id, session_id), ) row = await cursor.fetchone() return int(row["last_message_id"]) if row else 0 async def set_last_message_id(self, telegram_id: int, session_id: int, message_id: int) -> None: await self.db.execute( """ INSERT INTO session_cursors (telegram_id, session_id, last_message_id) VALUES (?, ?, ?) ON CONFLICT(telegram_id, session_id) DO UPDATE SET last_message_id = excluded.last_message_id """, (telegram_id, session_id, message_id), ) await self.db.commit()