From b2f4029a4a5527881a59e530764abd07498b606e Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 3 Dec 2024 12:40:28 +0800 Subject: [PATCH] Improve type hints of inspectdb (#371) --- aerich/inspectdb/__init__.py | 38 +++++++++++++++++++++++++----------- aerich/inspectdb/mysql.py | 10 +++++----- aerich/inspectdb/postgres.py | 14 +++++++------ aerich/inspectdb/sqlite.py | 12 ++++++------ 4 files changed, 46 insertions(+), 28 deletions(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 94c26bf..bc996a1 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -1,9 +1,24 @@ -from typing import Any, List, Optional +from __future__ import annotations + +from typing import Any, Callable, Dict, Optional, TypedDict from pydantic import BaseModel from tortoise import BaseDBAsyncClient +class ColumnInfoDict(TypedDict): + name: str + pk: str + index: str + null: str + default: str + length: str + comment: str + + +FieldMapDict = Dict[str, Callable[..., str]] + + class Column(BaseModel): name: str data_type: str @@ -18,7 +33,7 @@ class Column(BaseModel): decimal_places: Optional[int] = None max_digits: Optional[int] = None - def translate(self) -> dict: + def translate(self) -> ColumnInfoDict: comment = default = length = index = null = pk = "" if self.pk: pk = "pk=True, " @@ -28,23 +43,24 @@ def translate(self) -> dict: else: if self.index: index = "index=True, " - if self.data_type in ["varchar", "VARCHAR"]: + if self.data_type in ("varchar", "VARCHAR"): length = f"max_length={self.length}, " - if self.data_type in ["decimal", "numeric"]: + elif self.data_type in ("decimal", "numeric"): length_parts = [] if self.max_digits: length_parts.append(f"max_digits={self.max_digits}") if self.decimal_places: length_parts.append(f"decimal_places={self.decimal_places}") - length = ", ".join(length_parts)+", " + if length_parts: + length = ", ".join(length_parts) + ", " if self.null: null = "null=True, " if self.default is not None: - if self.data_type in ["tinyint", "INT"]: + if self.data_type in ("tinyint", "INT"): default = f"default={'True' if self.default == '1' else 'False'}, " elif self.data_type == "bool": default = f"default={'True' if self.default == 'true' else 'False'}, " - elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]: + elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"): if "CURRENT_TIMESTAMP" == self.default: if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: default = "auto_now=True, " @@ -76,7 +92,7 @@ def translate(self) -> dict: class Inspect: _table_template = "class {table}(Model):\n" - def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): + def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None: self.conn = conn try: self.database = conn.database # type:ignore[attr-defined] @@ -85,7 +101,7 @@ def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): self.tables = tables @property - def field_map(self) -> dict: + def field_map(self) -> FieldMapDict: raise NotImplementedError async def inspect(self) -> str: @@ -103,10 +119,10 @@ async def inspect(self) -> str: tables.append(model + "\n".join(fields)) return result + "\n\n\n".join(tables) - async def get_columns(self, table: str) -> List[Column]: + async def get_columns(self, table: str) -> list[Column]: raise NotImplementedError - async def get_all_tables(self) -> List[str]: + async def get_all_tables(self) -> list[str]: raise NotImplementedError @classmethod diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index db83c16..1d53f07 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -1,11 +1,11 @@ -from typing import List +from __future__ import annotations -from aerich.inspectdb import Column, Inspect +from aerich.inspectdb import Column, FieldMapDict, Inspect class InspectMySQL(Inspect): @property - def field_map(self) -> dict: + def field_map(self) -> FieldMapDict: return { "int": self.int_field, "smallint": self.smallint_field, @@ -24,12 +24,12 @@ def field_map(self) -> dict: "longblob": self.binary_field, } - async def get_all_tables(self) -> List[str]: + async def get_all_tables(self) -> list[str]: sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" ret = await self.conn.execute_query_dict(sql, [self.database]) return list(map(lambda x: x["TABLE_NAME"], ret)) - async def get_columns(self, table: str) -> List[Column]: + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME from information_schema.COLUMNS c diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index f77bc29..19fa031 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -1,18 +1,20 @@ -from typing import TYPE_CHECKING, List, Optional +from __future__ import annotations -from aerich.inspectdb import Column, Inspect +from typing import TYPE_CHECKING + +from aerich.inspectdb import Column, FieldMapDict, Inspect if TYPE_CHECKING: from tortoise.backends.base_postgres.client import BasePostgresClient class InspectPostgres(Inspect): - def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None: + def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None: super().__init__(conn, tables) self.schema = conn.server_settings.get("schema") or "public" @property - def field_map(self) -> dict: + def field_map(self) -> FieldMapDict: return { "int4": self.int_field, "int8": self.int_field, @@ -34,12 +36,12 @@ def field_map(self) -> dict: "timestamp": self.datetime_field, } - async def get_all_tables(self) -> List[str]: + async def get_all_tables(self) -> list[str]: sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2" ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) return list(map(lambda x: x["table_name"], ret)) - async def get_columns(self, table: str) -> List[Column]: + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"""select c.column_name, col_description('public.{table}'::regclass, ordinal_position) as column_comment, diff --git a/aerich/inspectdb/sqlite.py b/aerich/inspectdb/sqlite.py index 8cafa95..b729c73 100644 --- a/aerich/inspectdb/sqlite.py +++ b/aerich/inspectdb/sqlite.py @@ -1,11 +1,11 @@ -from typing import Callable, Dict, List +from __future__ import annotations -from aerich.inspectdb import Column, Inspect +from aerich.inspectdb import Column, FieldMapDict, Inspect class InspectSQLite(Inspect): @property - def field_map(self) -> Dict[str, Callable[..., str]]: + def field_map(self) -> FieldMapDict: return { "INTEGER": self.int_field, "INT": self.bool_field, @@ -21,7 +21,7 @@ def field_map(self) -> Dict[str, Callable[..., str]]: "BLOB": self.binary_field, } - async def get_columns(self, table: str) -> List[Column]: + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"PRAGMA table_info({table})" ret = await self.conn.execute_query_dict(sql) @@ -45,7 +45,7 @@ async def get_columns(self, table: str) -> List[Column]: ) return columns - async def _get_columns_index(self, table: str) -> Dict[str, str]: + async def _get_columns_index(self, table: str) -> dict[str, str]: sql = f"PRAGMA index_list ({table})" indexes = await self.conn.execute_query_dict(sql) ret = {} @@ -55,7 +55,7 @@ async def _get_columns_index(self, table: str) -> Dict[str, str]: ret[index_info["name"]] = "unique" if index["unique"] else "index" return ret - async def get_all_tables(self) -> List[str]: + async def get_all_tables(self) -> list[str]: sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" ret = await self.conn.execute_query_dict(sql) return list(map(lambda x: x["tbl_name"], ret))