182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Copy assistant data from SQLite to PostgreSQL."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy import Boolean, create_engine, inspect, text
|
|
|
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(BACKEND_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(BACKEND_ROOT))
|
|
|
|
from app.db import models # noqa: F401
|
|
from app.db.base import Base
|
|
|
|
|
|
def _row_count(engine, table: str) -> int:
|
|
inspector = inspect(engine)
|
|
if not inspector.has_table(table):
|
|
return 0
|
|
with engine.connect() as conn:
|
|
return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0)
|
|
|
|
|
|
def _normalize_row(table, row: dict) -> dict:
|
|
"""SQLite stores booleans as 0/1; PostgreSQL needs true/false."""
|
|
normalized = dict(row)
|
|
for column in table.columns:
|
|
name = column.name
|
|
if name not in normalized or normalized[name] is None:
|
|
continue
|
|
if isinstance(column.type, Boolean):
|
|
normalized[name] = bool(normalized[name])
|
|
return normalized
|
|
|
|
|
|
def _copy_table(
|
|
src_engine,
|
|
dst_engine,
|
|
table_name: str,
|
|
columns: list[str],
|
|
*,
|
|
orm_table=None,
|
|
) -> int:
|
|
col_sql = ", ".join(f'"{c}"' for c in columns)
|
|
select_sql = f'SELECT {col_sql} FROM "{table_name}"'
|
|
insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})'
|
|
|
|
with src_engine.connect() as src_conn:
|
|
rows = src_conn.execute(text(select_sql)).mappings().all()
|
|
if not rows:
|
|
return 0
|
|
|
|
with dst_engine.begin() as dst_conn:
|
|
for row in rows:
|
|
payload = dict(row)
|
|
if orm_table is not None:
|
|
payload = _normalize_row(orm_table, payload)
|
|
dst_conn.execute(text(insert_sql), payload)
|
|
return len(rows)
|
|
|
|
|
|
def _reset_serial(dst_engine, table_name: str) -> None:
|
|
with dst_engine.begin() as conn:
|
|
conn.execute(
|
|
text(
|
|
f"SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), "
|
|
f"COALESCE((SELECT MAX(id) FROM \"{table_name}\"), 1), true)"
|
|
)
|
|
)
|
|
|
|
|
|
def _truncate_all(dst_engine) -> None:
|
|
table_names = [t.name for t in Base.metadata.sorted_tables]
|
|
if not table_names:
|
|
return
|
|
joined = ", ".join(f'"{name}"' for name in table_names)
|
|
with dst_engine.begin() as conn:
|
|
conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE"))
|
|
|
|
|
|
def _postgres_has_data(dst_engine) -> int:
|
|
total = 0
|
|
inspector = inspect(dst_engine)
|
|
for table in Base.metadata.sorted_tables:
|
|
if table.name in inspector.get_table_names():
|
|
total += _row_count(dst_engine, table.name)
|
|
return total
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL")
|
|
parser.add_argument(
|
|
"--sqlite-path",
|
|
default=os.environ.get("SQLITE_PATH", "./data/assistant.db"),
|
|
help="Path to SQLite database file",
|
|
)
|
|
parser.add_argument(
|
|
"--database-url",
|
|
default=os.environ.get("DATABASE_URL", ""),
|
|
help="PostgreSQL DATABASE_URL (postgresql+psycopg2://...)",
|
|
)
|
|
parser.add_argument("--dry-run", action="store_true", help="Print row counts only")
|
|
parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import")
|
|
args = parser.parse_args()
|
|
|
|
if not args.database_url.startswith("postgresql"):
|
|
print("ERROR: DATABASE_URL must be a PostgreSQL URL (postgresql+psycopg2://...)")
|
|
return 1
|
|
|
|
sqlite_path = Path(args.sqlite_path)
|
|
if not sqlite_path.is_file():
|
|
print(f"ERROR: SQLite file not found: {sqlite_path}")
|
|
return 1
|
|
|
|
src_engine = create_engine(f"sqlite:///{sqlite_path.as_posix()}")
|
|
dst_engine = create_engine(args.database_url)
|
|
|
|
src_tables = set(inspect(src_engine).get_table_names())
|
|
extra_tables = [t for t in ("_schema_migrations",) if t in src_tables]
|
|
|
|
if args.dry_run:
|
|
print(f"Dry run: {sqlite_path} -> PostgreSQL")
|
|
total = 0
|
|
for table in Base.metadata.sorted_tables:
|
|
count = _row_count(src_engine, table.name) if table.name in src_tables else 0
|
|
if count:
|
|
print(f" {table.name}: {count}")
|
|
total += count
|
|
for name in extra_tables:
|
|
count = _row_count(src_engine, name)
|
|
if count:
|
|
print(f" {name}: {count}")
|
|
total += count
|
|
print(f"Total rows: {total}")
|
|
return 0
|
|
|
|
existing_rows = _postgres_has_data(dst_engine)
|
|
if existing_rows > 0 and not args.force:
|
|
print(
|
|
f"ERROR: PostgreSQL already has {existing_rows} row(s). "
|
|
"Use --force to truncate and re-import (e.g. after a failed partial migration)."
|
|
)
|
|
return 1
|
|
|
|
Base.metadata.create_all(bind=dst_engine)
|
|
|
|
if args.force and existing_rows > 0:
|
|
print("Truncating PostgreSQL tables...")
|
|
_truncate_all(dst_engine)
|
|
|
|
copied = 0
|
|
for table in Base.metadata.sorted_tables:
|
|
if table.name not in src_tables:
|
|
continue
|
|
count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table)
|
|
if count:
|
|
print(f" {table.name}: {count} rows")
|
|
copied += count
|
|
if "id" in [col.name for col in table.columns] and count > 0:
|
|
_reset_serial(dst_engine, table.name)
|
|
|
|
for name in extra_tables:
|
|
inspector = inspect(src_engine)
|
|
cols = [col["name"] for col in inspector.get_columns(name)]
|
|
count = _copy_table(src_engine, dst_engine, name, cols)
|
|
if count:
|
|
print(f" {name}: {count} rows")
|
|
copied += count
|
|
|
|
print(f"Migration complete: {copied} rows copied from {sqlite_path}")
|
|
print("SQLite file kept as backup. Update .env DATABASE_URL if not already pointing to PostgreSQL.")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|