Skip to content

Commit

Permalink
[auto_schema] cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lolopinto authored Jun 7, 2024
1 parent 4adcfb9 commit 6cb99f4
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 221 deletions.
12 changes: 7 additions & 5 deletions python/auto_schema/auto_schema/clause_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ def get_clause_text(server_default, col_type):
if server_default is None:
return server_default

if isinstance(server_default, TextClause):
return normalize_clause_text(server_default.text, col_type)
match server_default:
case TextClause():
return normalize_clause_text(server_default.text, col_type)

if isinstance(server_default, DefaultClause):
return normalize_clause_text(server_default.arg, col_type)
case DefaultClause():
return normalize_clause_text(server_default.arg, col_type)

return normalize_clause_text(server_default, col_type)
case _:
return normalize_clause_text(server_default, col_type)
8 changes: 3 additions & 5 deletions python/auto_schema/auto_schema/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def upgrade(self, revision='head', sql=False):
# Simulates running the `alembic downgrade` command

def downgrade(self, revision='', delete_files=True):
paths = []
if delete_files:
paths = self._get_paths_to_delete(revision)
paths = self._get_paths_to_delete(revision) if delete_files else []
command.downgrade(self.alembic_cfg, revision)

# if downgrade worked, delete files
Expand Down Expand Up @@ -125,9 +123,9 @@ def history(self, verbose=False, last=None, rev_range=None):
"cannot pass both last and rev_range. please pick one")
if last is not None:
revs = list(self.get_script_directory().revision_map.iterate_revisions(
self.get_heads(), '-%d' % int(last), select_for_downgrade=True
self.get_heads(), f"-{int(last)}", select_for_downgrade=True
))
rev_range = '%s:current' % revs[-1].revision
rev_range = f"{revs[-1].revision}:current"

command.history(self.alembic_cfg,
indicate_current=True, verbose=verbose, rev_range=rev_range)
Expand Down
37 changes: 18 additions & 19 deletions python/auto_schema/auto_schema/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
from sqlalchemy.dialects import postgresql
import alembic.operations.ops as alembicops
from typing import Optional, Union, Any
from typing import Any
import functools


Expand Down Expand Up @@ -67,9 +67,8 @@ def _edges_equal(edge1, edge2):
'edge_table',
'inverse_edge_type'
]
for f in fields:
if str(edge1.get(f, None)) != str(edge2.get(f, None)):
return False
if not all(str(edge1.get(f, None)) == str(edge2.get(f, None)) for f in fields):
return False

# sqlite stores 1 as bool. comparing as strings no bueno
return bool(edge1.get('symmetric_edge', None)) == bool(edge2.get('symmetric_edge', None))
Expand Down Expand Up @@ -131,17 +130,17 @@ def _table_exists(autogen_context: AutogenContext):

def _execute_postgres_dialect(connection: sa.engine.Connection):
row = connection.execute(sa.text(
"SELECT to_regclass('%s') IS NOT NULL as exists" % (
"assoc_edge_config"))
"SELECT to_regclass('assoc_edge_config') IS NOT NULL as exists"
)
)
res = row.first()._asdict()
return res["exists"]


def _execute_sqlite_dialect(connection: sa.engine.Connection):
row = connection.execute(sa.text(
"SELECT name FROM sqlite_master WHERE type='table' AND name='%s'" % (
"assoc_edge_config"))
"SELECT name FROM sqlite_master WHERE type='table' AND name='assoc_edge_config'",
)
)
res = row.first()
return res is not None
Expand All @@ -166,7 +165,7 @@ def _create_tuple_key(row, pkeys):
l = []
for key in pkeys:
if not key in row:
raise ValueError("pkey %s was not found in row" % key)
raise ValueError(f"pkey {key} was not found in row")
l.append(row[key])
return tuple(l)

Expand Down Expand Up @@ -215,7 +214,7 @@ def compare_data(autogen_context, upgrade_ops, schemas):

def _compare_db_values(autogen_context, upgrade_ops, table_name, pkeys, data_rows):
connection = autogen_context.connection
query = 'SELECT * FROM %s' % table_name
query = f"SELECT * FROM {table_name}"

db_rows = {}
for row in connection.execute(sa.text(query)):
Expand Down Expand Up @@ -443,7 +442,7 @@ def _compare_indexes(autogen_context: AutogenContext,
modify_table_ops: alembicops.ModifyTableOps,
schema,
tname: str,
conn_table: Optional[sa.Table],
conn_table: sa.Table | None,
metadata_table: sa.Table,
):

Expand Down Expand Up @@ -556,8 +555,8 @@ def _compare_generated_column(autogen_context: AutogenContext,
modify_table_ops: alembicops.ModifyTableOps,
schema,
tname: str,
conn_table: Optional[sa.Table],
metadata_table: Optional[sa.Table],
conn_table: sa.Table | None,
metadata_table: sa.Table | None,
) -> None:

