fix migration
This commit is contained in:
@@ -29,12 +29,19 @@ def _add_column_if_missing(table: str, column: str, ddl: str) -> None:
|
||||
|
||||
|
||||
def _ensure_schema_migrations_table() -> None:
|
||||
from app.db.dialect import is_postgresql
|
||||
|
||||
applied_type = (
|
||||
"TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
|
||||
if is_postgresql(engine)
|
||||
else "DATETIME DEFAULT CURRENT_TIMESTAMP"
|
||||
)
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS _schema_migrations ("
|
||||
"name TEXT PRIMARY KEY, "
|
||||
"applied_at DATETIME DEFAULT CURRENT_TIMESTAMP)"
|
||||
f"applied_at {applied_type})"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -74,6 +74,24 @@ def _reset_serial(dst_engine, table_name: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_schema_migrations_table(dst_engine) -> None:
|
||||
from app.db.dialect import is_postgresql
|
||||
|
||||
applied_type = (
|
||||
"TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
|
||||
if is_postgresql(dst_engine)
|
||||
else "DATETIME DEFAULT CURRENT_TIMESTAMP"
|
||||
)
|
||||
with dst_engine.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS _schema_migrations ("
|
||||
"name TEXT PRIMARY KEY, "
|
||||
f"applied_at {applied_type})"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _truncate_all(dst_engine) -> None:
|
||||
table_names = [t.name for t in Base.metadata.sorted_tables]
|
||||
if not table_names:
|
||||
@@ -106,6 +124,11 @@ def main() -> int:
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print row counts only")
|
||||
parser.add_argument("--force", action="store_true", help="Truncate PostgreSQL tables before import")
|
||||
parser.add_argument(
|
||||
"--extras-only",
|
||||
action="store_true",
|
||||
help="Copy only auxiliary tables (_schema_migrations); skip ORM tables",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.database_url.startswith("postgresql"):
|
||||
@@ -140,13 +163,14 @@ def main() -> int:
|
||||
return 0
|
||||
|
||||
existing_rows = _postgres_has_data(dst_engine)
|
||||
if existing_rows > 0 and not args.force:
|
||||
if existing_rows > 0 and not args.force and not args.extras_only:
|
||||
print(
|
||||
f"ERROR: PostgreSQL already has {existing_rows} row(s). "
|
||||
"Use --force to truncate and re-import (e.g. after a failed partial migration)."
|
||||
"Use --force to truncate and re-import, or --extras-only for _schema_migrations only."
|
||||
)
|
||||
return 1
|
||||
|
||||
if not args.extras_only:
|
||||
Base.metadata.create_all(bind=dst_engine)
|
||||
|
||||
if args.force and existing_rows > 0:
|
||||
@@ -157,14 +181,20 @@ def main() -> int:
|
||||
for table in Base.metadata.sorted_tables:
|
||||
if table.name not in src_tables:
|
||||
continue
|
||||
count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table)
|
||||
count = _copy_table(
|
||||
src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table
|
||||
)
|
||||
if count:
|
||||
print(f" {table.name}: {count} rows")
|
||||
copied += count
|
||||
if "id" in [col.name for col in table.columns] and count > 0:
|
||||
_reset_serial(dst_engine, table.name)
|
||||
else:
|
||||
copied = 0
|
||||
print("Extras-only mode: skipping ORM tables")
|
||||
|
||||
for name in extra_tables:
|
||||
_ensure_schema_migrations_table(dst_engine)
|
||||
inspector = inspect(src_engine)
|
||||
cols = [col["name"] for col in inspector.get_columns(name)]
|
||||
count = _copy_table(src_engine, dst_engine, name, cols)
|
||||
|
||||
Reference in New Issue
Block a user