fix migration

This commit is contained in:
2026-06-16 09:40:14 +03:00
parent 81c8117520
commit 9c09152bbf
2 changed files with 54 additions and 17 deletions
+8 -1
View File
@@ -29,12 +29,19 @@ def _add_column_if_missing(table: str, column: str, ddl: str) -> None:
def _ensure_schema_migrations_table() -> None: def _ensure_schema_migrations_table() -> None:
from app.db.dialect import is_postgresql
applied_type = (
"TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
if is_postgresql(engine)
else "DATETIME DEFAULT CURRENT_TIMESTAMP"
)
with engine.begin() as conn: with engine.begin() as conn:
conn.execute( conn.execute(
text( text(
"CREATE TABLE IF NOT EXISTS _schema_migrations (" "CREATE TABLE IF NOT EXISTS _schema_migrations ("
"name TEXT PRIMARY KEY, " "name TEXT PRIMARY KEY, "
"applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)" f"applied_at {applied_type})"
) )
) )
+46 -16
View File
@@ -74,6 +74,24 @@ def _reset_serial(dst_engine, table_name: str) -> None:
) )
def _ensure_schema_migrations_table(dst_engine) -> None:
from app.db.dialect import is_postgresql
applied_type = (
"TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
if is_postgresql(dst_engine)
else "DATETIME DEFAULT CURRENT_TIMESTAMP"
)
with dst_engine.begin() as conn:
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS _schema_migrations ("
"name TEXT PRIMARY KEY, "
f"applied_at {applied_type})"
)
)
def _truncate_all(dst_engine) -> None: def _truncate_all(dst_engine) -> None:
table_names = [t.name for t in Base.metadata.sorted_tables] table_names = [t.name for t in Base.metadata.sorted_tables]
if not table_names: if not table_names:
@@ -106,6 +124,11 @@ def main() -> int:
) )
parser.add_argument("--dry-run", action="store_true", help="Print row counts only") parser.add_argument("--dry-run", action="store_true", help="Print row counts only")
parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import") parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import")
parser.add_argument(
"--extras-only",
action="store_true",
help="Copy only auxiliary tables (_schema_migrations); skip ORM tables",
)
args = parser.parse_args() args = parser.parse_args()
if not args.database_url.startswith("postgresql"): if not args.database_url.startswith("postgresql"):
@@ -140,31 +163,38 @@ def main() -> int:
return 0 return 0
existing_rows = _postgres_has_data(dst_engine) existing_rows = _postgres_has_data(dst_engine)
if existing_rows > 0 and not args.force: if existing_rows > 0 and not args.force and not args.extras_only:
print( print(
f"ERROR: PostgreSQL already has {existing_rows} row(s). " f"ERROR: PostgreSQL already has {existing_rows} row(s). "
"Use --force to truncate and re-import (e.g. after a failed partial migration)." "Use --force to truncate and re-import, or --extras-only for _schema_migrations only."
) )
return 1 return 1
Base.metadata.create_all(bind=dst_engine) if not args.extras_only:
Base.metadata.create_all(bind=dst_engine)
if args.force and existing_rows > 0: if args.force and existing_rows > 0:
print("Truncating PostgreSQL tables...") print("Truncating PostgreSQL tables...")
_truncate_all(dst_engine) _truncate_all(dst_engine)
copied = 0 copied = 0
for table in Base.metadata.sorted_tables: for table in Base.metadata.sorted_tables:
if table.name not in src_tables: if table.name not in src_tables:
continue continue
count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table) count = _copy_table(
if count: src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table
print(f" {table.name}: {count} rows") )
copied += count if count:
if "id" in [col.name for col in table.columns] and count > 0: print(f" {table.name}: {count} rows")
_reset_serial(dst_engine, table.name) copied += count
if "id" in [col.name for col in table.columns] and count > 0:
_reset_serial(dst_engine, table.name)
else:
copied = 0
print("Extras-only mode: skipping ORM tables")
for name in extra_tables: for name in extra_tables:
_ensure_schema_migrations_table(dst_engine)
inspector = inspect(src_engine) inspector = inspect(src_engine)
cols = [col["name"] for col in inspector.get_columns(name)] cols = [col["name"] for col in inspector.get_columns(name)]
count = _copy_table(src_engine, dst_engine, name, cols) count = _copy_table(src_engine, dst_engine, name, cols)