From ccefad836761bbee93930d86becd3b80488a047a Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 28 Aug 2023 17:23:14 -0500 Subject: [PATCH] fix: sync env.py and add `sa_orm_sentinel` to the sorted columns --- .../alembic/templates/asyncio/env.py | 2 +- .../sqlalchemy/alembic/templates/sync/env.py | 31 ++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/litestar/contrib/sqlalchemy/alembic/templates/asyncio/env.py b/litestar/contrib/sqlalchemy/alembic/templates/asyncio/env.py index f7bc76d4c0..0ec0055b41 100644 --- a/litestar/contrib/sqlalchemy/alembic/templates/asyncio/env.py +++ b/litestar/contrib/sqlalchemy/alembic/templates/asyncio/env.py @@ -39,7 +39,7 @@ @writer.rewrites(ops.CreateTableOp) def order_columns(context: EnvironmentContext, revision: tuple[str, ...], op: ops.CreateTableOp) -> ops.CreateTableOp: """Orders ID first and the audit columns at the end.""" - special_names = {"id": -100, "created_at": 1001, "updated_at": 1002} + special_names = {"id": -100, "sa_orm_sentinel": 1001, "created_at": 1002, "updated_at": 1002} cols_by_key = [ ( special_names.get(col.key, index) if isinstance(col, Column) else 2000, diff --git a/litestar/contrib/sqlalchemy/alembic/templates/sync/env.py b/litestar/contrib/sqlalchemy/alembic/templates/sync/env.py index 8d27b2f176..6b505c4492 100644 --- a/litestar/contrib/sqlalchemy/alembic/templates/sync/env.py +++ b/litestar/contrib/sqlalchemy/alembic/templates/sync/env.py @@ -3,11 +3,14 @@ from typing import TYPE_CHECKING from alembic import context -from sqlalchemy import engine_from_config, pool +from alembic.autogenerate import rewriter +from alembic.operations import ops +from sqlalchemy import Column, engine_from_config, pool from litestar.contrib.sqlalchemy.base import orm_registry if TYPE_CHECKING: + from alembic.runtime.environment import EnvironmentContext from sqlalchemy.engine import Connection from litestar.contrib.sqlalchemy.alembic.commands import AlembicCommandConfig @@ -28,6 +31,30 @@ # can be acquired: # ... etc. +writer = rewriter.Rewriter() + + +@writer.rewrites(ops.CreateTableOp) +def order_columns(context: EnvironmentContext, revision: tuple[str, ...], op: ops.CreateTableOp) -> ops.CreateTableOp: + """Orders ID first and the audit columns at the end.""" + special_names = {"id": -100, "sa_orm_sentinel": 1001, "created_at": 1002, "updated_at": 1002} + cols_by_key = [ + ( + special_names.get(col.key, index) if isinstance(col, Column) else 2000, + col.copy(), # type: ignore[attr-defined] + ) + for index, col in enumerate(op.columns) + ] + columns = [col for _, col in sorted(cols_by_key, key=lambda entry: entry[0])] + return ops.CreateTableOp( + op.table_name, + columns, + schema=op.schema, + # TODO: Remove when https://github.com/sqlalchemy/alembic/issues/1193 is fixed + _namespace_metadata=op._namespace_metadata, + **op.kw, + ) + def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -50,6 +77,7 @@ def run_migrations_offline() -> None: version_table_pk=config.version_table_pk, user_module_prefix=config.user_module_prefix, render_as_batch=config.render_as_batch, + process_revision_directives=writer, # type: ignore[arg-type] ) with context.begin_transaction(): @@ -66,6 +94,7 @@ def do_run_migrations(connection: Connection) -> None: version_table_pk=config.version_table_pk, user_module_prefix=config.user_module_prefix, render_as_batch=config.render_as_batch, + process_revision_directives=writer, # type: ignore[arg-type] ) with context.begin_transaction():