Skip to content

Commit

Permalink
Add cancel (#1251)
Browse files Browse the repository at this point in the history
* Add query cancellation.

* clean up merge + linting

* Add back mp_context

* generating a fresh job_id for every _query_and_results call

* add cancellation test

* add cancellation test

* add seed cancellation

* remove type ignore

* add changie

* use defaultdict to simplify code

---------

Co-authored-by: Daniel Cole <[email protected]>
Co-authored-by: Jeremy Cohen <[email protected]>
Co-authored-by: Colin Rogers <[email protected]>
Co-authored-by: Colin <[email protected]>
  • Loading branch information
5 people authored Aug 1, 2024
1 parent 8c0a192 commit 3839953
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 29 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240730-135911.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add support for cancelling queries on keyboard interrupt
time: 2024-07-30T13:59:11.585452-07:00
custom:
Author: d-cole MichelleArk colin-rogers-dbt
Issue: "917"
74 changes: 61 additions & 13 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from collections import defaultdict
from concurrent.futures import TimeoutError
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt_common.invocation import get_invocation_id

from dbt_common.events.contextvars import get_node_info
import uuid
from mashumaro.helper import pass_through

from functools import lru_cache
from requests.exceptions import ConnectionError
from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING

from multiprocessing.context import SpawnContext
from typing import Optional, Any, Dict, Tuple, Hashable, List, TYPE_CHECKING

import google.auth
import google.auth.exceptions
Expand All @@ -24,19 +24,25 @@
service_account as GoogleServiceAccountCredentials,
)

from dbt.adapters.bigquery import gcloud
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse, Credentials
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import (
DbtRuntimeError,
DbtConfigError,
DbtDatabaseError,
)
from dbt_common.invocation import get_invocation_id
from dbt.adapters.bigquery import gcloud
from dbt.adapters.contracts.connection import (
ConnectionState,
AdapterResponse,
Credentials,
AdapterRequiredConfig,
)

from dbt_common.exceptions import DbtDatabaseError
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.base import BaseConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import SQLQuery
from dbt_common.events.functions import fire_event
from dbt.adapters.bigquery import __version__ as dbt_version
from dbt.adapters.bigquery.utility import is_base64, base64_to_string

Expand Down Expand Up @@ -231,6 +237,10 @@ class BigQueryConnectionManager(BaseConnectionManager):
DEFAULT_INITIAL_DELAY = 1.0 # Seconds
DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
super().__init__(profile, mp_context)
self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list)

@classmethod
def handle_error(cls, error, message):
error_msg = "\n".join([item["message"] for item in error.errors])
Expand Down Expand Up @@ -284,11 +294,31 @@ def exception_handler(self, sql):
exc_message = exc_message.split(BQ_QUERY_JOB_SPLIT)[0].strip()
raise DbtRuntimeError(exc_message)

def cancel_open(self) -> None:
pass
def cancel_open(self):
names = []
this_connection = self.get_if_exists()
with self.lock:
for thread_id, connection in self.thread_connections.items():
if connection is this_connection:
continue
if connection.handle is not None and connection.state == ConnectionState.OPEN:
client = connection.handle
for job_id in self.jobs_by_thread.get(thread_id, []):

def fn():
return client.cancel_job(job_id)

self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn)

self.close(connection)

if connection.name is not None:
names.append(connection.name)
return names

@classmethod
def close(cls, connection):
connection.handle.close()
connection.state = ConnectionState.CLOSED

return connection
Expand Down Expand Up @@ -452,6 +482,18 @@ def get_labels_from_query_comment(cls):

return {}

def generate_job_id(self) -> str:
# Generating a fresh job_id for every _query_and_results call to avoid job_id reuse.
# Generating a job id instead of persisting a BigQuery-generated one after client.query is called.
# Using BigQuery's job_id can lead to a race condition if a job has been started and a termination
# is sent before the job_id was stored, leading to a failure to cancel the job.
# By predetermining job_ids (uuid4), we can persist the job_id before the job has been kicked off.
# Doing this, the race condition only leads to attempting to cancel a job that doesn't exist.
job_id = str(uuid.uuid4())
thread_id = self.get_thread_identifier()
self.jobs_by_thread[thread_id].append(job_id)
return job_id

