Skip to content

Commit

Permalink
Improve type hints of inspectdb (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng authored Dec 3, 2024
1 parent 4e46d9d commit b2f4029
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 28 deletions.
38 changes: 27 additions & 11 deletions aerich/inspectdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from typing import Any, List, Optional
from __future__ import annotations

from typing import Any, Callable, Dict, Optional, TypedDict

from pydantic import BaseModel
from tortoise import BaseDBAsyncClient


class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str


FieldMapDict = Dict[str, Callable[..., str]]


class Column(BaseModel):
name: str
data_type: str
Expand All @@ -18,7 +33,7 @@ class Column(BaseModel):
decimal_places: Optional[int] = None
max_digits: Optional[int] = None

def translate(self) -> dict:
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
Expand All @@ -28,23 +43,24 @@ def translate(self) -> dict:
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]:
elif self.data_type in ("decimal", "numeric"):
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)+", "
if length_parts:
length = ", ".join(length_parts) + ", "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
Expand Down Expand Up @@ -76,7 +92,7 @@ def translate(self) -> dict:
class Inspect:
_table_template = "class {table}(Model):\n"

def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn
try:
self.database = conn.database # type:ignore[attr-defined]
Expand All @@ -85,7 +101,7 @@ def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.tables = tables

@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
raise NotImplementedError

async def inspect(self) -> str:
Expand All @@ -103,10 +119,10 @@ async def inspect(self) -> str:
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)

async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError

async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
raise NotImplementedError

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions aerich/inspectdb/mysql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List
from __future__ import annotations

from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect


class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
Expand All @@ -24,12 +24,12 @@ def field_map(self) -> dict:
"longblob": self.binary_field,
}

async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))

async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
Expand Down
14 changes: 8 additions & 6 deletions aerich/inspectdb/postgres.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, List, Optional
from __future__ import annotations

from aerich.inspectdb import Column, Inspect
from typing import TYPE_CHECKING

from aerich.inspectdb import Column, FieldMapDict, Inspect

if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient


class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public"

@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int4": self.int_field,
"int8": self.int_field,
Expand All @@ -34,12 +36,12 @@ def field_map(self) -> dict:
"timestamp": self.datetime_field,
}

async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))

async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
Expand Down
12 changes: 6 additions & 6 deletions aerich/inspectdb/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Callable, Dict, List
from __future__ import annotations

from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect


class InspectSQLite(Inspect):
@property
def field_map(self) -> Dict[str, Callable[..., str]]:
def field_map(self) -> FieldMapDict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
Expand All @@ -21,7 +21,7 @@ def field_map(self) -> Dict[str, Callable[..., str]]:
"BLOB": self.binary_field,
}

async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
Expand All @@ -45,7 +45,7 @@ async def get_columns(self, table: str) -> List[Column]:
)
return columns

async def _get_columns_index(self, table: str) -> Dict[str, str]:
async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
Expand All @@ -55,7 +55,7 @@ async def _get_columns_index(self, table: str) -> Dict[str, str]:
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret

async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

0 comments on commit b2f4029

Please sign in to comment.