From 9f6c449619cf245f1788fcf7e5cf13f3e3a855a6 Mon Sep 17 00:00:00 2001 From: colin-rogers-dbt <111200756+colin-rogers-dbt@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:24:14 -0700 Subject: [PATCH 1/2] Address type annotation issues and clean up impl (#933) * use dynamic schema in test_grant_access_to.py * use dynamic schema in test_grant_access_to.py * update impl ConnectionManager typing and move list_datasets into BigQueryConnectionManager * refactor unit tests and add one to cover list_datasets * accidental commit rever * add changie * Rework constructor for mypy. Remove unused functions. * Add changelog entry. * merge paw/type-fix --------- Co-authored-by: Peter Allen Webb --- .../Under the Hood-20230921-155645.yaml | 6 + .../Under the Hood-20230922-125327.yaml | 6 + dbt/adapters/bigquery/connections.py | 14 ++ dbt/adapters/bigquery/impl.py | 51 +---- tests/unit/test_bigquery_adapter.py | 182 +--------------- .../unit/test_bigquery_connection_manager.py | 198 ++++++++++++++++++ 6 files changed, 230 insertions(+), 227 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20230921-155645.yaml create mode 100644 .changes/unreleased/Under the Hood-20230922-125327.yaml create mode 100644 tests/unit/test_bigquery_connection_manager.py diff --git a/.changes/unreleased/Under the Hood-20230921-155645.yaml b/.changes/unreleased/Under the Hood-20230921-155645.yaml new file mode 100644 index 000000000..12cd663f8 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230921-155645.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Address type annotation issues and remove protected method ref from impl +time: 2023-09-21T15:56:45.329798-07:00 +custom: + Author: colin-rogers-dbt + Issue: "933" diff --git a/.changes/unreleased/Under the Hood-20230922-125327.yaml b/.changes/unreleased/Under the Hood-20230922-125327.yaml new file mode 100644 index 000000000..9ce871321 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230922-125327.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Fixed a mypy failure by reworking BigQueryAdapter constructor. +time: 2023-09-22T12:53:27.339599-04:00 +custom: + Author: peterallenwebb + Issue: "934" diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index c136042c3..7799ecb8a 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -700,6 +700,20 @@ def fn(): self._retry_and_handle(msg="create dataset", conn=conn, fn=fn) + def list_dataset(self, database: str): + # the database string we get here is potentially quoted. Strip that off + # for the API call. + database = database.strip("`") + conn = self.get_thread_connection() + client = conn.handle + + def query_schemas(): + # this is similar to how we have to deal with listing tables + all_datasets = client.list_datasets(project=database, max_results=10000) + return [ds.dataset_id for ds in all_datasets] + + return self._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas) + def _query_and_results( self, client, diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index bb04c78b8..8fc1b69bb 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -217,6 +217,10 @@ class BigQueryAdapter(BaseAdapter): ConstraintType.foreign_key: ConstraintSupport.ENFORCED, } + def __init__(self, config) -> None: + super().__init__(config) + self.connections: BigQueryConnectionManager = self.connections + ### # Implementations of abstract methods ### @@ -267,18 +271,7 @@ def rename_relation( @available def list_schemas(self, database: str) -> List[str]: - # the database string we get here is potentially quoted. Strip that off - # for the API call. - database = database.strip("`") - conn = self.connections.get_thread_connection() - client = conn.handle - - def query_schemas(): - # this is similar to how we have to deal with listing tables - all_datasets = client.list_datasets(project=database, max_results=10000) - return [ds.dataset_id for ds in all_datasets] - - return self.connections._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas) + return self.connections.list_dataset(database) @available.parse(lambda *a, **k: False) def check_schema_exists(self, database: str, schema: str) -> bool: @@ -481,40 +474,6 @@ def _agate_to_schema( bq_schema.append(SchemaField(col_name, type_)) # type: ignore[arg-type] return bq_schema - def _materialize_as_view(self, model: Dict[str, Any]) -> str: - model_database = model.get("database") - model_schema = model.get("schema") - model_alias = model.get("alias") - model_code = model.get("compiled_code") - - logger.debug("Model SQL ({}):\n{}".format(model_alias, model_code)) - self.connections.create_view( - database=model_database, schema=model_schema, table_name=model_alias, sql=model_code - ) - return "CREATE VIEW" - - def _materialize_as_table( - self, - model: Dict[str, Any], - model_sql: str, - decorator: Optional[str] = None, - ) -> str: - model_database = model.get("database") - model_schema = model.get("schema") - model_alias = model.get("alias") - - if decorator is None: - table_name = model_alias - else: - table_name = "{}${}".format(model_alias, decorator) - - logger.debug("Model SQL ({}):\n{}".format(table_name, model_sql)) - self.connections.create_table( - database=model_database, schema=model_schema, table_name=table_name, sql=model_sql - ) - - return "CREATE TABLE" - @available.parse(lambda *a, **k: "") def copy_table(self, source, destination, materialization): if materialization == "incremental": diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 10cb3f530..4db2ce83d 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -1,26 +1,19 @@ -import time - import agate import decimal -import json import string import random import re import pytest import unittest -from contextlib import contextmanager -from requests.exceptions import ConnectionError -from unittest.mock import patch, MagicMock, Mock, create_autospec, ANY +from unittest.mock import patch, MagicMock, create_autospec import dbt.dataclass_schema from dbt.adapters.bigquery import PartitionConfig -from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery import BigQueryRelation from dbt.adapters.bigquery import Plugin as BigQueryPlugin from google.cloud.bigquery.table import Table -from dbt.adapters.bigquery.connections import BigQueryConnectionManager from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.clients import agate_helper @@ -543,179 +536,6 @@ def test_replace(self): assert other_schema.quote_policy.database is False -class TestBigQueryConnectionManager(unittest.TestCase): - def setUp(self): - credentials = Mock(BigQueryCredentials) - profile = Mock(query_comment=None, credentials=credentials) - self.connections = BigQueryConnectionManager(profile=profile) - - self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client) - self.mock_connection = MagicMock() - - self.mock_connection.handle = self.mock_client - - self.connections.get_thread_connection = lambda: self.mock_connection - self.connections.get_job_retry_deadline_seconds = lambda x: None - self.connections.get_job_retries = lambda x: 1 - - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) - def test_retry_and_handle(self, is_retryable): - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield - - self.connections.exception_handler = dummy_handler - - class DummyException(Exception): - """Count how many times this exception is raised""" - - count = 0 - - def __init__(self): - DummyException.count += 1 - - def raiseDummyException(): - raise DummyException() - - with self.assertRaises(DummyException): - self.connections._retry_and_handle( - "some sql", Mock(credentials=Mock(retries=8)), raiseDummyException - ) - self.assertEqual(DummyException.count, 9) - - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) - def test_retry_connection_reset(self, is_retryable): - self.connections.open = MagicMock() - self.connections.close = MagicMock() - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield - - self.connections.exception_handler = dummy_handler - - def raiseConnectionResetError(): - raise ConnectionResetError("Connection broke") - - mock_conn = Mock(credentials=Mock(retries=1)) - with self.assertRaises(ConnectionResetError): - self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError) - self.connections.close.assert_called_once_with(mock_conn) - self.connections.open.assert_called_once_with(mock_conn) - - def test_is_retryable(self): - _is_retryable = dbt.adapters.bigquery.connections._is_retryable - exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions - internal_server_error = exceptions.InternalServerError("code broke") - bad_request_error = exceptions.BadRequest("code broke") - connection_error = ConnectionError("code broke") - client_error = exceptions.ClientError("bad code") - rate_limit_error = exceptions.Forbidden( - "code broke", errors=[{"reason": "rateLimitExceeded"}] - ) - - self.assertTrue(_is_retryable(internal_server_error)) - self.assertTrue(_is_retryable(bad_request_error)) - self.assertTrue(_is_retryable(connection_error)) - self.assertFalse(_is_retryable(client_error)) - self.assertTrue(_is_retryable(rate_limit_error)) - - def test_drop_dataset(self): - mock_table = Mock() - mock_table.reference = "table1" - self.mock_client.list_tables.return_value = [mock_table] - - self.connections.drop_dataset("project", "dataset") - - self.mock_client.list_tables.assert_not_called() - self.mock_client.delete_table.assert_not_called() - self.mock_client.delete_dataset.assert_called_once() - - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_query_and_results(self, mock_bq): - self.mock_client.query = Mock(return_value=Mock(state="DONE")) - self.connections._query_and_results( - self.mock_client, - "sql", - {"job_param_1": "blah"}, - job_creation_timeout=15, - job_execution_timeout=3, - ) - - mock_bq.QueryJobConfig.assert_called_once() - self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 - ) - - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_query_and_results_timeout(self, mock_bq): - self.mock_client.query = Mock( - return_value=Mock(result=lambda *args, **kwargs: time.sleep(4)) - ) - with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc: - self.connections._query_and_results( - self.mock_client, - "sql", - {"job_param_1": "blah"}, - job_creation_timeout=15, - job_execution_timeout=1, - ) - - mock_bq.QueryJobConfig.assert_called_once() - self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 - ) - assert "Query exceeded configured timeout of 1s" in str(exc.value) - - def test_copy_bq_table_appends(self): - self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) - args, kwargs = self.mock_client.copy_table.call_args - self.mock_client.copy_table.assert_called_once_with( - [self._table_ref("project", "dataset", "table1")], - self._table_ref("project", "dataset", "table2"), - job_config=ANY, - ) - args, kwargs = self.mock_client.copy_table.call_args - self.assertEqual( - kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_APPEND - ) - - def test_copy_bq_table_truncates(self): - self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE) - args, kwargs = self.mock_client.copy_table.call_args - self.mock_client.copy_table.assert_called_once_with( - [self._table_ref("project", "dataset", "table1")], - self._table_ref("project", "dataset", "table2"), - job_config=ANY, - ) - args, kwargs = self.mock_client.copy_table.call_args - self.assertEqual( - kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE - ) - - def test_job_labels_valid_json(self): - expected = {"key": "value"} - labels = self.connections._labels_from_query_comment(json.dumps(expected)) - self.assertEqual(labels, expected) - - def test_job_labels_invalid_json(self): - labels = self.connections._labels_from_query_comment("not json") - self.assertEqual(labels, {"query_comment": "not_json"}) - - def _table_ref(self, proj, ds, table): - return self.connections.table_ref(proj, ds, table) - - def _copy_table(self, write_disposition): - source = BigQueryRelation.create(database="project", schema="dataset", identifier="table1") - destination = BigQueryRelation.create( - database="project", schema="dataset", identifier="table2" - ) - self.connections.copy_bq_table(source, destination, write_disposition) - - class TestBigQueryAdapter(BaseTestBigQueryAdapter): def test_copy_table_materialization_table(self): adapter = self.get_adapter("oauth") diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py new file mode 100644 index 000000000..d6c3f64fc --- /dev/null +++ b/tests/unit/test_bigquery_connection_manager.py @@ -0,0 +1,198 @@ +import time +import json +import pytest +import unittest +from contextlib import contextmanager +from requests.exceptions import ConnectionError +from unittest.mock import patch, MagicMock, Mock, ANY + +import dbt.dataclass_schema + +from dbt.adapters.bigquery import BigQueryCredentials +from dbt.adapters.bigquery import BigQueryRelation +from dbt.adapters.bigquery.connections import BigQueryConnectionManager +import dbt.exceptions +from dbt.logger import GLOBAL_LOGGER as logger # noqa + + +class TestBigQueryConnectionManager(unittest.TestCase): + def setUp(self): + credentials = Mock(BigQueryCredentials) + profile = Mock(query_comment=None, credentials=credentials) + self.connections = BigQueryConnectionManager(profile=profile) + + self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client) + self.mock_connection = MagicMock() + + self.mock_connection.handle = self.mock_client + + self.connections.get_thread_connection = lambda: self.mock_connection + self.connections.get_job_retry_deadline_seconds = lambda x: None + self.connections.get_job_retries = lambda x: 1 + + @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) + def test_retry_and_handle(self, is_retryable): + self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 + + @contextmanager + def dummy_handler(msg): + yield + + self.connections.exception_handler = dummy_handler + + class DummyException(Exception): + """Count how many times this exception is raised""" + + count = 0 + + def __init__(self): + DummyException.count += 1 + + def raiseDummyException(): + raise DummyException() + + with self.assertRaises(DummyException): + self.connections._retry_and_handle( + "some sql", Mock(credentials=Mock(retries=8)), raiseDummyException + ) + self.assertEqual(DummyException.count, 9) + + @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) + def test_retry_connection_reset(self, is_retryable): + self.connections.open = MagicMock() + self.connections.close = MagicMock() + self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 + + @contextmanager + def dummy_handler(msg): + yield + + self.connections.exception_handler = dummy_handler + + def raiseConnectionResetError(): + raise ConnectionResetError("Connection broke") + + mock_conn = Mock(credentials=Mock(retries=1)) + with self.assertRaises(ConnectionResetError): + self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError) + self.connections.close.assert_called_once_with(mock_conn) + self.connections.open.assert_called_once_with(mock_conn) + + def test_is_retryable(self): + _is_retryable = dbt.adapters.bigquery.connections._is_retryable + exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions + internal_server_error = exceptions.InternalServerError("code broke") + bad_request_error = exceptions.BadRequest("code broke") + connection_error = ConnectionError("code broke") + client_error = exceptions.ClientError("bad code") + rate_limit_error = exceptions.Forbidden( + "code broke", errors=[{"reason": "rateLimitExceeded"}] + ) + + self.assertTrue(_is_retryable(internal_server_error)) + self.assertTrue(_is_retryable(bad_request_error)) + self.assertTrue(_is_retryable(connection_error)) + self.assertFalse(_is_retryable(client_error)) + self.assertTrue(_is_retryable(rate_limit_error)) + + def test_drop_dataset(self): + mock_table = Mock() + mock_table.reference = "table1" + self.mock_client.list_tables.return_value = [mock_table] + + self.connections.drop_dataset("project", "dataset") + + self.mock_client.list_tables.assert_not_called() + self.mock_client.delete_table.assert_not_called() + self.mock_client.delete_dataset.assert_called_once() + + @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") + def test_query_and_results(self, mock_bq): + self.mock_client.query = Mock(return_value=Mock(state="DONE")) + self.connections._query_and_results( + self.mock_client, + "sql", + {"job_param_1": "blah"}, + job_creation_timeout=15, + job_execution_timeout=3, + ) + + mock_bq.QueryJobConfig.assert_called_once() + self.mock_client.query.assert_called_once_with( + query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 + ) + + @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") + def test_query_and_results_timeout(self, mock_bq): + self.mock_client.query = Mock( + return_value=Mock(result=lambda *args, **kwargs: time.sleep(4)) + ) + with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc: + self.connections._query_and_results( + self.mock_client, + "sql", + {"job_param_1": "blah"}, + job_creation_timeout=15, + job_execution_timeout=1, + ) + + mock_bq.QueryJobConfig.assert_called_once() + self.mock_client.query.assert_called_once_with( + query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 + ) + assert "Query exceeded configured timeout of 1s" in str(exc.value) + + def test_copy_bq_table_appends(self): + self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) + args, kwargs = self.mock_client.copy_table.call_args + self.mock_client.copy_table.assert_called_once_with( + [self._table_ref("project", "dataset", "table1")], + self._table_ref("project", "dataset", "table2"), + job_config=ANY, + ) + args, kwargs = self.mock_client.copy_table.call_args + self.assertEqual( + kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_APPEND + ) + + def test_copy_bq_table_truncates(self): + self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE) + args, kwargs = self.mock_client.copy_table.call_args + self.mock_client.copy_table.assert_called_once_with( + [self._table_ref("project", "dataset", "table1")], + self._table_ref("project", "dataset", "table2"), + job_config=ANY, + ) + args, kwargs = self.mock_client.copy_table.call_args + self.assertEqual( + kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE + ) + + def test_job_labels_valid_json(self): + expected = {"key": "value"} + labels = self.connections._labels_from_query_comment(json.dumps(expected)) + self.assertEqual(labels, expected) + + def test_job_labels_invalid_json(self): + labels = self.connections._labels_from_query_comment("not json") + self.assertEqual(labels, {"query_comment": "not_json"}) + + def test_list_dataset_correctly_calls_lists_datasets(self): + mock_dataset = Mock(dataset_id="d1") + mock_list_dataset = Mock(return_value=[mock_dataset]) + self.mock_client.list_datasets = mock_list_dataset + result = self.connections.list_dataset("project") + self.mock_client.list_datasets.assert_called_once_with( + project="project", max_results=10000 + ) + assert result == ["d1"] + + def _table_ref(self, proj, ds, table): + return self.connections.table_ref(proj, ds, table) + + def _copy_table(self, write_disposition): + source = BigQueryRelation.create(database="project", schema="dataset", identifier="table1") + destination = BigQueryRelation.create( + database="project", schema="dataset", identifier="table2" + ) + self.connections.copy_bq_table(source, destination, write_disposition) From 4c02c07b95f8c1c582320558384a349e1d4fd51a Mon Sep 17 00:00:00 2001 From: colin-rogers-dbt <111200756+colin-rogers-dbt@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:03:18 -0700 Subject: [PATCH 2/2] Update SQLQuery to include node_info (#936) * use dynamic schema in test_grant_access_to.py * use dynamic schema in test_grant_access_to.py * update SQLQuery to include node_info * add changie * revert setup --- .changes/unreleased/Under the Hood-20230922-114217.yaml | 6 ++++++ dbt/adapters/bigquery/connections.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Under the Hood-20230922-114217.yaml diff --git a/.changes/unreleased/Under the Hood-20230922-114217.yaml b/.changes/unreleased/Under the Hood-20230922-114217.yaml new file mode 100644 index 000000000..78fee33c4 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230922-114217.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: update SQLQuery to include node_info +time: 2023-09-22T11:42:17.770033-07:00 +custom: + Author: colin-rogers-dbt + Issue: "936" diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 7799ecb8a..a5c7b9355 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field +from dbt.events.contextvars import get_node_info from mashumaro.helper import pass_through from functools import lru_cache @@ -444,7 +445,8 @@ def raw_execute( conn = self.get_thread_connection() client = conn.handle - fire_event(SQLQuery(conn_name=conn.name, sql=sql)) + fire_event(SQLQuery(conn_name=conn.name, sql=sql, node_info=get_node_info())) + if ( hasattr(self.profile, "query_comment") and self.profile.query_comment