Skip to content

Commit

Permalink
Adap 1162/merge agate lazy load (#1250)
Browse files Browse the repository at this point in the history
* lazy load agate

* Add test and documentation.

* Fix test.

* Don't need a test for this.

---------

Co-authored-by: dwreeves <[email protected]>
Co-authored-by: Mila Page <[email protected]>
  • Loading branch information
3 people authored Jun 13, 2024
1 parent 995ebcb commit e678489
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240331-101418.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load `agate`
time: 2024-03-31T10:14:18.260074-04:00
custom:
Author: dwreeves
Issue: "1162"
16 changes: 11 additions & 5 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from mashumaro.helper import pass_through

from functools import lru_cache
import agate
from requests.exceptions import ConnectionError
from typing import Optional, Any, Dict, Tuple
from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING

import google.auth
import google.auth.exceptions
Expand All @@ -26,7 +25,6 @@
)

from dbt.adapters.bigquery import gcloud
from dbt_common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
from dbt_common.exceptions import (
DbtRuntimeError,
Expand All @@ -44,6 +42,10 @@

from dbt_common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate

logger = AdapterLogger("BigQuery")

Expand Down Expand Up @@ -432,7 +434,9 @@ def get_job_retry_deadline_seconds(cls, conn):
return credentials.job_retry_deadline_seconds

@classmethod
def get_table_from_response(cls, resp):
def get_table_from_response(cls, resp) -> "agate.Table":
from dbt_common.clients import agate_helper

column_names = [field.name for field in resp.schema]
return agate_helper.table_from_data_flat(resp, column_names)

Expand Down Expand Up @@ -499,14 +503,16 @@ def fn():

def execute(
self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None
) -> Tuple[BigQueryAdapterResponse, agate.Table]:
) -> Tuple[BigQueryAdapterResponse, "agate.Table"]:
sql = self._add_query_comment(sql)
# auto_begin is ignored on bigquery, and only included for consistency
query_job, iterator = self.raw_execute(sql, limit=limit)

if fetch:
table = self.get_table_from_response(iterator)
else:
from dbt_common.clients import agate_helper

table = agate_helper.empty_table()

message = "OK"
Expand Down
52 changes: 37 additions & 15 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
from multiprocessing.context import SpawnContext

import time
from typing import Any, Dict, List, Optional, Type, Set, Union, FrozenSet, Tuple, Iterable
from typing import (
Any,
Dict,
List,
Optional,
Type,
Set,
Union,
FrozenSet,
Tuple,
Iterable,
TYPE_CHECKING,
)

import agate
from dbt.adapters.contracts.relation import RelationConfig

import dbt_common.exceptions.base
Expand All @@ -24,7 +35,6 @@
from dbt.adapters.base.impl import FreshnessResponse
from dbt.adapters.cache import _make_ref_key_dict # type: ignore
from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support
import dbt_common.clients.agate_helper
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt_common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
Expand Down Expand Up @@ -58,6 +68,10 @@
)
from dbt.adapters.bigquery.utility import sql_escape

if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate

logger = AdapterLogger("BigQuery")

Expand Down Expand Up @@ -334,32 +348,34 @@ def quote(cls, identifier: str) -> str:
return "`{}`".format(identifier)

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "string"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
import agate

decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) # type: ignore[attr-defined]
return "float64" if decimals else "int64"

@classmethod
def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "int64"

@classmethod
def convert_boolean_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_boolean_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "bool"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "datetime"

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "time"

###
Expand Down Expand Up @@ -387,7 +403,7 @@ def _get_dbt_columns_from_bq_table(self, table) -> List[BigQueryColumn]:
return columns

def _agate_to_schema(
self, agate_table: agate.Table, column_override: Dict[str, str]
self, agate_table: "agate.Table", column_override: Dict[str, str]
) -> List[SchemaField]:
"""Convert agate.Table with column names to a list of bigquery schemas."""
bq_schema = []
Expand Down Expand Up @@ -655,7 +671,13 @@ def alter_table_add_columns(self, relation, columns):

@available.parse_none
def load_dataframe(
self, database, schema, table_name, agate_table, column_override, field_delimiter
self,
database,
schema,
table_name,
agate_table: "agate.Table",
column_override,
field_delimiter,
):
bq_schema = self._agate_to_schema(agate_table, column_override)
conn = self.connections.get_thread_connection()
Expand All @@ -667,7 +689,7 @@ def load_dataframe(
load_config.skip_leading_rows = 1
load_config.schema = bq_schema
load_config.field_delimiter = field_delimiter
with open(agate_table.original_abspath, "rb") as f:
with open(agate_table.original_abspath, "rb") as f: # type: ignore
job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config)

timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300
Expand Down Expand Up @@ -699,8 +721,8 @@ def upload_file(

@classmethod
def _catalog_filter_table(
cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]]
) -> agate.Table:
cls, table: "agate.Table", used_schemas: FrozenSet[Tuple[str, str]]
) -> "agate.Table":
table = table.rename(
column_names={col.name: col.name.replace("__", ":") for col in table.columns}
)
Expand Down
12 changes: 9 additions & 3 deletions dbt/adapters/bigquery/relation_configs/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Dict
from typing import Optional, Dict, TYPE_CHECKING

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.relation_configs import RelationConfigBase
from google.cloud.bigquery import Table as BigQueryTable
Expand All @@ -13,6 +12,11 @@
)
from dbt.adapters.contracts.relation import ComponentName, RelationConfig

if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class BigQueryBaseRelationConfig(RelationConfigBase):
Expand Down Expand Up @@ -55,8 +59,10 @@ def _render_part(cls, component: ComponentName, value: Optional[str]) -> Optiona
return None

@classmethod
def _get_first_row(cls, results: agate.Table) -> agate.Row:
def _get_first_row(cls, results: "agate.Table") -> "agate.Row":
try:
return results.rows[0]
except IndexError:
import agate

return agate.Row(values=set())

0 comments on commit e678489

Please sign in to comment.