fix migration
This commit is contained in:
@@ -190,8 +190,9 @@ deploy/ примеры nginx
|
|||||||
# 1. Бэкап
|
# 1. Бэкап
|
||||||
cp -a data data.bak.$(date +%Y%m%d)
|
cp -a data data.bak.$(date +%Y%m%d)
|
||||||
|
|
||||||
# 2. Поднять postgres
|
# 2. Поднять postgres и пересобрать backend (скрипт миграции в образе)
|
||||||
docker compose up -d postgres
|
docker compose up -d postgres
|
||||||
|
docker compose build backend
|
||||||
|
|
||||||
# 3. Dry-run (подсчёт строк)
|
# 3. Dry-run (подсчёт строк)
|
||||||
docker compose run --rm backend python scripts/migrate_sqlite_to_postgres.py --dry-run
|
docker compose run --rm backend python scripts/migrate_sqlite_to_postgres.py --dry-run
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import create_engine, inspect, text
|
from sqlalchemy import Boolean, create_engine, inspect, text
|
||||||
|
|
||||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||||
if str(BACKEND_ROOT) not in sys.path:
|
if str(BACKEND_ROOT) not in sys.path:
|
||||||
@@ -26,7 +26,26 @@ def _row_count(engine, table: str) -> int:
|
|||||||
return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0)
|
return int(conn.execute(text(f'SELECT COUNT(*) FROM "{table}"')).scalar() or 0)
|
||||||
|
|
||||||
|
|
||||||
def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) -> int:
|
def _normalize_row(table, row: dict) -> dict:
|
||||||
|
"""SQLite stores booleans as 0/1; PostgreSQL needs true/false."""
|
||||||
|
normalized = dict(row)
|
||||||
|
for column in table.columns:
|
||||||
|
name = column.name
|
||||||
|
if name not in normalized or normalized[name] is None:
|
||||||
|
continue
|
||||||
|
if isinstance(column.type, Boolean):
|
||||||
|
normalized[name] = bool(normalized[name])
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_table(
|
||||||
|
src_engine,
|
||||||
|
dst_engine,
|
||||||
|
table_name: str,
|
||||||
|
columns: list[str],
|
||||||
|
*,
|
||||||
|
orm_table=None,
|
||||||
|
) -> int:
|
||||||
col_sql = ", ".join(f'"{c}"' for c in columns)
|
col_sql = ", ".join(f'"{c}"' for c in columns)
|
||||||
select_sql = f'SELECT {col_sql} FROM "{table_name}"'
|
select_sql = f'SELECT {col_sql} FROM "{table_name}"'
|
||||||
insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})'
|
insert_sql = f'INSERT INTO "{table_name}" ({col_sql}) VALUES ({", ".join(f":{c}" for c in columns)})'
|
||||||
@@ -38,7 +57,10 @@ def _copy_table(src_engine, dst_engine, table_name: str, columns: list[str]) ->
|
|||||||
|
|
||||||
with dst_engine.begin() as dst_conn:
|
with dst_engine.begin() as dst_conn:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
dst_conn.execute(text(insert_sql), dict(row))
|
payload = dict(row)
|
||||||
|
if orm_table is not None:
|
||||||
|
payload = _normalize_row(orm_table, payload)
|
||||||
|
dst_conn.execute(text(insert_sql), payload)
|
||||||
return len(rows)
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +83,15 @@ def _truncate_all(dst_engine) -> None:
|
|||||||
conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE"))
|
conn.execute(text(f"TRUNCATE TABLE {joined} RESTART IDENTITY CASCADE"))
|
||||||
|
|
||||||
|
|
||||||
|
def _postgres_has_data(dst_engine) -> int:
|
||||||
|
total = 0
|
||||||
|
inspector = inspect(dst_engine)
|
||||||
|
for table in Base.metadata.sorted_tables:
|
||||||
|
if table.name in inspector.get_table_names():
|
||||||
|
total += _row_count(dst_engine, table.name)
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL")
|
parser = argparse.ArgumentParser(description="Migrate SQLite assistant.db to PostgreSQL")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -108,17 +139,17 @@ def main() -> int:
|
|||||||
print(f"Total rows: {total}")
|
print(f"Total rows: {total}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
existing_users = _row_count(dst_engine, "users")
|
existing_rows = _postgres_has_data(dst_engine)
|
||||||
if existing_users > 0 and not args.force:
|
if existing_rows > 0 and not args.force:
|
||||||
print(
|
print(
|
||||||
f"ERROR: PostgreSQL already has {existing_users} user(s). "
|
f"ERROR: PostgreSQL already has {existing_rows} row(s). "
|
||||||
"Use --force to truncate and re-import, or migrate to an empty database."
|
"Use --force to truncate and re-import (e.g. after a failed partial migration)."
|
||||||
)
|
)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
Base.metadata.create_all(bind=dst_engine)
|
Base.metadata.create_all(bind=dst_engine)
|
||||||
|
|
||||||
if args.force and existing_users > 0:
|
if args.force and existing_rows > 0:
|
||||||
print("Truncating PostgreSQL tables...")
|
print("Truncating PostgreSQL tables...")
|
||||||
_truncate_all(dst_engine)
|
_truncate_all(dst_engine)
|
||||||
|
|
||||||
@@ -126,12 +157,11 @@ def main() -> int:
|
|||||||
for table in Base.metadata.sorted_tables:
|
for table in Base.metadata.sorted_tables:
|
||||||
if table.name not in src_tables:
|
if table.name not in src_tables:
|
||||||
continue
|
continue
|
||||||
columns = [col.name for col in table.columns]
|
count = _copy_table(src_engine, dst_engine, table.name, [col.name for col in table.columns], orm_table=table)
|
||||||
count = _copy_table(src_engine, dst_engine, table.name, columns)
|
|
||||||
if count:
|
if count:
|
||||||
print(f" {table.name}: {count} rows")
|
print(f" {table.name}: {count} rows")
|
||||||
copied += count
|
copied += count
|
||||||
if "id" in columns and count > 0:
|
if "id" in [col.name for col in table.columns] and count > 0:
|
||||||
_reset_serial(dst_engine, table.name)
|
_reset_serial(dst_engine, table.name)
|
||||||
|
|
||||||
for name in extra_tables:
|
for name in extra_tables:
|
||||||
|
|||||||
Reference in New Issue
Block a user