Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Query Cancellation #918

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import re
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt_common.invocation import get_invocation_id

from dbt_common.events.contextvars import get_node_info
import uuid
from mashumaro.helper import pass_through

from functools import lru_cache
Expand All @@ -25,20 +22,25 @@
service_account as GoogleServiceAccountCredentials,
)

from dbt.adapters.bigquery import gcloud
from dbt_common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import (
DbtRuntimeError,
DbtConfigError,
)

from dbt_common.exceptions import DbtDatabaseError
from dbt_common.invocation import get_invocation_id
from dbt.adapters.bigquery import gcloud
from dbt.adapters.contracts.connection import (
ConnectionState,
AdapterResponse,
Credentials,
AdapterRequiredConfig
)
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.base import BaseConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import SQLQuery
from dbt_common.events.functions import fire_event
from dbt.adapters.bigquery import __version__ as dbt_version

from dbt_common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum
Expand Down Expand Up @@ -227,6 +229,11 @@ class BigQueryConnectionManager(BaseConnectionManager):
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds


def __init__(self, profile: AdapterRequiredConfig):
super().__init__(profile)
self.jobs_by_thread = {}

@classmethod
def handle_error(cls, error, message):
error_msg = "\n".join([item["message"] for item in error.errors])
Expand Down Expand Up @@ -280,11 +287,29 @@ def exception_handler(self, sql):
exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip()
raise DbtRuntimeError(exc_message)

def cancel_open(self) -> None:
pass
def cancel_open(self):
names = []
this_connection = self.get_if_exists()
with self.lock:
for thread_id, connection in self.thread_connections.items():
if connection is this_connection:
continue
if connection.handle is not None and connection.state == ConnectionState.OPEN:
client = connection.handle
for job_id in self.jobs_by_thread.get(thread_id, []):
def fn():
return client.cancel_job(job_id)
self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn)

self.close(connection)

if connection.name is not None:
names.append(connection.name)
return names

@classmethod
def close(cls, connection):
connection.handle.close()
connection.state = ConnectionState.CLOSED

return connection
Expand Down Expand Up @@ -478,12 +503,16 @@ def raw_execute(

job_creation_timeout = self.get_job_creation_timeout_seconds(conn)
job_execution_timeout = self.get_job_execution_timeout_seconds(conn)
job_id = str(uuid.uuid4())
thread_id = self.get_thread_identifier()
self.jobs_by_thread[thread_id] = self.jobs_by_thread.get(thread_id, []) + [job_id]

def fn():
return self._query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=job_creation_timeout,
job_execution_timeout=job_execution_timeout,
limit=limit,
Expand Down Expand Up @@ -721,13 +750,15 @@ def _query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=None,
job_execution_timeout=None,
limit: Optional[int] = None,
):
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
query_job = client.query(query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout)
query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout)
if (
query_job.location is not None
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def date_function(cls) -> str:

@classmethod
def is_cancelable(cls) -> bool:
return False
return True

def drop_relation(self, relation: BigQueryRelation) -> None:
is_cached = self._schema_is_cached(relation.database, relation.schema) # type: ignore[arg-type]
Expand Down
173 changes: 166 additions & 7 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
inject_adapter,
TestAdapterConversions,
load_internal_manifest_macros,
mock_connection
)


Expand Down Expand Up @@ -368,23 +369,27 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):

def test_cancel_open_connections_empty(self):
adapter = self.get_adapter("oauth")
self.assertEqual(adapter.cancel_open_connections(), None)
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_master(self):
adapter = self.get_adapter("oauth")
adapter.connections.thread_connections[0] = object()
self.assertEqual(adapter.cancel_open_connections(), None)
key = adapter.connections.get_thread_identifier()
adapter.connections.thread_connections[key] = mock_connection("master")
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_single(self):
adapter = self.get_adapter("oauth")
master = mock_connection("master")
model = mock_connection("model")
key = adapter.connections.get_thread_identifier()

