197 lines
6.8 KiB
Python
197 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
|
|
import aiosqlite
|
|
|
|
|
|
@dataclass
|
|
class LinkedUser:
|
|
telegram_id: int
|
|
api_token: str
|
|
ha_user_id: int
|
|
display_name: str
|
|
username: str
|
|
session_id: int
|
|
reminder_seq: int
|
|
pomodoro_seq: int
|
|
|
|
|
|
class Storage:
|
|
def __init__(self, db_path: str) -> None:
|
|
self.db_path = db_path
|
|
self._db: aiosqlite.Connection | None = None
|
|
|
|
async def connect(self) -> None:
|
|
self._db = await aiosqlite.connect(self.db_path)
|
|
self._db.row_factory = aiosqlite.Row
|
|
await self._db.executescript(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
telegram_id INTEGER PRIMARY KEY,
|
|
api_token TEXT NOT NULL,
|
|
ha_user_id INTEGER NOT NULL,
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
username TEXT NOT NULL DEFAULT '',
|
|
session_id INTEGER NOT NULL,
|
|
reminder_seq INTEGER NOT NULL DEFAULT 0,
|
|
pomodoro_seq INTEGER NOT NULL DEFAULT 0,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS session_cursors (
|
|
telegram_id INTEGER NOT NULL,
|
|
session_id INTEGER NOT NULL,
|
|
last_message_id INTEGER NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (telegram_id, session_id)
|
|
);
|
|
"""
|
|
)
|
|
await self._db.commit()
|
|
|
|
async def close(self) -> None:
|
|
if self._db:
|
|
await self._db.close()
|
|
self._db = None
|
|
|
|
@property
|
|
def db(self) -> aiosqlite.Connection:
|
|
if not self._db:
|
|
raise RuntimeError("Storage is not connected")
|
|
return self._db
|
|
|
|
async def get_user(self, telegram_id: int) -> LinkedUser | None:
|
|
cursor = await self.db.execute(
|
|
"""
|
|
SELECT telegram_id, api_token, ha_user_id, display_name, username,
|
|
session_id, reminder_seq, pomodoro_seq
|
|
FROM users WHERE telegram_id = ?
|
|
""",
|
|
(telegram_id,),
|
|
)
|
|
row = await cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
return LinkedUser(
|
|
telegram_id=int(row["telegram_id"]),
|
|
api_token=str(row["api_token"]),
|
|
ha_user_id=int(row["ha_user_id"]),
|
|
display_name=str(row["display_name"] or ""),
|
|
username=str(row["username"] or ""),
|
|
session_id=int(row["session_id"]),
|
|
reminder_seq=int(row["reminder_seq"]),
|
|
pomodoro_seq=int(row["pomodoro_seq"]),
|
|
)
|
|
|
|
async def list_linked_users(self) -> list[LinkedUser]:
|
|
cursor = await self.db.execute(
|
|
"""
|
|
SELECT telegram_id, api_token, ha_user_id, display_name, username,
|
|
session_id, reminder_seq, pomodoro_seq
|
|
FROM users
|
|
"""
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [
|
|
LinkedUser(
|
|
telegram_id=int(row["telegram_id"]),
|
|
api_token=str(row["api_token"]),
|
|
ha_user_id=int(row["ha_user_id"]),
|
|
display_name=str(row["display_name"] or ""),
|
|
username=str(row["username"] or ""),
|
|
session_id=int(row["session_id"]),
|
|
reminder_seq=int(row["reminder_seq"]),
|
|
pomodoro_seq=int(row["pomodoro_seq"]),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
async def link_user(
|
|
self,
|
|
*,
|
|
telegram_id: int,
|
|
api_token: str,
|
|
ha_user_id: int,
|
|
display_name: str,
|
|
username: str,
|
|
session_id: int,
|
|
) -> None:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
await self.db.execute(
|
|
"""
|
|
INSERT INTO users (
|
|
telegram_id, api_token, ha_user_id, display_name, username,
|
|
session_id, reminder_seq, pomodoro_seq, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, 0, 0, ?)
|
|
ON CONFLICT(telegram_id) DO UPDATE SET
|
|
api_token = excluded.api_token,
|
|
ha_user_id = excluded.ha_user_id,
|
|
display_name = excluded.display_name,
|
|
username = excluded.username,
|
|
session_id = excluded.session_id,
|
|
reminder_seq = 0,
|
|
pomodoro_seq = 0
|
|
""",
|
|
(telegram_id, api_token, ha_user_id, display_name, username, session_id, now),
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def unlink_user(self, telegram_id: int) -> bool:
|
|
cursor = await self.db.execute("DELETE FROM users WHERE telegram_id = ?", (telegram_id,))
|
|
await self.db.execute(
|
|
"DELETE FROM session_cursors WHERE telegram_id = ?",
|
|
(telegram_id,),
|
|
)
|
|
await self.db.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
async def set_session_id(self, telegram_id: int, session_id: int) -> None:
|
|
await self.db.execute(
|
|
"UPDATE users SET session_id = ? WHERE telegram_id = ?",
|
|
(session_id, telegram_id),
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def update_seq(
|
|
self,
|
|
telegram_id: int,
|
|
*,
|
|
reminder_seq: int | None = None,
|
|
pomodoro_seq: int | None = None,
|
|
) -> None:
|
|
if reminder_seq is not None:
|
|
await self.db.execute(
|
|
"UPDATE users SET reminder_seq = ? WHERE telegram_id = ?",
|
|
(reminder_seq, telegram_id),
|
|
)
|
|
if pomodoro_seq is not None:
|
|
await self.db.execute(
|
|
"UPDATE users SET pomodoro_seq = ? WHERE telegram_id = ?",
|
|
(pomodoro_seq, telegram_id),
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_last_message_id(self, telegram_id: int, session_id: int) -> int:
|
|
cursor = await self.db.execute(
|
|
"""
|
|
SELECT last_message_id FROM session_cursors
|
|
WHERE telegram_id = ? AND session_id = ?
|
|
""",
|
|
(telegram_id, session_id),
|
|
)
|
|
row = await cursor.fetchone()
|
|
return int(row["last_message_id"]) if row else 0
|
|
|
|
async def set_last_message_id(self, telegram_id: int, session_id: int, message_id: int) -> None:
|
|
await self.db.execute(
|
|
"""
|
|
INSERT INTO session_cursors (telegram_id, session_id, last_message_id)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(telegram_id, session_id) DO UPDATE SET
|
|
last_message_id = excluded.last_message_id
|
|
""",
|
|
(telegram_id, session_id, message_id),
|
|
)
|
|
await self.db.commit()
|