diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 132854748..63b7055ce 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -11,7 +11,7 @@ from functools import lru_cache import agate from requests.exceptions import ConnectionError -from typing import Optional, Any, Dict, Tuple +from typing import Optional, Any, Dict, List, Tuple import google.auth import google.auth.exceptions @@ -25,10 +25,11 @@ ) from dbt.adapters.bigquery import gcloud +from dbt.adapters.bigquery.jobs import define_job_id from dbt.clients import agate_helper from dbt.config.profile import INVALID_PROFILE_MESSAGE from dbt.tracking import active_user -from dbt.contracts.connection import ConnectionState, AdapterResponse +from dbt.contracts.connection import ConnectionState, AdapterResponse, AdapterRequiredConfig from dbt.exceptions import ( FailedToConnectError, DbtRuntimeError, @@ -229,6 +230,10 @@ class BigQueryConnectionManager(BaseConnectionManager): DEFAULT_INITIAL_DELAY = 1.0 # Seconds DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds + def __init__(self, profile: AdapterRequiredConfig): + super().__init__(profile) + self.jobs_by_thread: Dict[Any, Any] = {} + @classmethod def handle_error(cls, error, message): error_msg = "\n".join([item["message"] for item in error.errors]) @@ -282,8 +287,28 @@ def exception_handler(self, sql): exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip() raise DbtRuntimeError(exc_message) - def cancel_open(self) -> None: - pass + def cancel_open(self) -> List[str]: + names = [] + this_connection = self.get_if_exists() + with self.lock: + for thread_id, connection in self.thread_connections.items(): + if connection is this_connection: + continue + + if connection.handle is not None and connection.state == ConnectionState.OPEN: + client = connection.handle + for job_id in self.jobs_by_thread.get(thread_id, []): + + def fn(): + return client.cancel_job(job_id) + + self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn) + + self.close(connection) + + if connection.name is not None: + names.append(connection.name) + return names @classmethod def close(cls, connection): @@ -481,18 +506,26 @@ def raw_execute( job_creation_timeout = self.get_job_creation_timeout_seconds(conn) job_execution_timeout = self.get_job_execution_timeout_seconds(conn) + # build out determinsitic_id + model_name = client.connection.name + invocation_id = client.connection.clent_info.invocation_id + job_id = define_job_id(model_name, invocation_id) + thread_id = self.get_thread_identifier() + self.jobs_by_thread[thread_id] = self.jobs_by_thread.get(thread_id, []) + [job_id] def fn(): return self._query_and_results( client, sql, job_params, + job_id, job_creation_timeout=job_creation_timeout, job_execution_timeout=job_execution_timeout, limit=limit, ) query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn) + self.jobs_by_thread.get(thread_id, []).remove(job_id) return query_job, iterator @@ -724,6 +757,7 @@ def _query_and_results( client, sql, job_params, + job_id, job_creation_timeout=None, job_execution_timeout=None, limit: Optional[int] = None, @@ -731,7 +765,9 @@ def _query_and_results( """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used job_config = google.cloud.bigquery.QueryJobConfig(**job_params) - query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout) + query_job = client.query( + query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout + ) if ( query_job.location is not None and query_job.job_id is not None diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 50ce21f11..41bbf6532 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -130,7 +130,7 @@ def date_function(cls) -> str: @classmethod def is_cancelable(cls) -> bool: - return False + return True def drop_relation(self, relation: BigQueryRelation) -> None: is_cached = self._schema_is_cached(relation.database, relation.schema) # type: ignore[arg-type] diff --git a/dbt/adapters/bigquery/jobs.py b/dbt/adapters/bigquery/jobs.py new file mode 100644 index 000000000..d1abe5ad4 --- /dev/null +++ b/dbt/adapters/bigquery/jobs.py @@ -0,0 +1,3 @@ +def define_job_id(model_name, invocation_id): + job_id = f"{model_name}_{invocation_id}" + return job_id diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index 926547e10..3e16ecd82 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -21,7 +21,11 @@ from google.cloud.bigquery import AccessEntry -from .utils import config_from_parts_or_dicts, inject_adapter, TestAdapterConversions +from .utils import ( + config_from_parts_or_dicts, + inject_adapter, + TestAdapterConversions, +) def _bq_conn(): diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index d6c3f64fc..3cc81607d 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -113,13 +113,14 @@ def test_query_and_results(self, mock_bq): self.mock_client, "sql", {"job_param_1": "blah"}, + job_id=1, job_creation_timeout=15, job_execution_timeout=3, ) mock_bq.QueryJobConfig.assert_called_once() self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 + query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15 ) @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") @@ -127,18 +128,20 @@ def test_query_and_results_timeout(self, mock_bq): self.mock_client.query = Mock( return_value=Mock(result=lambda *args, **kwargs: time.sleep(4)) ) + with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc: self.connections._query_and_results( self.mock_client, "sql", {"job_param_1": "blah"}, + job_id=1, job_creation_timeout=15, job_execution_timeout=1, ) mock_bq.QueryJobConfig.assert_called_once() self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15 + query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15 ) assert "Query exceeded configured timeout of 1s" in str(exc.value)