adapter.connections.thread_connections.update(
{
0: object(),
1: object(),
key: master,
1: model
}
)
# actually does nothing
self.assertEqual(adapter.cancel_open_connections(), None)
self.assertEqual(len(list(adapter.cancel_open_connections())), 1)

@patch("dbt.adapters.bigquery.impl.google.auth.default")
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
Expand Down Expand Up @@ -562,6 +567,160 @@ def test_replace(self):
assert other_schema.quote_policy.database is False


# TODO: move to tests/unit/test_bigquery_connection_manager.py
class TestBigQueryConnectionManager(unittest.TestCase):
def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.connections = BigQueryConnectionManager(profile=profile)

self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_connection = MagicMock()

self.mock_connection.handle = self.mock_client

self.connections.get_thread_connection = lambda: self.mock_connection
self.connections.get_job_retry_deadline_seconds = lambda x: None
self.connections.get_job_retries = lambda x: 1

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_and_handle(self, is_retryable):
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

class DummyException(Exception):
"""Count how many times this exception is raised"""

count = 0

def __init__(self):
DummyException.count += 1

def raiseDummyException():
raise DummyException()

with self.assertRaises(DummyException):
self.connections._retry_and_handle(
"some sql", Mock(credentials=Mock(retries=8)), raiseDummyException
)
self.assertEqual(DummyException.count, 9)

@patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True)
def test_retry_connection_reset(self, is_retryable):
self.connections.open = MagicMock()
self.connections.close = MagicMock()
self.connections.DEFAULT_MAXIMUM_DELAY = 2.0

@contextmanager
def dummy_handler(msg):
yield

self.connections.exception_handler = dummy_handler

def raiseConnectionResetError():
raise ConnectionResetError("Connection broke")

mock_conn = Mock(credentials=Mock(retries=1))
with self.assertRaises(ConnectionResetError):
self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError)
self.connections.close.assert_called_once_with(mock_conn)
self.connections.open.assert_called_once_with(mock_conn)

def test_is_retryable(self):
_is_retryable = dbt.adapters.bigquery.connections._is_retryable
exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions
internal_server_error = exceptions.InternalServerError("code broke")
bad_request_error = exceptions.BadRequest("code broke")
connection_error = ConnectionError("code broke")
client_error = exceptions.ClientError("bad code")
rate_limit_error = exceptions.Forbidden(
"code broke", errors=[{"reason": "rateLimitExceeded"}]
)

self.assertTrue(_is_retryable(internal_server_error))
self.assertTrue(_is_retryable(bad_request_error))
self.assertTrue(_is_retryable(connection_error))
self.assertFalse(_is_retryable(client_error))
self.assertTrue(_is_retryable(rate_limit_error))

def test_drop_dataset(self):
mock_table = Mock()
mock_table.reference = "table1"
self.mock_client.list_tables.return_value = [mock_table]

self.connections.drop_dataset("project", "dataset")

self.mock_client.list_tables.assert_not_called()
self.mock_client.delete_table.assert_not_called()
self.mock_client.delete_dataset.assert_called_once()

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results(self, mock_bq):
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=100,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)

def test_copy_bq_table_appends(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_APPEND
)

def test_copy_bq_table_truncates(self):
self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_TRUNCATE)
args, kwargs = self.mock_client.copy_table.call_args
self.mock_client.copy_table.assert_called_once_with(
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
kwargs["job_config"].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE
)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table):
return self.connections.table_ref(proj, ds, table)

def _copy_table(self, write_disposition):
source = BigQueryRelation.create(database="project", schema="dataset", identifier="table1")
destination = BigQueryRelation.create(
database="project", schema="dataset", identifier="table2"
)
self.connections.copy_bq_table(source, destination, write_disposition)


class TestBigQueryAdapter(BaseTestBigQueryAdapter):
def test_copy_table_materialization_table(self):
adapter = self.get_adapter("oauth")
Expand Down
Loading