From a2db35badbe8ab82a7fcf4e8a72a02269657fd36 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 22:31:23 -0500 Subject: [PATCH] align new retries with original methods, simplify retry factory --- dbt/adapters/bigquery/connections.py | 27 ++- dbt/adapters/bigquery/python_submissions.py | 63 +++---- dbt/adapters/bigquery/retry.py | 157 +++++++----------- .../unit/test_bigquery_connection_manager.py | 16 +- 4 files changed, 102 insertions(+), 161 deletions(-) diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 856e69485..952a83b04 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -140,12 +140,12 @@ def cancel_open(self): continue if connection.handle is not None and connection.state == ConnectionState.OPEN: - client = connection.handle + client: Client = connection.handle for job_id in self.jobs_by_thread.get(thread_id, []): with self.exception_handler(f"Cancel job: {job_id}"): client.cancel_job( job_id, - retry=self._retry.deadline(connection), + retry=self._retry.reopen_with_deadline(connection), ) self.close(connection) @@ -444,9 +444,8 @@ def copy_bq_table(self, source, destination, write_disposition) -> None: source_ref_array, destination_ref, job_config=CopyJobConfig(write_disposition=write_disposition), - retry=self._retry.deadline(conn), ) - copy_job.result(retry=self._retry.job_execution_capped(conn)) + copy_job.result(timeout=self._retry.job_execution_timeout(300)) def load_dataframe( self, @@ -497,7 +496,7 @@ def _load_table_from_file( with open(file_path, "rb") as f: job = client.load_table_from_file(f, table, rewind=True, job_config=config) - if not job.done(retry=self._retry.polling_done(fallback_timeout=fallback_timeout)): + if not job.done(retry=self._retry.retry(fallback_timeout=fallback_timeout)): raise DbtRuntimeError("BigQuery Timeout Exceeded") elif job.error_result: @@ -522,7 +521,7 @@ def get_bq_table(self, database, schema, identifier) -> Table: schema = schema or conn.credentials.schema return client.get_table( table=self.table_ref(database, schema, identifier), - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def drop_dataset(self, database, schema) -> None: @@ -533,7 +532,7 @@ def drop_dataset(self, database, schema) -> None: dataset=self.dataset_ref(database, schema), delete_contents=True, not_found_ok=True, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def create_dataset(self, database, schema) -> Dataset: @@ -543,7 +542,7 @@ def create_dataset(self, database, schema) -> Dataset: return client.create_dataset( dataset=self.dataset_ref(database, schema), exists_ok=True, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) def list_dataset(self, database: str): @@ -556,7 +555,7 @@ def list_dataset(self, database: str): all_datasets = client.list_datasets( project=database.strip("`"), max_results=10000, - retry=self._retry.deadline(conn), + retry=self._retry.reopen_with_deadline(conn), ) return [ds.dataset_id for ds in all_datasets] @@ -571,13 +570,11 @@ def _query_and_results( client: Client = conn.handle """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used - job_factory = QueryJobConfig - job_config = job_factory(**job_params) query_job = client.query( query=sql, - job_config=job_config, + job_config=QueryJobConfig(**job_params), job_id=job_id, # note, this disables retry since the job_id will have been used - timeout=self._retry.job_creation_timeout, + timeout=self._retry.job_creation_timeout(), ) if ( query_job.location is not None @@ -589,11 +586,11 @@ def _query_and_results( ) try: iterator = query_job.result( - max_results=limit, timeout=self._retry.job_execution_timeout + max_results=limit, timeout=self._retry.job_execution_timeout() ) return query_job, iterator except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout} seconds." + exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout()} seconds." raise TimeoutError(exc) def _labels_from_query_comment(self, comment: str) -> Dict: diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 98a8bee25..7118a67cc 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -4,7 +4,6 @@ from google.cloud.dataproc_v1 import ( Batch, CreateBatchRequest, - GetBatchRequest, Job, RuntimeConfig, ) @@ -30,11 +29,6 @@ class BaseDataProcHelper(PythonJobHelper): def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: - """_summary_ - - Args: - credentials (_type_): _description_ - """ # validate all additional stuff for python is set for required_config in ["dataproc_region", "gcs_bucket"]: if not getattr(credentials, required_config): @@ -83,23 +77,26 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None ) def _submit_dataproc_job(self) -> Job: - job = { - "placement": {"cluster_name": self._cluster_name}, - "pyspark_job": { - "main_python_file_uri": self._gcs_path, + request = { + "project_id": self._project, + "region": self._region, + "job": { + "placement": {"cluster_name": self._cluster_name}, + "pyspark_job": { + "main_python_file_uri": self._gcs_path, + }, }, } - operation = self._job_controller_client.submit_job_as_operation( - request={ - "project_id": self._project, - "region": self._region, - "job": job, - } - ) - # check if job failed + + # submit the job + operation = self._job_controller_client.submit_job_as_operation(request) + + # wait for the job to complete response = operation.result(polling=self._polling_retry) + if response.status.state == 6: raise ValueError(response.status.details) + return response @@ -114,29 +111,21 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None def _submit_dataproc_job(self) -> Batch: _logger.info(f"Submitting batch job with id: {self._batch_id}") - # make the request request = CreateBatchRequest( parent=f"projects/{self._project}/locations/{self._region}", - batch=self._configure_batch(), + batch=self._batch(), batch_id=self._batch_id, ) - self._batch_controller_client.create_batch(request=request) - - # return the response - batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}") - return self._batch_controller_client.get_batch(batch, retry=self._polling_retry) - # there might be useful results here that we can parse and return - # Dataproc job output is saved to the Cloud Storage bucket - # allocated to the job. Use regex to obtain the bucket and blob info. - # matches = re.match("gs://(.*?)/(.*)", response.driver_output_resource_uri) - # output = ( - # self.storage_client - # .get_bucket(matches.group(1)) - # .blob(f"{matches.group(2)}.000000000") - # .download_as_string() - # ) - - def _configure_batch(self) -> Batch: + + # submit the batch + operation = self._batch_controller_client.create_batch(request) + + # wait for the batch to complete + response = operation.result(polling=self._polling_retry) + + return response + + def _batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index e5c6abbb2..0f1be6f81 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,8 +1,8 @@ from typing import Callable, Optional -from google.api_core.retry import Retry from google.api_core.exceptions import Forbidden -from google.api_core.future.polling import POLLING_PREDICATE +from google.api_core.future.polling import DEFAULT_POLLING +from google.api_core.retry import Retry from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError @@ -17,6 +17,14 @@ _logger = AdapterLogger("BigQuery") +_ONE_DAY = 60 * 60 * 24 # seconds + + +_DEFAULT_INITIAL_DELAY = 1.0 # seconds +_DEFAULT_MAXIMUM_DELAY = 3.0 # seconds +_DEFAULT_POLLING_MAXIMUM_DELAY = 10.0 # seconds + + _REOPENABLE_ERRORS = ( ConnectionResetError, ConnectionError, @@ -32,53 +40,24 @@ ) -_ONE_DAY = 60 * 60 * 24 - - class RetryFactory: - _DEFAULT_INITIAL_DELAY = 1.0 # seconds - _DEFAULT_MAXIMUM_DELAY = 3.0 # seconds - def __init__(self, credentials: BigQueryCredentials) -> None: self._retries = credentials.job_retries or 0 - self.job_creation_timeout = credentials.job_creation_timeout_seconds - self.job_execution_timeout = credentials.job_execution_timeout_seconds - self.job_deadline = credentials.job_retry_deadline_seconds + self._job_creation_timeout = credentials.job_creation_timeout_seconds + self._job_execution_timeout = credentials.job_execution_timeout_seconds + self._job_deadline = credentials.job_retry_deadline_seconds - def deadline(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - initial=self._DEFAULT_INITIAL_DELAY, - maximum=self._DEFAULT_MAXIMUM_DELAY, - timeout=self.job_deadline, - on_error=_on_error(connection), - ) + def job_creation_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + return self._job_creation_timeout or fallback or _ONE_DAY - def job_execution(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - initial=self._DEFAULT_INITIAL_DELAY, - maximum=self._DEFAULT_MAXIMUM_DELAY, - timeout=self.job_execution_timeout, - on_error=_on_error(connection), - ) + def job_execution_timeout(self, fallback: Optional[float] = None) -> Optional[float]: + return self._job_execution_timeout or fallback or _ONE_DAY - def job_execution_capped(self, connection: Connection) -> Retry: - """ - This strategy mimics what was accomplished with _retry_and_handle - """ - return Retry( - predicate=self._buffered_predicate(), - timeout=self.job_execution_timeout or 300, - on_error=_on_error(connection), - ) + def retry( + self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None + ) -> Retry: + return DEFAULT_RETRY.with_timeout(timeout or self.job_execution_timeout(fallback_timeout)) def polling( self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None @@ -86,55 +65,53 @@ def polling( """ This strategy mimics what was accomplished with _retry_and_handle """ - return Retry( - predicate=POLLING_PREDICATE, - minimum=1.0, - maximum=10.0, - timeout=timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY, + return DEFAULT_POLLING.with_timeout( + timeout or self.job_execution_timeout(fallback_timeout) ) - def polling_done( - self, timeout: Optional[float] = None, fallback_timeout: Optional[float] = None - ) -> Retry: - return DEFAULT_RETRY.with_timeout( - timeout or self.job_execution_timeout or fallback_timeout or _ONE_DAY + def reopen_with_deadline(self, connection: Connection) -> Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return Retry( + predicate=_BufferedPredicate(self._retries), + initial=_DEFAULT_INITIAL_DELAY, + maximum=_DEFAULT_MAXIMUM_DELAY, + deadline=self._job_deadline, + on_error=_reopen_on_error(connection), ) - def _buffered_predicate(self) -> Callable[[Exception], bool]: - class BufferedPredicate: - """ - Count ALL errors, not just retryable errors, up to a threshold - then raises the next error, regardless of whether it is retryable. - - Was previously called _ErrorCounter. - """ - def __init__(self, retries: int) -> None: - self._retries: int = retries - self._error_count = 0 +class _BufferedPredicate: + """ + Count ALL errors, not just retryable errors, up to a threshold. + Raise the next error, regardless of whether it is retryable. + """ - def __call__(self, error: Exception) -> bool: - # exit immediately if the user does not want retries - if self._retries == 0: - return False + def __init__(self, retries: int) -> None: + self._retries: int = retries + self._error_count = 0 - # count all errors - self._error_count += 1 + def __call__(self, error: Exception) -> bool: + # exit immediately if the user does not want retries + if self._retries == 0: + return False - # if the error is retryable and we haven't breached the threshold, log and continue - if _is_retryable(error) and self._error_count <= self._retries: - _logger.debug( - f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" - ) - return True + # count all errors + self._error_count += 1 - # otherwise raise - return False + # if the error is retryable, and we haven't breached the threshold, log and continue + if _is_retryable(error) and self._error_count <= self._retries: + _logger.debug( + f"Retry attempt {self._error_count} of { self._retries} after error: {repr(error)}" + ) + return True - return BufferedPredicate(self._retries) + # otherwise raise + return False -def _on_error(connection: Connection) -> Callable[[Exception], None]: +def _reopen_on_error(connection: Connection) -> Callable[[Exception], None]: def on_error(error: Exception): if isinstance(error, _REOPENABLE_ERRORS): @@ -165,25 +142,3 @@ def _is_retryable(error: Exception) -> bool: ): return True return False - - -class _BufferedPredicate: - """Counts errors seen up to a threshold then raises the next error.""" - - def __init__(self, retries: int) -> None: - self._retries = retries - self._error_count = 0 - - def count_error(self, error): - if self._retries == 0: - return False # Don't log - self._error_count += 1 - if _is_retryable(error) and self._error_count <= self._retries: - _logger.debug( - "Retry attempt {} of {} after error: {}".format( - self._error_count, self._retries, repr(error) - ) - ) - return True - else: - return False diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 19e0e1ab4..6775445b9 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -17,7 +17,7 @@ def setUp(self): self.credentials = Mock(BigQueryCredentials) self.credentials.method = "oauth" self.credentials.job_retries = 1 - self.credentials.job_execution_timeout_seconds = 1 + self.credentials.job_retry_deadline_seconds = 1 self.credentials.scopes = tuple() self.mock_client = Mock(google.cloud.bigquery.Client) @@ -36,18 +36,21 @@ def setUp(self): "dbt.adapters.bigquery.retry.bigquery_client", return_value=Mock(google.cloud.bigquery.Client), ) - def test_retry_connection_reset(self, mock_bigquery_client): - original_handle = self.mock_connection.handle + def test_retry_connection_reset(self, mock_client_factory): + new_mock_client = mock_client_factory.return_value - @self.connections._retry.job_execution(self.mock_connection) + @self.connections._retry.reopen_with_deadline(self.mock_connection) def generate_connection_reset_error(): raise ConnectionResetError + assert self.mock_connection.handle is self.mock_client + with self.assertRaises(ConnectionResetError): # this will always raise the error, we just want to test that the connection was reopening in between generate_connection_reset_error() - assert not self.mock_connection.handle is original_handle + assert self.mock_connection.handle is new_mock_client + assert new_mock_client is not self.mock_client def test_is_retryable(self): _is_retryable = dbt.adapters.bigquery.retry._is_retryable @@ -98,12 +101,10 @@ def test_query_and_results(self, MockQueryJobConfig): def test_copy_bq_table_appends(self): self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) - args, kwargs = self.mock_client.copy_table.call_args self.mock_client.copy_table.assert_called_once_with( [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, - retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -117,7 +118,6 @@ def test_copy_bq_table_truncates(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, - retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual(