Files
2026-06-13 20:20:56 +00:00

197 lines
6.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
import aiosqlite
@dataclass
class LinkedUser:
telegram_id: int
api_token: str
ha_user_id: int
display_name: str
username: str
session_id: int
reminder_seq: int
pomodoro_seq: int
class Storage:
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self._db: aiosqlite.Connection | None = None
async def connect(self) -> None:
self._db = await aiosqlite.connect(self.db_path)
self._db.row_factory = aiosqlite.Row
await self._db.executescript(
"""
CREATE TABLE IF NOT EXISTS users (
telegram_id INTEGER PRIMARY KEY,
api_token TEXT NOT NULL,
ha_user_id INTEGER NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
username TEXT NOT NULL DEFAULT '',
session_id INTEGER NOT NULL,
reminder_seq INTEGER NOT NULL DEFAULT 0,
pomodoro_seq INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS session_cursors (
telegram_id INTEGER NOT NULL,
session_id INTEGER NOT NULL,
last_message_id INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (telegram_id, session_id)
);
"""
)
await self._db.commit()
async def close(self) -> None:
if self._db:
await self._db.close()
self._db = None
@property
def db(self) -> aiosqlite.Connection:
if not self._db:
raise RuntimeError("Storage is not connected")
return self._db
async def get_user(self, telegram_id: int) -> LinkedUser | None:
cursor = await self.db.execute(
"""
SELECT telegram_id, api_token, ha_user_id, display_name, username,
session_id, reminder_seq, pomodoro_seq
FROM users WHERE telegram_id = ?
""",
(telegram_id,),
)
row = await cursor.fetchone()
if not row:
return None
return LinkedUser(
telegram_id=int(row["telegram_id"]),
api_token=str(row["api_token"]),
ha_user_id=int(row["ha_user_id"]),
display_name=str(row["display_name"] or ""),
username=str(row["username"] or ""),
session_id=int(row["session_id"]),
reminder_seq=int(row["reminder_seq"]),
pomodoro_seq=int(row["pomodoro_seq"]),
)
async def list_linked_users(self) -> list[LinkedUser]:
cursor = await self.db.execute(
"""
SELECT telegram_id, api_token, ha_user_id, display_name, username,
session_id, reminder_seq, pomodoro_seq
FROM users
"""
)
rows = await cursor.fetchall()
return [
LinkedUser(
telegram_id=int(row["telegram_id"]),
api_token=str(row["api_token"]),
ha_user_id=int(row["ha_user_id"]),
display_name=str(row["display_name"] or ""),
username=str(row["username"] or ""),
session_id=int(row["session_id"]),
reminder_seq=int(row["reminder_seq"]),
pomodoro_seq=int(row["pomodoro_seq"]),
)
for row in rows
]
async def link_user(
self,
*,
telegram_id: int,
api_token: str,
ha_user_id: int,
display_name: str,
username: str,
session_id: int,
) -> None:
now = datetime.now(timezone.utc).isoformat()
await self.db.execute(
"""
INSERT INTO users (
telegram_id, api_token, ha_user_id, display_name, username,
session_id, reminder_seq, pomodoro_seq, created_at
) VALUES (?, ?, ?, ?, ?, ?, 0, 0, ?)
ON CONFLICT(telegram_id) DO UPDATE SET
api_token = excluded.api_token,
ha_user_id = excluded.ha_user_id,
display_name = excluded.display_name,
username = excluded.username,
session_id = excluded.session_id,
reminder_seq = 0,
pomodoro_seq = 0
""",
(telegram_id, api_token, ha_user_id, display_name, username, session_id, now),
)
await self.db.commit()
async def unlink_user(self, telegram_id: int) -> bool:
cursor = await self.db.execute("DELETE FROM users WHERE telegram_id = ?", (telegram_id,))
await self.db.execute(
"DELETE FROM session_cursors WHERE telegram_id = ?",
(telegram_id,),
)
await self.db.commit()
return cursor.rowcount > 0
async def set_session_id(self, telegram_id: int, session_id: int) -> None:
await self.db.execute(
"UPDATE users SET session_id = ? WHERE telegram_id = ?",
(session_id, telegram_id),
)
await self.db.commit()
async def update_seq(
self,
telegram_id: int,
*,
reminder_seq: int | None = None,
pomodoro_seq: int | None = None,
) -> None:
if reminder_seq is not None:
await self.db.execute(
"UPDATE users SET reminder_seq = ? WHERE telegram_id = ?",
(reminder_seq, telegram_id),
)
if pomodoro_seq is not None:
await self.db.execute(
"UPDATE users SET pomodoro_seq = ? WHERE telegram_id = ?",
(pomodoro_seq, telegram_id),
)
await self.db.commit()
async def get_last_message_id(self, telegram_id: int, session_id: int) -> int:
cursor = await self.db.execute(
"""
SELECT last_message_id FROM session_cursors
WHERE telegram_id = ? AND session_id = ?
""",
(telegram_id, session_id),
)
row = await cursor.fetchone()
return int(row["last_message_id"]) if row else 0
async def set_last_message_id(self, telegram_id: int, session_id: int, message_id: int) -> None:
await self.db.execute(
"""
INSERT INTO session_cursors (telegram_id, session_id, last_message_id)
VALUES (?, ?, ?)
ON CONFLICT(telegram_id, session_id) DO UPDATE SET
last_message_id = excluded.last_message_id
""",
(telegram_id, session_id, message_id),
)
await self.db.commit()