From 9c09152bbf90a0ce90eefaaf8f0d846c366282b4 Mon Sep 17 00:00:00 2001 From: grigo Date: Tue, 16 Jun 2026 09:40:14 +0300 Subject: [PATCH] fix migration --- backend/app/db/migrate_fitness.py | 9 ++- backend/scripts/migrate_sqlite_to_postgres.py | 62 ++++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/backend/app/db/migrate_fitness.py b/backend/app/db/migrate_fitness.py index d711351..ea9abf3 100644 --- a/backend/app/db/migrate_fitness.py +++ b/backend/app/db/migrate_fitness.py @@ -29,12 +29,19 @@ def _add_column_if_missing(table: str, column: str, ddl: str) -> None: def _ensure_schema_migrations_table() -> None: + from app.db.dialect import is_postgresql + + applied_type = ( + "TIMESTAMP DEFAULT CURRENT_TIMESTAMP" + if is_postgresql(engine) + else "DATETIME DEFAULT CURRENT_TIMESTAMP" + ) with engine.begin() as conn: conn.execute( text( "CREATE TABLE IF NOT EXISTS _schema_migrations (" "name TEXT PRIMARY KEY, " - "applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)" + f"applied_at {applied_type})" ) ) diff --git a/backend/scripts/migrate_sqlite_to_postgres.py b/backend/scripts/migrate_sqlite_to_postgres.py index f37cc51..e392e26 100644 --- a/backend/scripts/migrate_sqlite_to_postgres.py +++ b/backend/scripts/migrate_sqlite_to_postgres.py @@ -74,6 +74,24 @@ def _reset_serial(dst_engine, table_name: str) -> None: ) +def _ensure_schema_migrations_table(dst_engine) -> None: + from app.db.dialect import is_postgresql + + applied_type = ( + "TIMESTAMP DEFAULT CURRENT_TIMESTAMP" + if is_postgresql(dst_engine) + else "DATETIME DEFAULT CURRENT_TIMESTAMP" + ) + with dst_engine.begin() as conn: + conn.execute( + text( + "CREATE TABLE IF NOT EXISTS _schema_migrations (" + "name TEXT PRIMARY KEY, " + f"applied_at {applied_type})" + ) + ) + + def _truncate_all(dst_engine) -> None: table_names = [t.name for t in Base.metadata.sorted_tables] if not table_names: @@ -106,6 +124,11 @@ def main() -> int: ) parser.add_argument("--dry-run", action="store_true", help="Print row counts only") parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import") + parser.add_argument( + "--extras-only", + action="store_true", + help="Copy only auxiliary tables (_schema_migrations); skip ORM tables", + ) args = parser.parse_args() if not args.database_url.startswith("postgresql"): @@ -140,31 +163,38 @@ def main() -> int: return 0 existing_rows = _postgres_has_data(dst_engine) - if existing_rows > 0 and not args.force: + if existing_rows > 0 and not args.force and not args.extras_only: print( f"ERROR: PostgreSQL already has {existing_rows} row(s). " - "Use --force to truncate and re-import (e.g. after a failed partial migration)." + "Use --force to truncate and re-import, or --extras-only for _schema_migrations only." ) return 1 - Base.metadata.create_all(bind=dst_engine) + if not args.extras_only: + Base.metadata.create_all(bind=dst_engine) - if args.force and existing_rows > 0: - print("Truncating PostgreSQL tables...") - _truncate_all(dst_engine) + if args.force and existing_rows > 0: + print("Truncating PostgreSQL tables...") + _truncate_all(dst_engine) - copied = 0 - for table in Base.metadata.sorted_tables: - if table.name not in src_tables: - continue - count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table) - if count: - print(f" {table.name}: {count} rows") - copied += count - if "id" in [col.name for col in table.columns] and count > 0: - _reset_serial(dst_engine, table.name) + copied = 0 + for table in Base.metadata.sorted_tables: + if table.name not in src_tables: + continue + count = _copy_table( + src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table + ) + if count: + print(f" {table.name}: {count} rows") + copied += count + if "id" in [col.name for col in table.columns] and count > 0: + _reset_serial(dst_engine, table.name) + else: + copied = 0 + print("Extras-only mode: skipping ORM tables") for name in extra_tables: + _ensure_schema_migrations_table(dst_engine) inspector = inspect(src_engine) cols = [col["name"] for col in inspector.get_columns(name)] count = _copy_table(src_engine, dst_engine, name, cols)