Skip to content

Commit

Permalink
Migrate to dbt-adapter and common (#1071)
Browse files Browse the repository at this point in the history
* use dynamic schema in test_grant_access_to.py

* use dynamic schema in test_grant_access_to.py

* revert setup

* replace dbt.common with dbt_common

* add dbt-adapters

* delete dbt/adapters

* fix Credentials import and test fixtures

* remove global exceptions import
  • Loading branch information
colin-rogers-dbt authored Jan 22, 2024
1 parent f2804c0 commit e86609a
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 62 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240116-154305.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Migrate to dbt-common and dbt-adapters package
time: 2024-01-16T15:43:05.046735-08:00
custom:
Author: colin-rogers-dbt
Issue: "1071"
18 changes: 9 additions & 9 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt.common.invocation import get_invocation_id
from dbt_common.invocation import get_invocation_id

from dbt.common.events.contextvars import get_node_info
from dbt_common.events.contextvars import get_node_info
from mashumaro.helper import pass_through

from functools import lru_cache
Expand All @@ -27,21 +27,21 @@
)

from dbt.adapters.bigquery import gcloud
from dbt.common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse
from dbt.common.exceptions import (
from dbt_common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
from dbt_common.exceptions import (
DbtRuntimeError,
DbtConfigError,
)
from dbt.common.exceptions import DbtDatabaseError
from dbt_common.exceptions import DbtDatabaseError
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.base import BaseConnectionManager, Credentials
from dbt.adapters.base import BaseConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import SQLQuery
from dbt.common.events.functions import fire_event
from dbt_common.events.functions import fire_event
from dbt.adapters.bigquery import __version__ as dbt_version

from dbt.common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum
from dbt_common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

logger = AdapterLogger("BigQuery")

Expand Down
7 changes: 4 additions & 3 deletions dbt/adapters/bigquery/gcloud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.events.logging import AdapterLogger
import dbt.common.exceptions
from dbt.common.clients.system import run_cmd
from dbt_common.clients.system import run_cmd

NOT_INSTALLED_MSG = """
dbt requires the gcloud SDK to be installed to authenticate with BigQuery.
Expand All @@ -25,4 +26,4 @@ def setup_default_credentials():
if gcloud_installed():
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
else:
raise dbt.common.exceptions.DbtRuntimeError(NOT_INSTALLED_MSG)
raise DbtRuntimeError(NOT_INSTALLED_MSG)
30 changes: 15 additions & 15 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import agate
from dbt.adapters.contracts.relation import RelationConfig

import dbt.common.exceptions.base
import dbt_common.exceptions.base
from dbt.adapters.base import ( # type: ignore
AdapterConfig,
BaseAdapter,
Expand All @@ -21,15 +21,15 @@
available,
)
from dbt.adapters.cache import _make_ref_key_dict # type: ignore
import dbt.common.clients.agate_helper
import dbt_common.clients.agate_helper
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt.common.dataclass_schema import dbtClassMixin
from dbt_common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt_common.dataclass_schema import dbtClassMixin
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.events.functions import fire_event
from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import SchemaCreation, SchemaDrop
import dbt.common.exceptions
from dbt.common.utils import filter_null_values
import dbt_common.exceptions
from dbt_common.utils import filter_null_values
import google.api_core
import google.auth
import google.oauth2
Expand Down Expand Up @@ -147,7 +147,7 @@ def drop_relation(self, relation: BigQueryRelation) -> None:
conn.handle.delete_table(table_ref, not_found_ok=True)

def truncate_relation(self, relation: BigQueryRelation) -> None:
raise dbt.common.exceptions.base.NotImplementedError(
raise dbt_common.exceptions.base.NotImplementedError(
"`truncate` is not implemented for this adapter!"
)

Expand All @@ -164,7 +164,7 @@ def rename_relation(
or from_relation.type == RelationType.View
or to_relation.type == RelationType.View
):
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
"Renaming of views is not currently supported in BigQuery"
)

Expand Down Expand Up @@ -390,7 +390,7 @@ def copy_table(self, source, destination, materialization):
elif materialization == "table":
write_disposition = WRITE_TRUNCATE
else:
raise dbt.common.exceptions.CompilationError(
raise dbt_common.exceptions.CompilationError(
'Copy table materialization must be "copy" or "table", but '
f"config.get('copy_materialization', 'table') was "
f"{materialization}"
Expand Down Expand Up @@ -437,11 +437,11 @@ def poll_until_job_completes(cls, job, timeout):
job.reload()

if job.state != "DONE":
raise dbt.common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")
raise dbt_common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")

elif job.error_result:
message = "\n".join(error["message"].strip() for error in job.errors)
raise dbt.common.exceptions.DbtRuntimeError(message)
raise dbt_common.exceptions.DbtRuntimeError(message)

def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]:
if bq_table is None:
Expand All @@ -465,7 +465,7 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
if self.nice_connection_name() in ["on-run-start", "on-run-end"]:
self.warning_on_hooks(self.nice_connection_name())
else:
raise dbt.common.exceptions.base.NotImplementedError(
raise dbt_common.exceptions.base.NotImplementedError(
"`add_query` is not implemented for this adapter!"
)

Expand Down Expand Up @@ -777,7 +777,7 @@ def describe_relation(
bq_table = self.get_bq_table(relation)
parser = BigQueryMaterializedViewConfig
else:
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
f"The method `BigQueryAdapter.describe_relation` is not implemented "
f"for the relation type: {relation.type}"
)
Expand Down Expand Up @@ -843,7 +843,7 @@ def string_add_sql(
elif location == "prepend":
return f"concat('{value}', {add_to})"
else:
raise dbt.common.exceptions.DbtRuntimeError(
raise dbt_common.exceptions.DbtRuntimeError(
f'Got an unexpected location value of "{location}"'
)

Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
BigQueryPartitionConfigChange,
)
from dbt.adapters.contracts.relation import RelationType, RelationConfig
from dbt.common.exceptions import CompilationError
from dbt.common.utils.dict import filter_null_values
from dbt_common.exceptions import CompilationError
from dbt_common.utils.dict import filter_null_values


Self = TypeVar("Self", bound="BigQueryRelation")
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/bigquery/relation_configs/_partition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import dbt.common.exceptions
import dbt_common.exceptions
from dbt.adapters.relation_configs import RelationConfigChange
from dbt.adapters.contracts.relation import RelationConfig
from dbt.common.dataclass_schema import dbtClassMixin, ValidationError
from dbt_common.dataclass_schema import dbtClassMixin, ValidationError
from google.cloud.bigquery.table import Table as BigQueryTable


Expand Down Expand Up @@ -92,11 +92,11 @@ def parse(cls, raw_partition_by) -> Optional["PartitionConfig"]:
}
)
except ValidationError as exc:
raise dbt.common.exceptions.base.DbtValidationError(
raise dbt_common.exceptions.base.DbtValidationError(
"Could not parse partition config"
) from exc
except TypeError:
raise dbt.common.exceptions.CompilationError(
raise dbt_common.exceptions.CompilationError(
f"Invalid partition_by config:\n"
f" Got: {raw_partition_by}\n"
f' Expected a dictionary with "field" and "data_type" keys'
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/utility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Optional

import dbt.common.exceptions
import dbt_common.exceptions


def bool_setting(value: Optional[Any] = None) -> Optional[bool]:
Expand Down Expand Up @@ -41,5 +41,5 @@ def float_setting(value: Optional[Any] = None) -> Optional[float]:

def sql_escape(string):
if not isinstance(string, str):
raise dbt.common.exceptions.CompilationError(f"cannot escape a non-string: {string}")
raise dbt_common.exceptions.CompilationError(f"cannot escape a non-string: {string}")
return json.dumps(string)[1:-1]
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _dbt_core_version(plugin_version: str) -> str:
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
include_package_data=True,
install_requires=[
f"dbt-core~={_dbt_core_version(_dbt_bigquery_version())}",
"dbt-common<1.0",
"dbt-adapters~=0.1.0a1",
"google-cloud-bigquery~=3.0",
"google-cloud-storage~=2.4",
"google-cloud-dataproc~=5.0",
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/adapter/column_types/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
version: 2
models:
- name: model
tests:
data_tests:
- is_type:
column_map:
int64_col: ['integer', 'number']
Expand All @@ -39,7 +39,7 @@
version: 2
models:
- name: model
tests:
data_tests:
- is_type:
column_map:
int64_col: ['string', 'not number']
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/adapter/test_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
version: 2
models:
- name: model_a
tests:
data_tests:
- expect_value:
field: tablename
value: duped_alias
- name: model_b
tests:
data_tests:
- expect_value:
field: tablename
value: duped_alias
Expand Down
36 changes: 19 additions & 17 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import unittest
from unittest.mock import patch, MagicMock, create_autospec

import dbt.common.dataclass_schema
import dbt.common.exceptions.base
import dbt_common.dataclass_schema
import dbt_common.exceptions.base

import dbt.adapters
from dbt.adapters.bigquery.relation_configs import PartitionConfig
from dbt.adapters.bigquery import BigQueryAdapter, BigQueryRelation
from google.cloud.bigquery.table import Table
from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT
from dbt.common.clients import agate_helper
import dbt.common.exceptions
from dbt_common.clients import agate_helper
import dbt_common.exceptions
from dbt.context.manifest import generate_query_header_context
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import ManifestStateCheck
Expand Down Expand Up @@ -214,7 +216,7 @@ def test_acquire_connection_oauth_no_project_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -231,7 +233,7 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -255,7 +257,7 @@ def test_acquire_connection_dataproc_serverless(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.ValidationException as e:
except dbt_common.exceptions.ValidationException as e:
self.fail("got ValidationException: {}".format(str(e)))

except BaseException:
Expand All @@ -272,7 +274,7 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -289,7 +291,7 @@ def test_acquire_connection_oauth_token_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -306,7 +308,7 @@ def test_acquire_connection_oauth_credentials_validations(self, mock_open_connec
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -325,7 +327,7 @@ def test_acquire_connection_impersonated_service_account_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
Expand All @@ -343,7 +345,7 @@ def test_acquire_connection_priority(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.priority, "batch")

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
Expand All @@ -358,7 +360,7 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.maximum_bytes_billed, 0)

except dbt.common.exceptions.base.DbtValidationError as e:
except dbt_common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
Expand Down Expand Up @@ -509,7 +511,7 @@ def test_invalid_relation(self):
},
"quote_policy": {"identifier": False, "schema": True},
}
with self.assertRaises(dbt.common.dataclass_schema.ValidationError):
with self.assertRaises(dbt_common.dataclass_schema.ValidationError):
BigQueryRelation.validate(kwargs)


Expand Down Expand Up @@ -581,10 +583,10 @@ def test_copy_table_materialization_incremental(self):
def test_parse_partition_by(self):
adapter = self.get_adapter("oauth")

with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("date(ts)")

with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("ts")

self.assertEqual(
Expand Down Expand Up @@ -736,7 +738,7 @@ def test_parse_partition_by(self):
)

# Invalid, should raise an error
with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
with self.assertRaises(dbt_common.exceptions.base.DbtValidationError):
adapter.parse_partition_by({})

# passthrough
Expand Down
Loading

0 comments on commit e86609a

Please sign in to comment.