#!/usr/bin/env python3 """Copy assistant data from SQLite to PostgreSQL.""" from __future__ import annotations import argparse import os import sys from pathlib import Path from sqlalchemy import create_engine, inspect, text BACKEND_ROOT = Path(__file__).resolve().parents[1] if str(BACKEND_ROOT) not in sys.path: sys.path.insert(0, str(BACKEND_ROOT)) from app.db import models # noqa: F401 from app.db.base import Base def _row_count(engine, table: str) -> int: inspector = inspect(engine) if not inspector.has_table(table): return 0 with engine.connect() as conn: return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0) def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) -> int: col_sql = ", ".join(f'"{c}"' for c in columns) select_sql = f'SELECT {col_sql} FROM "{table_name}"' insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})' with src_engine.connect() as src_conn: rows = src_conn.execute(text(select_sql)).mappings().all() if not rows: return 0 with dst_engine.begin() as dst_conn: for row in rows: dst_conn.execute(text(insert_sql), dict(row)) return len(rows) def _reset_serial(dst_engine, table_name: str) -> None: with dst_engine.begin() as conn: conn.execute( text( f"SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), " f"COALESCE((SELECT MAX(id) FROM \"{table_name}\"), 1), true)" ) ) def _truncate_all(dst_engine) -> None: table_names = [t.name for t in Base.metadata.sorted_tables] if not table_names: return joined = ", ".join(f'"{name}"' for name in table_names) with dst_engine.begin() as conn: conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE")) def main() -> int: parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL") parser.add_argument( "--sqlite-path", default=os.environ.get("SQLITE_PATH", "./data/assistant.db"), help="Path to SQLite database file", ) parser.add_argument( "--database-url", default=os.environ.get("DATABASE_URL", ""), help="PostgreSQL DATABASE_URL (postgresql+psycopg2://...)", ) parser.add_argument("--dry-run", action="store_true", help="Print row counts only") parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import") args = parser.parse_args() if not args.database_url.startswith("postgresql"): print("ERROR: DATABASE_URL must be a PostgreSQL URL (postgresql+psycopg2://...)") return 1 sqlite_path = Path(args.sqlite_path) if not sqlite_path.is_file(): print(f"ERROR: SQLite file not found: {sqlite_path}") return 1 src_engine = create_engine(f"sqlite:///{sqlite_path.as_posix()}") dst_engine = create_engine(args.database_url) src_tables = set(inspect(src_engine).get_table_names()) extra_tables = [t for t in ("_schema_migrations",) if t in src_tables] if args.dry_run: print(f"Dry run: {sqlite_path} -> PostgreSQL") total = 0 for table in Base.metadata.sorted_tables: count = _row_count(src_engine, table.name) if table.name in src_tables else 0 if count: print(f" {table.name}: {count}") total += count for name in extra_tables: count = _row_count(src_engine, name) if count: print(f" {name}: {count}") total += count print(f"Total rows: {total}") return 0 existing_users = _row_count(dst_engine, "users") if existing_users > 0 and not args.force: print( f"ERROR: PostgreSQL already has {existing_users} user(s). " "Use --force to truncate and re-import, or migrate to an empty database." ) return 1 Base.metadata.create_all(bind=dst_engine) if args.force and existing_users > 0: print("Truncating PostgreSQL tables...") _truncate_all(dst_engine) copied = 0 for table in Base.metadata.sorted_tables: if table.name not in src_tables: continue columns = [col.name for col in table.columns] count = _copy_table(src_engine, dst_engine, table.name, columns) if count: print(f" {table.name}: {count} rows") copied += count if "id" in columns and count > 0: _reset_serial(dst_engine, table.name) for name in extra_tables: inspector = inspect(src_engine) cols = [col["name"] for col in inspector.get_columns(name)] count = _copy_table(src_engine, dst_engine, name, cols) if count: print(f" {name}: {count} rows") copied += count print(f"Migration complete: {copied} rows copied from {sqlite_path}") print("SQLite file kept as backup. Update .env DATABASE_URL if not already pointing to PostgreSQL.") return 0 if __name__ == "__main__": raise SystemExit(main())