if conn_table is None or metadata_table is None:
Expand Down Expand Up @@ -626,12 +625,12 @@ def _compare_generated_column(autogen_context: AutogenContext,
# sqlalchemy doesn't reflect postgres indexes that have expressions in them so have to manually
# fetch these indices from pg_indices to find them
# warning: "Skipped unsupported reflection of expression-based index accounts_full_text_idx"
def _get_raw_db_indexes(autogen_context: AutogenContext, conn_table: Optional[sa.Table]):
def _get_raw_db_indexes(autogen_context: AutogenContext, conn_table: sa.Table | None):
if conn_table is None or _dialect_name(autogen_context) != 'postgresql':
return {'missing': {}, 'all': {}}

missing = {}
all = {}
all_indices = {}
# we cache the db hit but the table seems to change across the same call and so we're
# just paying the CPU price. can probably be fixed in some way...
names = set([index.name for index in conn_table.indexes] +
Expand All @@ -648,7 +647,7 @@ def _get_raw_db_indexes(autogen_context: AutogenContext, conn_table: Optional[sa
continue
r = m.groups()

all[name] = {
all_indices[name] = {
'postgresql_using': r[1],
'postgresql_using_internals': r[2],
# TODO don't have columns|column to pass to FullTextIndex
Expand All @@ -662,14 +661,14 @@ def _get_raw_db_indexes(autogen_context: AutogenContext, conn_table: Optional[sa
# TODO don't have columns|column to pass to FullTextIndex
}

return {'missing': missing, 'all': all}
return {'missing': missing, 'all': all_indices}


# use a cache so we only hit the db once for each table
# @functools.lru_cache()
def get_db_indexes_for_table(connection: sa.engine.Connection, tname: str):
res = connection.execute(sa.text(
"SELECT indexname, indexdef from pg_indexes where tablename = '%s'" % tname))
f"SELECT indexname, indexdef from pg_indexes where tablename = '{tname}'"))
return res


Expand Down Expand Up @@ -736,7 +735,7 @@ def _compare_server_default_nullable(
modify_table_ops: alembicops.ModifyTableOps,
schema,
tname: str,
conn_table: Optional[sa.Table],
conn_table: sa.Table | None,
metadata_table: sa.Table,
):
if conn_table is None or metadata_table is None:
Expand Down
6 changes: 3 additions & 3 deletions python/auto_schema/auto_schema/csv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
def render_list_csv(l):
str_l = ["'%s'" % (v) for v in l]
str_l = [f"'{v}'" for v in l]
return ", ".join(str_l)


def render_list_csv_as_list(l):
str_l = ["'%s'" % (v) for v in l]
return "[%s]" % ", ".join(str_l)
str_l = [f"'{v}'" for v in l]
return f"[{', '.join(str_l)}]"
63 changes: 30 additions & 33 deletions python/auto_schema/auto_schema/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

import alembic.operations.ops as alembicops

from typing import List, Sequence
from typing import Sequence, defaultdict


class Diff(object):

def __init__(self: Diff, diff: Sequence[alembicops.MigrateOperation], group_by_table=True):
self._diff = diff
self._changes = {}
self._changes = defaultdict(list)
self._changes_list = []
self._group_by_table = group_by_table

self._exec_ops(self._diff)

def list_changes(self: Diff) -> List[Change]:
def list_changes(self: Diff) -> list[Change]:
return self._changes_list

def changes(self: Diff):
Expand Down Expand Up @@ -57,13 +57,13 @@ def _custom_migrate_op(self: Diff, op: MigrateOpInterface):
def _create_table(self: Diff, op: alembicops.CreateTableOp):
self._append_change(op.table_name, {
"change": ChangeType.ADD_TABLE,
"desc": 'add %s table' % op.table_name,
"desc": f"add {op.table_name} table",
})

def _drop_table(self: Diff, op: alembicops.DropTableOp):
self._append_change(op.table_name, {
"change": ChangeType.DROP_TABLE,
"desc": 'drop %s table' % op.table_name,
"desc": f"drop {op.table_name} table",
})

def _modify_table(self: Diff, op: alembicops.ModifyTableOps):
Expand All @@ -72,90 +72,91 @@ def _modify_table(self: Diff, op: alembicops.ModifyTableOps):
def _add_column(self: Diff, op: alembicops.AddColumnOp):
self._append_change(op.table_name, {
"change": ChangeType.ADD_COLUMN,
"desc": "add column %s to table %s" % (op.column.name, op.table_name),
"desc": f"add column {op.column.name} to table {op.table_name}",
"col": op.column.name,
})

def _drop_column(self: Diff, op: alembicops.DropColumnOp):
self._append_change(op.table_name, {
"change": ChangeType.DROP_COLUMN,
"desc": "drop column %s from table %s" % (op.column_name, op.table_name),
"desc": f"drop column {op.column_name} from table {op.table_name}",
"col": op.column_name,
})

def _create_index(self: Diff, op: alembicops.CreateIndexOp):
self._append_change(op.table_name, {
"change": ChangeType.CREATE_INDEX,
"index": op.index_name,
"desc": 'add index %s to %s' % (op.index_name, op.table_name),
"desc": f"add index {op.index_name} to {op.table_name}",
})

def _drop_index(self: Diff, op: alembicops.DropIndexOp):
self._append_change(op.table_name, {
"change": ChangeType.DROP_INDEX,
"index": op.index_name,
"desc": 'drop index %s from %s' % (op.index_name, op.table_name),
"desc": f"drop index {op.index_name} from {op.table_name}",
})

def _create_full_text_index(self: Diff, op: CreateFullTextIndexOp):
self._append_change(op.table_name, {
"change": ChangeType.CREATE_FULL_TEXT_INDEX,
"desc": 'add full text index %s to %s' % (op.index_name, op.table_name),
"desc": f"add full text index {op.index_name} to {op.table_name}",
})

def _drop_full_text_index(self: Diff, op: DropFullTextIndexOp):
self._append_change(op.table_name, {
"change": ChangeType.DROP_FULL_TEXT_INDEX,
"desc": 'drop full text index %s from %s' % (op.index_name, op.table_name),
"desc": f"drop full text index {op.index_name} from {op.table_name}",
})

def _create_foreign_key(self: Diff, op: alembicops.CreateForeignKeyOp):
self._append_change(op.source_table, {
"change": ChangeType.CREATE_FOREIGN_KEY,
"desc": 'create fk constraint %s on %s' % (op.constraint_name, op.source_table),
"desc": f"create fk constraint {op.constraint_name} on {op.source_table}",
})

def _create_unique_constraint(self: Diff, op: alembicops.CreateUniqueConstraintOp):
self._append_change(op.table_name, {
"change": ChangeType.CREATE_UNIQUE_CONSTRAINT,
"desc": 'add unique constraint %s' % op.constraint_name,
"desc": f"add unique constraint {op.constraint_name}",
})

def _drop_constraint(self: Diff, op: alembicops.DropConstraintOp):
self._append_change(op.table_name, {
"change": ChangeType.DROP_CHECK_CONSTRAINT,
"desc": 'drop constraint %s from %s' % (op.constraint_name, op.table_name)
"desc": f"drop constraint {op.constraint_name} from {op.table_name}",
})

def _create_check_constraint(self: Diff, op: alembicops.CreateCheckConstraintOp):
self._append_change(op.table_name, {
"change": ChangeType.CREATE_CHECK_CONSTRAINT,
"desc": 'add constraint %s to %s' % (op.constraint_name, op.table_name)
"desc": f"add constraint {op.constraint_name} to {op.table_name}",
})

def _alter_column(self: Diff, op: alembicops.AlterColumnOp):
def get_desc(op: alembicops.AlterColumnOp):
if op.modify_type is not None:
return 'modify column %s type from %s to %s' % (op.column_name, op.existing_type, op.modify_type)
return f"modify column {op.column_name} type from {op.existing_type} to {op.modify_type}"
elif op.modify_nullable is not None:
return 'modify nullable value of column %s from %s to %s' % (op.column_name, op.existing_nullable, op.modify_nullable)
return f"modify nullable value of column {op.column_name} from {op.existing_nullable} to {op.modify_nullable}"
elif op.modify_server_default is not None:
# these 3 here could flag it to be rendered differently
return 'modify server_default value of column %s from %s to %s' % (
op.column_name,
get_clause_text(op.existing_server_default,
op.existing_type),
get_clause_text(op.modify_server_default, op.modify_type))
existing_clause_text = get_clause_text(
op.existing_server_default,
op.existing_type,
)
modified_clause_text = get_clause_text(op.modify_server_default, op.modify_type)
return f"modify server_default value of column {op.column_name} from {existing_clause_text} to {modified_clause_text}"
elif op.modify_comment:
return "modify comment of column %s"
return f"modify comment of column {op.modify_comment}"
elif op.modify_name:
return "modify name of column %s"
return f"modify name of column {op.modify_name}"
elif op.modify_server_default is None and op.existing_server_default is not None:
return 'modify server_default value of column %s from %s to None' % (
op.column_name,
get_clause_text(op.existing_server_default,
op.existing_type)
existing_clause_text = get_clause_text(
op.existing_server_default,
op.existing_type,
)
return f"modify server_default value of column {op.column_name} from {existing_clause_text} to None"
else:
raise ValueError("unsupported alter_column op")

Expand All @@ -169,8 +170,4 @@ def _append_change(self: Diff, table_name: String, change: Change):
self._changes_list.append(change)
return

changes = []
if table_name in self._changes:
changes = self._changes[table_name]
changes.append(change)
self._changes[table_name] = changes
self._changes[table_name].append(change)
Loading

0 comments on commit 6cb99f4

Please sign in to comment.