diff --git a/.changes/unreleased/Under the Hood-20240612-195629.yaml b/.changes/unreleased/Under the Hood-20240612-195629.yaml new file mode 100644 index 000000000..c90ebcdab --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240612-195629.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Lazy load agate to improve performance +time: 2024-06-12T19:56:29.943204-07:00 +custom: + Author: versusfacit + Issue: "1049" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 255ab7806..d33ebde20 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -14,6 +14,7 @@ Callable, Set, FrozenSet, + TYPE_CHECKING, ) from dbt.adapters.base.relation import InformationSchema @@ -24,7 +25,10 @@ from typing_extensions import TypeAlias -import agate +if TYPE_CHECKING: + # Indirectly imported via agate_helper, which is lazy loaded further downfile. + # Used by mypy for earlier type hints. + import agate from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport @@ -127,34 +131,36 @@ def date_function(cls) -> str: return "current_timestamp()" @classmethod - def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "string" @classmethod - def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: + import agate + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @classmethod - def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "bigint" @classmethod - def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "date" @classmethod - def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "time" @classmethod - def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: + def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "timestamp" def quote(self, identifier: str) -> str: return "`{}`".format(identifier) - def _get_relation_information(self, row: agate.Row) -> RelationInfo: + def _get_relation_information(self, row: "agate.Row") -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: _schema, name, _, information = row @@ -165,7 +171,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: return _schema, name, information - def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: + def _get_relation_information_using_describe(self, row: "agate.Row") -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" try: _schema, name, _ = row @@ -193,8 +199,8 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn def _build_spark_relation_list( self, - row_list: agate.Table, - relation_info_func: Callable[[agate.Row], RelationInfo], + row_list: "agate.Table", + relation_info_func: Callable[["agate.Row"], RelationInfo], ) -> List[BaseRelation]: """Aggregate relations with format metadata included.""" relations = [] @@ -370,7 +376,7 @@ def get_catalog( self, relation_configs: Iterable[RelationConfig], used_schemas: FrozenSet[Tuple[str, str]], - ) -> Tuple[agate.Table, List[Exception]]: + ) -> Tuple["agate.Table", List[Exception]]: schema_map = self._get_catalog_schemas(relation_configs) if len(schema_map) > 1: raise CompilationError( @@ -378,7 +384,7 @@ def get_catalog( ) with executor(self.config) as tpe: - futures: List[Future[agate.Table]] = [] + futures: List[Future["agate.Table"]] = [] for info, schemas in schema_map.items(): for schema in schemas: futures.append( @@ -399,7 +405,7 @@ def _get_one_catalog( information_schema: InformationSchema, schemas: Set[str], used_schemas: FrozenSet[Tuple[str, str]], - ) -> agate.Table: + ) -> "agate.Table": if len(schemas) != 1: raise CompilationError( f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" @@ -412,6 +418,9 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) + + import agate + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database: str, schema: str) -> bool: @@ -486,7 +495,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: "all_purpose_cluster": AllPurposeClusterPythonJobHelper, } - def standardize_grants_dict(self, grants_table: agate.Table) -> dict: + def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: grants_dict: Dict[str, List[str]] = {} for row in grants_table: grantee = row["Principal"]