Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adap 1049/lazy load agate #1050

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240612-195629.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load agate to improve performance
time: 2024-06-12T19:56:29.943204-07:00
custom:
Author: versusfacit
Issue: "1049"
29 changes: 17 additions & 12 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Callable,
Set,
FrozenSet,
TYPE_CHECKING,
)

from dbt.adapters.base.relation import InformationSchema
Expand All @@ -24,7 +25,10 @@

from typing_extensions import TypeAlias

import agate
if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport
Expand Down Expand Up @@ -127,28 +131,28 @@ def date_function(cls) -> str:
return "current_timestamp()"

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "string"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "double" if decimals else "bigint"

@classmethod
def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "bigint"

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "time"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "timestamp"

def quote(self, identifier: str) -> str:
Expand Down Expand Up @@ -193,7 +197,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn

def _build_spark_relation_list(
self,
row_list: agate.Table,
row_list: "agate.Table",
relation_info_func: Callable[[agate.Row], RelationInfo],
) -> List[BaseRelation]:
"""Aggregate relations with format metadata included."""
Expand Down Expand Up @@ -370,15 +374,15 @@ def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple[agate.Table, List[Exception]]:
) -> Tuple["agate.Table", List[Exception]]:
schema_map = self._get_catalog_schemas(relation_configs)
if len(schema_map) > 1:
raise CompilationError(
f"Expected only one database in get_catalog, found " f"{list(schema_map)}"
)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
futures: List[Future["agate.Table"]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(
Expand All @@ -399,7 +403,8 @@ def _get_one_catalog(
information_schema: InformationSchema,
schemas: Set[str],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
) -> "agate.Table":
import agate
if len(schemas) != 1:
raise CompilationError(
f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}"
Expand Down Expand Up @@ -486,7 +491,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
"all_purpose_cluster": AllPurposeClusterPythonJobHelper,
}

def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
grants_dict: Dict[str, List[str]] = {}
for row in grants_table:
grantee = row["Principal"]
Expand Down
Loading