Skip to content

Commit

Permalink
fix: sync env.py and add sa_orm_sentinel to the sorted columns
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Aug 28, 2023
1 parent cffc81b commit ccefad8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@writer.rewrites(ops.CreateTableOp)
def order_columns(context: EnvironmentContext, revision: tuple[str, ...], op: ops.CreateTableOp) -> ops.CreateTableOp:
"""Orders ID first and the audit columns at the end."""
special_names = {"id": -100, "created_at": 1001, "updated_at": 1002}
special_names = {"id": -100, "sa_orm_sentinel": 1001, "created_at": 1002, "updated_at": 1002}
cols_by_key = [
(
special_names.get(col.key, index) if isinstance(col, Column) else 2000,
Expand Down
31 changes: 30 additions & 1 deletion litestar/contrib/sqlalchemy/alembic/templates/sync/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from typing import TYPE_CHECKING

from alembic import context
from sqlalchemy import engine_from_config, pool
from alembic.autogenerate import rewriter
from alembic.operations import ops
from sqlalchemy import Column, engine_from_config, pool

from litestar.contrib.sqlalchemy.base import orm_registry

if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from sqlalchemy.engine import Connection

from litestar.contrib.sqlalchemy.alembic.commands import AlembicCommandConfig
Expand All @@ -28,6 +31,30 @@
# can be acquired:
# ... etc.

writer = rewriter.Rewriter()


@writer.rewrites(ops.CreateTableOp)
def order_columns(context: EnvironmentContext, revision: tuple[str, ...], op: ops.CreateTableOp) -> ops.CreateTableOp:
"""Orders ID first and the audit columns at the end."""
special_names = {"id": -100, "sa_orm_sentinel": 1001, "created_at": 1002, "updated_at": 1002}
cols_by_key = [
(
special_names.get(col.key, index) if isinstance(col, Column) else 2000,
col.copy(), # type: ignore[attr-defined]
)
for index, col in enumerate(op.columns)
]
columns = [col for _, col in sorted(cols_by_key, key=lambda entry: entry[0])]
return ops.CreateTableOp(
op.table_name,
columns,
schema=op.schema,
# TODO: Remove when https://github.com/sqlalchemy/alembic/issues/1193 is fixed
_namespace_metadata=op._namespace_metadata,
**op.kw,
)


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Expand All @@ -50,6 +77,7 @@ def run_migrations_offline() -> None:
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer, # type: ignore[arg-type]
)

with context.begin_transaction():
Expand All @@ -66,6 +94,7 @@ def do_run_migrations(connection: Connection) -> None:
version_table_pk=config.version_table_pk,
user_module_prefix=config.user_module_prefix,
render_as_batch=config.render_as_batch,
process_revision_directives=writer, # type: ignore[arg-type]
)

with context.begin_transaction():
Expand Down

0 comments on commit ccefad8

Please sign in to comment.