def raw_execute(
self,
sql,
Expand Down Expand Up @@ -488,10 +530,13 @@ def raw_execute(
job_execution_timeout = self.get_job_execution_timeout_seconds(conn)

def fn():
job_id = self.generate_job_id()

return self._query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=job_creation_timeout,
job_execution_timeout=job_execution_timeout,
limit=limit,
Expand Down Expand Up @@ -731,14 +776,17 @@ def _query_and_results(
client,
sql,
job_params,
job_id,
job_creation_timeout=None,
job_execution_timeout=None,
limit: Optional[int] = None,
):
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = google.cloud.bigquery.QueryJobConfig(**job_params)
query_job = client.query(query=sql, job_config=job_config, timeout=job_creation_timeout)
query_job = client.query(
query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout
)
if (
query_job.location is not None
and query_job.job_id is not None
Expand Down
7 changes: 5 additions & 2 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def date_function(cls) -> str:

@classmethod
def is_cancelable(cls) -> bool:
return False
return True

def drop_relation(self, relation: BigQueryRelation) -> None:
is_cached = self._schema_is_cached(relation.database, relation.schema)
Expand Down Expand Up @@ -693,8 +693,11 @@ def load_dataframe(
load_config.skip_leading_rows = 1
load_config.schema = bq_schema
load_config.field_delimiter = field_delimiter
job_id = self.connections.generate_job_id()
with open(agate_table.original_abspath, "rb") as f: # type: ignore
job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config)
job = client.load_table_from_file(
f, table_ref, rewind=True, job_config=load_config, job_id=job_id
)

timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300
with self.connections.exception_handler("LOAD TABLE"):
Expand Down
127 changes: 127 additions & 0 deletions tests/functional/test_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import time

import os
import signal
import subprocess

import pytest

from dbt.tests.util import get_connection

_SEED_CSV = """
id, name, astrological_sign, moral_alignment
1, Alice, Aries, Lawful Good
2, Bob, Taurus, Neutral Good
3, Thaddeus, Gemini, Chaotic Neutral
4, Zebulon, Cancer, Lawful Evil
5, Yorick, Leo, True Neutral
6, Xavier, Virgo, Chaotic Evil
7, Wanda, Libra, Lawful Neutral
"""

_LONG_RUNNING_MODEL_SQL = """
{{ config(materialized='table') }}
with array_1 as (
select generated_ids from UNNEST(GENERATE_ARRAY(1, 200000)) AS generated_ids
),
array_2 as (
select generated_ids from UNNEST(GENERATE_ARRAY(2, 200000)) AS generated_ids
)
SELECT array_1.generated_ids
FROM array_1
LEFT JOIN array_1 as jnd on 1=1
LEFT JOIN array_2 as jnd2 on 1=1
LEFT JOIN array_1 as jnd3 on jnd3.generated_ids >= jnd2.generated_ids
"""


def _get_info_schema_jobs_query(project_id, dataset_id, table_id):
"""
Running this query requires roles/bigquery.resourceViewer on the project,
see: https://cloud.google.com/bigquery/docs/information-schema-jobs#required_role
:param project_id:
:param dataset_id:
:param table_id:
:return: a single job id that matches the model we tried to create and was cancelled
"""
return f"""
SELECT job_id
FROM `region-us`.`INFORMATION_SCHEMA.JOBS_BY_PROJECT`
WHERE creation_time > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 5 HOUR)
AND statement_type = 'CREATE_TABLE_AS_SELECT'
AND state = 'DONE'
AND job_id IS NOT NULL
AND project_id = '{project_id}'
AND error_result.reason = 'stopped'
AND error_result.message = 'Job execution was cancelled: User requested cancellation'
AND destination_table.table_id = '{table_id}'
AND destination_table.dataset_id = '{dataset_id}'
"""


def _run_dbt_in_subprocess(project, dbt_command):
os.chdir(project.project_root)
run_dbt_process = subprocess.Popen(
[
"dbt",
dbt_command,
"--profiles-dir",
project.profiles_dir,
"--project-dir",
project.project_root,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=False,
)
std_out_log = ""
while True:
std_out_line = run_dbt_process.stdout.readline().decode("utf-8")
std_out_log += std_out_line
if std_out_line != "":
print(std_out_line)
if "1 of 1 START" in std_out_line:
time.sleep(1)
run_dbt_process.send_signal(signal.SIGINT)

if run_dbt_process.poll():
break

return std_out_log


def _get_job_id(project, table_name):
# Because we run this in a subprocess we have to actually call Bigquery and look up the job id
with get_connection(project.adapter):
job_id = project.run_sql(
_get_info_schema_jobs_query(project.database, project.test_schema, table_name)
)

return job_id


class TestBigqueryCancelsQueriesOnKeyboardInterrupt:
@pytest.fixture(scope="class", autouse=True)
def models(self):
return {
"model.sql": _LONG_RUNNING_MODEL_SQL,
}

@pytest.fixture(scope="class", autouse=True)
def seeds(self):
return {
"seed.csv": _SEED_CSV,
}

def test_bigquery_cancels_queries_for_model_on_keyboard_interrupt(self, project):
std_out_log = _run_dbt_in_subprocess(project, "run")

assert "CANCEL query model.test.model" in std_out_log
assert len(_get_job_id(project, "model")) == 1

def test_bigquery_cancels_queries_for_seed_on_keyboard_interrupt(self, project):
std_out_log = _run_dbt_in_subprocess(project, "seed")

assert "CANCEL query seed.test.seed" in std_out_log
# we can't assert the job id since we can't kill the seed process fast enough to cancel it
22 changes: 11 additions & 11 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
inject_adapter,
TestAdapterConversions,
load_internal_manifest_macros,
mock_connection,
)


Expand Down Expand Up @@ -368,23 +369,22 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):

def test_cancel_open_connections_empty(self):
adapter = self.get_adapter("oauth")
self.assertEqual(adapter.cancel_open_connections(), None)
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_master(self):
adapter = self.get_adapter("oauth")
adapter.connections.thread_connections[0] = object()
self.assertEqual(adapter.cancel_open_connections(), None)
key = adapter.connections.get_thread_identifier()
adapter.connections.thread_connections[key] = mock_connection("master")
self.assertEqual(len(list(adapter.cancel_open_connections())), 0)

def test_cancel_open_connections_single(self):
adapter = self.get_adapter("oauth")
adapter.connections.thread_connections.update(
{
0: object(),
1: object(),
}
)
# actually does nothing
self.assertEqual(adapter.cancel_open_connections(), None)
master = mock_connection("master")
model = mock_connection("model")
key = adapter.connections.get_thread_identifier()

adapter.connections.thread_connections.update({key: master, 1: model})
self.assertEqual(len(list(adapter.cancel_open_connections())), 1)

@patch("dbt.adapters.bigquery.impl.google.auth.default")
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,18 @@ def test_drop_dataset(self):

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results(self, mock_bq):
self.mock_client.query = Mock(return_value=Mock(state="DONE"))
self.connections._query_and_results(
self.mock_client,
"sql",
{"job_param_1": "blah"},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=3,
job_execution_timeout=100,
)

mock_bq.QueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), timeout=15
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
)

def test_copy_bq_table_appends(self):
Expand Down

0 comments on commit 3839953

Please sign in to comment.