diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 37de188c5..62e5d0a2a 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -58,6 +58,14 @@ class SparkConfig(AdapterConfig): merge_update_columns: Optional[str] = None +@dataclass(frozen=True) +class RelationInfo: + table_schema: str + table_name: str + columns: List[Tuple[str, str]] + properties: Dict[str, str] + + class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( "table_database", @@ -79,9 +87,7 @@ class SparkAdapter(SQLAdapter): "stats:rows:description", "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) - INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_COLUMN_REGEX = re.compile(r" \|-- (.*): (.*) \(nullable = (.*)\)") HUDI_METADATA_COLUMNS = [ "_hoodie_commit_time", "_hoodie_commit_seqno", @@ -91,7 +97,6 @@ class SparkAdapter(SQLAdapter): ] Relation: TypeAlias = SparkRelation - RelationInfo = Tuple[str, str, str] Column: TypeAlias = SparkColumn ConnectionManager: TypeAlias = SparkConnectionManager AdapterSpecificConfigs: TypeAlias = SparkConfig @@ -139,13 +144,42 @@ def add_schema_to_cache(self, schema) -> str: def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: - _schema, name, _, information = row + # Example lines: + # Database: dbt_schema + # Table: names + # Owner: fokkodriesprong + # Created Time: Mon May 08 18:06:47 CEST 2023 + # Last Access: UNKNOWN + # Created By: Spark 3.3.2 + # Type: MANAGED + # Provider: hive + # Table Properties: [transient_lastDdlTime=1683562007] + # Statistics: 16 bytes + # Schema: root + # |-- idx: integer (nullable = false) + # |-- name: string (nullable = false) + table_properties = {} + columns = [] + _schema, name, _, information_blob = row + for line in information_blob.split("\n"): + if line: + if line.startswith(" |--"): + # A column + m = self.INFORMATION_COLUMN_REGEX.match(line) + columns.append( + (m[1], m[2]) + ) + else: + # A property + parts = line.split(": ", maxsplit=2) + table_properties[parts[0]] = parts[1] + except ValueError: raise dbt.exceptions.DbtRuntimeError( f'Invalid value from "show tables extended ...", got {len(row)} values, expected 4' ) - return _schema, name, information + return RelationInfo(_schema, name, columns, table_properties) def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" @@ -165,13 +199,49 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - information = "" - for info_row in table_results: - info_type, info_value, _ = info_row - if not info_type.startswith("#"): - information += f"{info_type}: {info_value}\n" + # idx int + # name string + # + # # Partitioning + # Not partitioned + # + # # Metadata Columns + # _spec_id int + # _partition struct<> + # _file string + # _pos bigint + # _deleted boolean + # + # # Detailed Table Information + # Name sandbox.dbt_tabular3.names + # Location s3://tabular-wh-us-east-1/6efbcaf4-21ae-499d-b340-3bc1a7003f52/d2082e32-d2bd-4484-bb93-7bc445c1c6bb + # Provider iceberg + + # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns + table_results_itr = iter(table_results) + + # First the columns + columns = [] + for info_row in table_results_itr: + if info_row[0] == '': + break + columns.append( + (info_row[0], info_row[1]) + ) - return _schema, name, information + # Next all the properties + table_properties = {} + for info_row in table_results_itr: + info_type, info_value, _ = info_row + if not info_type.startswith("#") and info_type != '': + table_properties[info_type] = info_value + + return RelationInfo( + _schema, + name, + columns, + table_properties + ) def _build_spark_relation_list( self, @@ -181,23 +251,24 @@ def _build_spark_relation_list( """Aggregate relations with format metadata included.""" relations = [] for row in row_list: - _schema, name, information = relation_info_func(row) + relation = relation_info_func(row) rel_type: RelationType = ( - RelationType.View if "Type: VIEW" in information else RelationType.Table + RelationType.View if relation.properties.get("type") == "VIEW" else RelationType.Table ) - is_delta: bool = "Provider: delta" in information - is_hudi: bool = "Provider: hudi" in information - is_iceberg: bool = "Provider: iceberg" in information + is_delta: bool = relation.properties.get("provider") == "delta" + is_hudi: bool = relation.properties.get("provider") == "hudi" + is_iceberg: bool = relation.properties.get("provider") == "iceberg" relation: BaseRelation = self.Relation.create( # type: ignore - schema=_schema, - identifier=name, + schema=relation.table_schema, + identifier=relation.table_name, type=rel_type, - information=information, is_delta=is_delta, is_iceberg=is_iceberg, is_hudi=is_hudi, + columns=relation.columns, + properties=relation.properties, ) relations.append(relation) @@ -250,19 +321,10 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) def parse_describe_extended( - self, relation: BaseRelation, raw_rows: AttrDict + self, relation: SparkRelation, raw_rows: AttrDict ) -> List[SparkColumn]: # Convert the Row to a dict - dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] - # Find the separator between the rows and the metadata provided - # by the DESCRIBE TABLE EXTENDED statement - pos = self.find_table_information_separator(dict_rows) - - # Remove rows that start with a hash, they are comments - rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] - metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} - - raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) + raw_table_stats = relation.properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) return [ SparkColumn( @@ -270,24 +332,15 @@ def parse_describe_extended( table_schema=relation.schema, table_name=relation.name, table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_owner=relation.properties.get(KEY_TABLE_OWNER, ""), table_stats=table_stats, - column=column["col_name"], + column=column_name, column_index=idx, - dtype=column["data_type"], + dtype=column_type, ) - for idx, column in enumerate(rows) + for idx, (column_name, column_type) in enumerate(relation.columns) ] - @staticmethod - def find_table_information_separator(rows: List[dict]) -> int: - pos = 0 - for row in rows: - if not row["col_name"] or row["col_name"].startswith("#"): - break - pos += 1 - return pos - def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: columns = [] try: @@ -309,20 +362,11 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns - def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkColumn]: - if hasattr(relation, "information"): - information = relation.information or "" - else: - information = "" - owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information) - owner = owner_match[0] if owner_match else None - matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information) + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: + owner = relation.properties.get(KEY_TABLE_OWNER, "") columns = [] - stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, information) - raw_table_stats = stats_match[0] if stats_match else None - table_stats = SparkColumn.convert_table_stats(raw_table_stats) - for match_num, match in enumerate(matches): - column_name, column_type, nullable = match.groups() + table_stats = SparkColumn.convert_table_stats(relation.properties.get(KEY_TABLE_STATISTICS)) + for match_num, (column_name, column_type) in enumerate(relation.columns): column = SparkColumn( table_database=None, table_schema=relation.schema, diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index f5a3e3e15..164b41809 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import Optional, TypeVar, List, Tuple, Dict from dataclasses import dataclass, field from dbt.adapters.base.relation import BaseRelation, Policy @@ -33,8 +33,8 @@ class SparkRelation(BaseRelation): is_delta: Optional[bool] = None is_hudi: Optional[bool] = None is_iceberg: Optional[bool] = None - # TODO: make this a dict everywhere - information: Optional[str] = None + columns: List[Tuple[str, str]] = field(default_factory=list) + properties: Dict[str, str] = field(default_factory=dict) def __post_init__(self): if self.database != self.schema and self.database: