Skip to content

Commit

Permalink
Merge branch 'main' into test-limit-in-show-sql
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Sep 25, 2023
2 parents 7bfe1f2 + 4c02c07 commit 093c97e
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 228 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230921-155645.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Address type annotation issues and remove protected method ref from impl
time: 2023-09-21T15:56:45.329798-07:00
custom:
Author: colin-rogers-dbt
Issue: "933"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230922-114217.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: update SQLQuery to include node_info
time: 2023-09-22T11:42:17.770033-07:00
custom:
Author: colin-rogers-dbt
Issue: "936"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230922-125327.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Fixed a mypy failure by reworking BigQueryAdapter constructor.
time: 2023-09-22T12:53:27.339599-04:00
custom:
Author: peterallenwebb
Issue: "934"
18 changes: 17 additions & 1 deletion dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt.events.contextvars import get_node_info
from mashumaro.helper import pass_through

from functools import lru_cache
Expand Down Expand Up @@ -444,7 +445,8 @@ def raw_execute(
conn = self.get_thread_connection()
client = conn.handle

fire_event(SQLQuery(conn_name=conn.name, sql=sql))
fire_event(SQLQuery(conn_name=conn.name, sql=sql, node_info=get_node_info()))

if (
hasattr(self.profile, "query_comment")
and self.profile.query_comment
Expand Down Expand Up @@ -700,6 +702,20 @@ def fn():

self._retry_and_handle(msg="create dataset", conn=conn, fn=fn)

def list_dataset(self, database: str):
# the database string we get here is potentially quoted. Strip that off
# for the API call.
database = database.strip("`")
conn = self.get_thread_connection()
client = conn.handle

def query_schemas():
# this is similar to how we have to deal with listing tables
all_datasets = client.list_datasets(project=database, max_results=10000)
return [ds.dataset_id for ds in all_datasets]

return self._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas)

def _query_and_results(
self,
client,
Expand Down
51 changes: 5 additions & 46 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ class BigQueryAdapter(BaseAdapter):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

def __init__(self, config) -> None:
super().__init__(config)
self.connections: BigQueryConnectionManager = self.connections

###
# Implementations of abstract methods
###
Expand Down Expand Up @@ -267,18 +271,7 @@ def rename_relation(

@available
def list_schemas(self, database: str) -> List[str]:
# the database string we get here is potentially quoted. Strip that off
# for the API call.
database = database.strip("`")
conn = self.connections.get_thread_connection()
client = conn.handle

def query_schemas():
# this is similar to how we have to deal with listing tables
all_datasets = client.list_datasets(project=database, max_results=10000)
return [ds.dataset_id for ds in all_datasets]

return self.connections._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas)
return self.connections.list_dataset(database)

@available.parse(lambda *a, **k: False)
def check_schema_exists(self, database: str, schema: str) -> bool:
Expand Down Expand Up @@ -481,40 +474,6 @@ def _agate_to_schema(
bq_schema.append(SchemaField(col_name, type_)) # type: ignore[arg-type]
return bq_schema

def _materialize_as_view(self, model: Dict[str, Any]) -> str:
model_database = model.get("database")
model_schema = model.get("schema")
model_alias = model.get("alias")
model_code = model.get("compiled_code")

logger.debug("Model SQL ({}):\n{}".format(model_alias, model_code))
self.connections.create_view(
database=model_database, schema=model_schema, table_name=model_alias, sql=model_code
)
return "CREATE VIEW"

def _materialize_as_table(
self,
model: Dict[str, Any],
model_sql: str,
decorator: Optional[str] = None,
) -> str:
model_database = model.get("database")
model_schema = model.get("schema")
model_alias = model.get("alias")

if decorator is None:
table_name = model_alias
else:
table_name = "{}${}".format(model_alias, decorator)

logger.debug("Model SQL ({}):\n{}".format(table_name, model_sql))
self.connections.create_table(
database=model_database, schema=model_schema, table_name=table_name, sql=model_sql
)

return "CREATE TABLE"

@available.parse(lambda *a, **k: "")
def copy_table(self, source, destination, materialization):
if materialization == "incremental":
Expand Down
182 changes: 1 addition & 181 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
import time

import agate
import decimal
import json
import string
import random
import re
import pytest
import unittest
from contextlib import contextmanager
from requests.exceptions import ConnectionError
from unittest.mock import patch, MagicMock, Mock, create_autospec, ANY
from unittest.mock import patch, MagicMock, create_autospec

import dbt.dataclass_schema

from dbt.adapters.bigquery import PartitionConfig
from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryAdapter
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery import Plugin as BigQueryPlugin
from google.cloud.bigquery.table import Table
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.clients import agate_helper
Expand Down Expand Up @@ -543,179 +536,6 @@ def test_replace(self):
assert other_schema.quote_policy.database is False


class TestBigQueryConnectionManager(unittest.TestCase):
def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.connections = BigQueryConnectionManager(profile=profile)

self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_connection = MagicMock()

self.mock_connection.handle = self.mock_client

self.connections.get_thread_connection = lambda: self.mock_connection
self.connections.get_job_retry_deadline_seconds = lambda x: None
self.connections.get_job_retries = lambda x: 1

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_and_handle(self, is_retryable):
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

class DummyException(Exception):
"""Count how many times this exception is raised"""

count = 0

def __init__(self):
DummyException.count += 1

def raiseDummyException():
raise DummyException()

with self.assertRaises(DummyException):
self.connections._retry_and_handle(
"some sql", Mock(credentials=Mock(retries=8)), raiseDummyException
)
self.assertEqual(DummyException.count, 9)

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_connection_reset(self, is_retryable):
self.connections.open = MagicMock()
self.connections.close = MagicMock()
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

def raiseConnectionResetError():
raise ConnectionResetError("Connection broke")

mock_conn = Mock(credentials=Mock(retries=1))
with self.assertRaises(ConnectionResetError):
self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError)
self.connections.close.assert_called_once_with(mock_conn)
self.connections.open.assert_called_once_with(mock_conn)

def test_is_retryable(self):
_is_retryable = dbt.adapters.bigquery.connections._is_retryable
exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions
internal_server_error = exceptions.InternalServerError("code broke")
bad_request_error = exceptions.BadRequest("code broke")
connection_error = ConnectionError("code broke")
client_error = exceptions.ClientError("bad code")
rate_limit_error = exceptions.Forbidden(
"code broke", errors=[{"reason": "rateLimitExceeded"}]
)

self.assertTrue(_is_retryable(internal_server_error))
self.assertTrue(_is_retryable(bad_request_error))
self.assertTrue(_is_retryable(connection_error))
self.assertFalse(_is_retryable(client_error))
self.assertTrue(_is_retryable(rate_limit_error))

def test_drop_dataset(self):
mock_table = Mock()
mock_table.reference = "table1"
self.mock_client.list_tables.return_value = [mock_table]

self.connections.drop_dataset("project", "dataset")

self.mock_client.list_tables.assert_not_called()
self.mock_client.delete_table.assert_not_called()
self.mock_client.delete_dataset.assert_called_once()

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results(self, mock_bq):
self.mock_client.query = Mock(return_value=Mock(state="DONE"))
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_creation_timeout=15,
job_execution_timeout=3,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
)

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results_timeout(self, mock_bq):
self.mock_client.query = Mock(
return_value=Mock(result=lambda *args, **kwargs: time.sleep(4))
)
with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc:
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_creation_timeout=15,
job_execution_timeout=1,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
)
assert "Query exceeded configured timeout of 1s" in str(exc.value)

def test_copy_bq_table_appends(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_APPEND
)

def test_copy_bq_table_truncates(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE
)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table):
return self.connections.table_ref(proj, ds, table)

def _copy_table(self, write_disposition):
source = BigQueryRelation.create(database="project", schema="dataset", identifier="table1")
destination = BigQueryRelation.create(
database="project", schema="dataset", identifier="table2"
)
self.connections.copy_bq_table(source, destination, write_disposition)


class TestBigQueryAdapter(BaseTestBigQueryAdapter):
def test_copy_table_materialization_table(self):
adapter = self.get_adapter("oauth")
Expand Down
Loading

0 comments on commit 093c97e

Please sign in to comment.