Skip to content

Commit

Permalink
Merge pull request #342 from waketzheng/type-hint-ddl
Browse files Browse the repository at this point in the history
Improve type hints for ddl and inspectdb
  • Loading branch information
long2ice authored Jun 6, 2024
2 parents 79a77d3 + c7a3d16 commit e6302a9
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 86 deletions.
137 changes: 66 additions & 71 deletions aerich/ddl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Type
from typing import Any, List, Type, cast

from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
Expand Down Expand Up @@ -35,25 +35,26 @@ class BaseDDL:
)
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'

def __init__(self, client: "BaseDBAsyncClient"):
def __init__(self, client: "BaseDBAsyncClient") -> None:
self.client = client
self.schema_generator = self.schema_generator_cls(client)

def create_table(self, model: "Type[Model]"):
def create_table(self, model: "Type[Model]") -> str:
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
";"
)

def drop_table(self, table_name: str):
def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)

def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
) -> str:
through = cast(str, field_describe.get("through"))
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format(
table_name=through,
backward_table=model._meta.db_table,
Expand All @@ -73,15 +74,15 @@ def create_m2m(
),
)

def drop_m2m(self, table_name: str):
def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)

def _get_default(self, model: "Type[Model]", field_describe: dict):
def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any:
db_table = model._meta.db_table
default = field_describe.get("default")
if isinstance(default, Enum):
default = default.value
db_column = field_describe.get("db_column")
db_column = cast(str, field_describe.get("db_column"))
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
Expand All @@ -106,21 +107,30 @@ def _get_default(self, model: "Type[Model]", field_describe: dict):
default = None
return default

def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk)

def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str:
db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
db_column = cast(str, field_describe.get("db_column"))
db_field_types = cast(dict, field_describe.get("db_field_types"))
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format(
if modify:
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
unique = "UNIQUE" if field_describe.get("unique") else ""
template = self._ADD_COLUMN_TEMPLATE
return template.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "",
unique=unique,
comment=(
self.schema_generator._column_comment_generator(
table=db_table,
Expand All @@ -135,39 +145,17 @@ def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = F
),
)

def drop_column(self, model: "Type[Model]", column_name: str):
def drop_column(self, model: "Type[Model]", column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name
)

def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=(
self.schema_generator._column_comment_generator(
table=db_table,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_describe.get("description")
else ""
),
is_primary_key=is_pk,
default=default,
),
)
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk, modify=True)

def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
def rename_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
Expand All @@ -176,15 +164,15 @@ def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_n

def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
new_column_name=new_column_name,
new_column_type=new_column_type,
)

def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name(
Expand All @@ -194,53 +182,60 @@ def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)

def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table,
)

def drop_index_by_name(self, model: "Type[Model]", index_name: str):
def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)

def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table

db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column")
fk_name = self.schema_generator._generate_fk_name(
def _generate_fk_name(
self, db_table, field_describe: dict, reference_table_describe: dict
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
return self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
to_table=to_table,
to_field=to_field,
)

def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table

db_column = field_describe.get("raw_field")
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe),
db_column=db_column,
table=reference_table_describe.get("table"),
field=reference_id,
on_delete=field_describe.get("on_delete"),
)

def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
def drop_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format(
table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)

def alter_column_default(self, model: "Type[Model]", field_describe: dict):
def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
Expand All @@ -249,13 +244,13 @@ def alter_column_default(self, model: "Type[Model]", field_describe: dict):
default="SET" + default if default is not None else "DROP DEFAULT",
)

def alter_column_null(self, model: "Type[Model]", field_describe: dict):
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe)

def set_comment(self, model: "Type[Model]", field_describe: dict):
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe)

def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
Expand Down
10 changes: 5 additions & 5 deletions aerich/ddl/postgres/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type
from typing import Type, cast

from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
Expand All @@ -18,17 +18,17 @@ class PostgresDDL(BaseDDL):
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'

def alter_column_null(self, model: "Type[Model]", field_describe: dict):
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column"),
set_drop="DROP" if field_describe.get("nullable") else "SET",
)

def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format(
Expand All @@ -38,7 +38,7 @@ def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool
using=f' USING "{db_column}"::{datatype}',
)

def set_comment(self, model: "Type[Model]", field_describe: dict):
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table,
Expand Down
2 changes: 1 addition & 1 deletion aerich/inspectdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Inspect:
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
self.database = conn.database # type:ignore[attr-defined]
except AttributeError:
pass
self.tables = tables
Expand Down
3 changes: 2 additions & 1 deletion aerich/inspectdb/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ async def get_columns(self, table: str) -> List[Column]:
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
unque=unique,
# TODO: why `unque`?
unque=unique, # type:ignore
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
Expand Down
11 changes: 6 additions & 5 deletions aerich/inspectdb/postgres.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import List, Optional

from tortoise import BaseDBAsyncClient
from typing import TYPE_CHECKING, List, Optional

from aerich.inspectdb import Column, Inspect

if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient


class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
self.schema = conn.server_settings.get("schema") or "public"

@property
def field_map(self) -> dict:
Expand Down
6 changes: 3 additions & 3 deletions aerich/inspectdb/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List
from typing import Callable, Dict, List

from aerich.inspectdb import Column, Inspect


class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
def field_map(self) -> Dict[str, Callable[..., str]]:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
Expand Down Expand Up @@ -45,7 +45,7 @@ async def get_columns(self, table: str) -> List[Column]:
)
return columns

async def _get_columns_index(self, table: str):
async def _get_columns_index(self, table: str) -> Dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
Expand Down

0 comments on commit e6302a9

Please sign in to comment.