From 81c8117520f8bb667518597e879a5cbf853d1507 Mon Sep 17 00:00:00 2001 From: grigo Date: Tue, 16 Jun 2026 09:26:31 +0300 Subject: [PATCH] fix migration --- README.md | 3 +- backend/scripts/migrate_sqlite_to_postgres.py | 52 +++++++++++++++---- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4856bf1..7c9c240 100644 --- a/README.md +++ b/README.md @@ -190,8 +190,9 @@ deploy/ примеры nginx # 1. Бэкап cp -a data data.bak.$(date +%Y%m%d) -# 2. Поднять postgres +# 2. Поднять postgres и пересобрать backend (скрипт миграции в образе) docker compose up -d postgres +docker compose build backend # 3. Dry-run (подсчёт строк) docker compose run --rm backend python scripts/migrate_sqlite_to_postgres.py --dry-run diff --git a/backend/scripts/migrate_sqlite_to_postgres.py b/backend/scripts/migrate_sqlite_to_postgres.py index b1ac9f4..f37cc51 100644 --- a/backend/scripts/migrate_sqlite_to_postgres.py +++ b/backend/scripts/migrate_sqlite_to_postgres.py @@ -8,7 +8,7 @@ import os import sys from pathlib import Path -from sqlalchemy import create_engine, inspect, text +from sqlalchemy import Boolean, create_engine, inspect, text BACKEND_ROOT = Path(__file__).resolve().parents[1] if str(BACKEND_ROOT) not in sys.path: @@ -26,7 +26,26 @@ def _row_count(engine, table: str) -> int: return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0) -def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) -> int: +def _normalize_row(table, row: dict) -> dict: + """SQLite stores booleans as 0/1; PostgreSQL needs true/false.""" + normalized = dict(row) + for column in table.columns: + name = column.name + if name not in normalized or normalized[name] is None: + continue + if isinstance(column.type, Boolean): + normalized[name] = bool(normalized[name]) + return normalized + + +def _copy_table( + src_engine, + dst_engine, + table_name: str, + columns: list[str], + *, + orm_table=None, +) -> int: col_sql = ", ".join(f'"{c}"' for c in columns) select_sql = f'SELECT {col_sql} FROM "{table_name}"' insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})' @@ -38,7 +57,10 @@ def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) -> with dst_engine.begin() as dst_conn: for row in rows: - dst_conn.execute(text(insert_sql), dict(row)) + payload = dict(row) + if orm_table is not None: + payload = _normalize_row(orm_table, payload) + dst_conn.execute(text(insert_sql), payload) return len(rows) @@ -61,6 +83,15 @@ def _truncate_all(dst_engine) -> None: conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE")) +def _postgres_has_data(dst_engine) -> int: + total = 0 + inspector = inspect(dst_engine) + for table in Base.metadata.sorted_tables: + if table.name in inspector.get_table_names(): + total += _row_count(dst_engine, table.name) + return total + + def main() -> int: parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL") parser.add_argument( @@ -108,17 +139,17 @@ def main() -> int: print(f"Total rows: {total}") return 0 - existing_users = _row_count(dst_engine, "users") - if existing_users > 0 and not args.force: + existing_rows = _postgres_has_data(dst_engine) + if existing_rows > 0 and not args.force: print( - f"ERROR: PostgreSQL already has {existing_users} user(s). " - "Use --force to truncate and re-import, or migrate to an empty database." + f"ERROR: PostgreSQL already has {existing_rows} row(s). " + "Use --force to truncate and re-import (e.g. after a failed partial migration)." ) return 1 Base.metadata.create_all(bind=dst_engine) - if args.force and existing_users > 0: + if args.force and existing_rows > 0: print("Truncating PostgreSQL tables...") _truncate_all(dst_engine) @@ -126,12 +157,11 @@ def main() -> int: for table in Base.metadata.sorted_tables: if table.name not in src_tables: continue - columns = [col.name for col in table.columns] - count = _copy_table(src_engine, dst_engine, table.name, columns) + count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table) if count: print(f" {table.name}: {count} rows") copied += count - if "id" in columns and count > 0: + if "id" in [col.name for col in table.columns] and count > 0: _reset_serial(dst_engine, table.name) for name in extra_tables: