Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adap 1049/lazy load agate #1050

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240612-195629.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load agate to improve performance
time: 2024-06-12T19:56:29.943204-07:00
custom:
Author: versusfacit
Issue: "1049"
39 changes: 24 additions & 15 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Callable,
Set,
FrozenSet,
TYPE_CHECKING,
)

from dbt.adapters.base.relation import InformationSchema
Expand All @@ -24,7 +25,10 @@

from typing_extensions import TypeAlias

import agate
if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport
Expand Down Expand Up @@ -127,34 +131,36 @@ def date_function(cls) -> str:
return "current_timestamp()"

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "string"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
import agate

decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "double" if decimals else "bigint"

@classmethod
def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "bigint"

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "time"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "timestamp"

def quote(self, identifier: str) -> str:
return "`{}`".format(identifier)

def _get_relation_information(self, row: agate.Row) -> RelationInfo:
def _get_relation_information(self, row: "agate.Row") -> RelationInfo:
"""relation info was fetched with SHOW TABLES EXTENDED"""
try:
_schema, name, _, information = row
Expand All @@ -165,7 +171,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo:

return _schema, name, information

def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo:
def _get_relation_information_using_describe(self, row: "agate.Row") -> RelationInfo:
"""Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement"""
try:
_schema, name, _ = row
Expand Down Expand Up @@ -193,8 +199,8 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn

def _build_spark_relation_list(
self,
row_list: agate.Table,
relation_info_func: Callable[[agate.Row], RelationInfo],
row_list: "agate.Table",
relation_info_func: Callable[["agate.Row"], RelationInfo],
) -> List[BaseRelation]:
"""Aggregate relations with format metadata included."""
relations = []
Expand Down Expand Up @@ -370,15 +376,15 @@ def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple[agate.Table, List[Exception]]:
) -> Tuple["agate.Table", List[Exception]]:
schema_map = self._get_catalog_schemas(relation_configs)
if len(schema_map) > 1:
raise CompilationError(
f"Expected only one database in get_catalog, found " f"{list(schema_map)}"
)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
futures: List[Future["agate.Table"]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(
Expand All @@ -399,7 +405,7 @@ def _get_one_catalog(
information_schema: InformationSchema,
schemas: Set[str],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
) -> "agate.Table":
if len(schemas) != 1:
raise CompilationError(
f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}"
Expand All @@ -412,6 +418,9 @@ def _get_one_catalog(
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", str(relation))
columns.extend(self._get_columns_for_catalog(relation))

import agate

return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database: str, schema: str) -> bool:
Expand Down Expand Up @@ -486,7 +495,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
"all_purpose_cluster": AllPurposeClusterPythonJobHelper,
}

def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
grants_dict: Dict[str, List[str]] = {}
for row in grants_table:
grantee = row["Principal"]
Expand Down
Loading