fix migration
This commit is contained in:
@@ -8,7 +8,7 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
from sqlalchemy import Boolean, create_engine, inspect, text
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
@@ -26,7 +26,26 @@ def _row_count(engine, table: str) -> int:
|
||||
return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0)
|
||||
|
||||
|
||||
def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) -> int:
|
||||
def _normalize_row(table, row: dict) -> dict:
|
||||
"""SQLite stores booleans as 0/1; PostgreSQL needs true/false."""
|
||||
normalized = dict(row)
|
||||
for column in table.columns:
|
||||
name = column.name
|
||||
if name not in normalized or normalized[name] is None:
|
||||
continue
|
||||
if isinstance(column.type, Boolean):
|
||||
normalized[name] = bool(normalized[name])
|
||||
return normalized
|
||||
|
||||
|
||||
def _copy_table(
|
||||
src_engine,
|
||||
dst_engine,
|
||||
table_name: str,
|
||||
columns: list[str],
|
||||
*,
|
||||
orm_table=None,
|
||||
) -> int:
|
||||
col_sql = ", ".join(f'"{c}"' for c in columns)
|
||||
select_sql = f'SELECT {col_sql} FROM "{table_name}"'
|
||||
insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})'
|
||||
@@ -38,7 +57,10 @@ def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) ->
|
||||
|
||||
with dst_engine.begin() as dst_conn:
|
||||
for row in rows:
|
||||
dst_conn.execute(text(insert_sql), dict(row))
|
||||
payload = dict(row)
|
||||
if orm_table is not None:
|
||||
payload = _normalize_row(orm_table, payload)
|
||||
dst_conn.execute(text(insert_sql), payload)
|
||||
return len(rows)
|
||||
|
||||
|
||||
@@ -61,6 +83,15 @@ def _truncate_all(dst_engine) -> None:
|
||||
conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE"))
|
||||
|
||||
|
||||
def _postgres_has_data(dst_engine) -> int:
|
||||
total = 0
|
||||
inspector = inspect(dst_engine)
|
||||
for table in Base.metadata.sorted_tables:
|
||||
if table.name in inspector.get_table_names():
|
||||
total += _row_count(dst_engine, table.name)
|
||||
return total
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL")
|
||||
parser.add_argument(
|
||||
@@ -108,17 +139,17 @@ def main() -> int:
|
||||
print(f"Total rows: {total}")
|
||||
return 0
|
||||
|
||||
existing_users = _row_count(dst_engine, "users")
|
||||
if existing_users > 0 and not args.force:
|
||||
existing_rows = _postgres_has_data(dst_engine)
|
||||
if existing_rows > 0 and not args.force:
|
||||
print(
|
||||
f"ERROR: PostgreSQL already has {existing_users} user(s). "
|
||||
"Use --force to truncate and re-import, or migrate to an empty database."
|
||||
f"ERROR: PostgreSQL already has {existing_rows} row(s). "
|
||||
"Use --force to truncate and re-import (e.g. after a failed partial migration)."
|
||||
)
|
||||
return 1
|
||||
|
||||
Base.metadata.create_all(bind=dst_engine)
|
||||
|
||||
if args.force and existing_users > 0:
|
||||
if args.force and existing_rows > 0:
|
||||
print("Truncating PostgreSQL tables...")
|
||||
_truncate_all(dst_engine)
|
||||
|
||||
@@ -126,12 +157,11 @@ def main() -> int:
|
||||
for table in Base.metadata.sorted_tables:
|
||||
if table.name not in src_tables:
|
||||
continue
|
||||
columns = [col.name for col in table.columns]
|
||||
count = _copy_table(src_engine, dst_engine, table.name, columns)
|
||||
count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table)
|
||||
if count:
|
||||
print(f" {table.name}: {count} rows")
|
||||
copied += count
|
||||
if "id" in columns and count > 0:
|
||||
if "id" in [col.name for col in table.columns] and count > 0:
|
||||
_reset_serial(dst_engine, table.name)
|
||||
|
||||
for name in extra_tables:
|
||||
|
||||
Reference in New Issue
Block a user