Skip to content

Commit

Permalink
initial work on bigquery canceling a retry if a job did in fact succe…
Browse files Browse the repository at this point in the history
…ed in previous run
  • Loading branch information
McKnight-42 committed Oct 23, 2023
1 parent bfc5dc4 commit a4e2b74
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
46 changes: 41 additions & 5 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import lru_cache
import agate
from requests.exceptions import ConnectionError
from typing import Optional, Any, Dict, Tuple
from typing import Optional, Any, Dict, List, Tuple

import google.auth
import google.auth.exceptions
Expand All @@ -25,10 +25,11 @@
)

from dbt.adapters.bigquery import gcloud
from dbt.adapters.bigquery.jobs import define_job_id
from dbt.clients import agate_helper
from dbt.config.profile import INVALID_PROFILE_MESSAGE
from dbt.tracking import active_user
from dbt.contracts.connection import ConnectionState, AdapterResponse
from dbt.contracts.connection import ConnectionState, AdapterResponse, AdapterRequiredConfig
from dbt.exceptions import (
FailedToConnectError,
DbtRuntimeError,
Expand Down Expand Up @@ -229,6 +230,10 @@ class BigQueryConnectionManager(BaseConnectionManager):
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds

def __init__(self, profile: AdapterRequiredConfig):
super().__init__(profile)
self.jobs_by_thread: Dict[Any, Any] = {}

@classmethod
def handle_error(cls, error, message):
error_msg = "\n".join([item["message"] for item in error.errors])
Expand Down Expand Up @@ -282,8 +287,28 @@ def exception_handler(self, sql):
exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip()
raise DbtRuntimeError(exc_message)

def cancel_open(self) -> None:
pass
def cancel_open(self) -> List[str]:
names = []
this_connection = self.get_if_exists()
with self.lock:
for thread_id, connection in self.thread_connections.items():
if connection is this_connection:
continue

if connection.handle is not None and connection.state == ConnectionState.OPEN:
client = connection.handle
for job_id in self.jobs_by_thread.get(thread_id, []):

def fn():
return client.cancel_job(job_id)

self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn)

self.close(connection)

if connection.name is not None:
names.append(connection.name)
return names

@classmethod
def close(cls, connection):
Expand Down Expand Up @@ -481,18 +506,26 @@ def raw_execute(

job_creation_timeout = self.get_job_creation_timeout_seconds(conn)
job_execution_timeout = self.get_job_execution_timeout_seconds(conn)
# build out determinsitic_id
model_name = client.connection.name
invocation_id = client.connection.clent_info.invocation_id
job_id = define_job_id(model_name, invocation_id)
thread_id = self.get_thread_identifier()
self.jobs_by_thread[thread_id] = self.jobs_by_thread.get(thread_id, []) + [job_id]

def fn():
return self._query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=job_creation_timeout,
job_execution_timeout=job_execution_timeout,
limit=limit,
)

query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn)
self.jobs_by_thread.get(thread_id, []).remove(job_id)

return query_job, iterator

Expand Down Expand Up @@ -724,14 +757,17 @@ def _query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=None,
job_execution_timeout=None,
limit: Optional[int] = None,
):
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout)
query_job = client.query(
query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout
)
if (
query_job.location is not None
and query_job.job_id is not None
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def date_function(cls) -> str:

@classmethod
def is_cancelable(cls) -> bool:
return False
return True

def drop_relation(self, relation: BigQueryRelation) -> None:
is_cached = self._schema_is_cached(relation.database, relation.schema) # type: ignore[arg-type]
Expand Down
3 changes: 3 additions & 0 deletions dbt/adapters/bigquery/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def define_job_id(model_name, invocation_id):
job_id = f"{model_name}_{invocation_id}"
return job_id
6 changes: 5 additions & 1 deletion tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from google.cloud.bigquery import AccessEntry

from .utils import config_from_parts_or_dicts, inject_adapter, TestAdapterConversions
from .utils import (
config_from_parts_or_dicts,
inject_adapter,
TestAdapterConversions,
)


def _bq_conn():
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,35 @@ def test_query_and_results(self, mock_bq):
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=3,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results_timeout(self, mock_bq):
self.mock_client.query = Mock(
return_value=Mock(result=lambda *args, **kwargs: time.sleep(4))
)

with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc:
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=1,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)
assert "Query exceeded configured timeout of 1s" in str(exc.value)

Expand Down

0 comments on commit a4e2b74

Please sign in to comment.