From b4f21500c486fe822d3651b2613fade8775a0cbb Mon Sep 17 00:00:00 2001 From: Mila Page Date: Wed, 12 Jun 2024 20:05:51 -0700 Subject: [PATCH] More comments on types and lint. --- dbt/adapters/spark/impl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 16a9ce665..d33ebde20 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -136,6 +136,8 @@ def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: @classmethod def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: + import agate + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @@ -158,7 +160,7 @@ def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: def quote(self, identifier: str) -> str: return "`{}`".format(identifier) - def _get_relation_information(self, row: agate.Row) -> RelationInfo: + def _get_relation_information(self, row: "agate.Row") -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: _schema, name, _, information = row @@ -169,7 +171,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: return _schema, name, information - def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: + def _get_relation_information_using_describe(self, row: "agate.Row") -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" try: _schema, name, _ = row @@ -198,7 +200,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn def _build_spark_relation_list( self, row_list: "agate.Table", - relation_info_func: Callable[[agate.Row], RelationInfo], + relation_info_func: Callable[["agate.Row"], RelationInfo], ) -> List[BaseRelation]: """Aggregate relations with format metadata included.""" relations = [] @@ -404,7 +406,6 @@ def _get_one_catalog( schemas: Set[str], used_schemas: FrozenSet[Tuple[str, str]], ) -> "agate.Table": - import agate if len(schemas) != 1: raise CompilationError( f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" @@ -417,6 +418,9 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) + + import agate + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database: str, schema: str) -> bool: