From 9a9f87ec08eceaf7d13dfcf7a413835a243e209f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 6 Nov 2024 13:22:25 -0500 Subject: [PATCH] move polling from manual check in python_submissions module into retry_factory --- dbt/adapters/bigquery/dataproc/__init__.py | 0 dbt/adapters/bigquery/dataproc/batch.py | 68 ------- dbt/adapters/bigquery/python_submissions.py | 196 ++++++++++---------- dbt/adapters/bigquery/retry.py | 14 +- tests/unit/test_configure_dataproc_batch.py | 4 +- 5 files changed, 111 insertions(+), 171 deletions(-) delete mode 100644 dbt/adapters/bigquery/dataproc/__init__.py delete mode 100644 dbt/adapters/bigquery/dataproc/batch.py diff --git a/dbt/adapters/bigquery/dataproc/__init__.py b/dbt/adapters/bigquery/dataproc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbt/adapters/bigquery/dataproc/batch.py b/dbt/adapters/bigquery/dataproc/batch.py deleted file mode 100644 index 59f40d246..000000000 --- a/dbt/adapters/bigquery/dataproc/batch.py +++ /dev/null @@ -1,68 +0,0 @@ -from datetime import datetime -import time -from typing import Dict, Union - -from google.cloud.dataproc_v1 import ( - Batch, - BatchControllerClient, - CreateBatchRequest, - GetBatchRequest, -) -from google.protobuf.json_format import ParseDict - -from dbt.adapters.bigquery.credentials import DataprocBatchConfig - - -_BATCH_RUNNING_STATES = [Batch.State.PENDING, Batch.State.RUNNING] -DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" - - -def create_batch_request( - batch: Batch, batch_id: str, project: str, region: str -) -> CreateBatchRequest: - return CreateBatchRequest( - parent=f"projects/{project}/locations/{region}", - batch_id=batch_id, - batch=batch, - ) - - -def poll_batch_job( - parent: str, batch_id: str, job_client: BatchControllerClient, timeout: int -) -> Batch: - batch_name = "".join([parent, "/batches/", batch_id]) - state = Batch.State.PENDING - response = None - run_time = 0 - while state in _BATCH_RUNNING_STATES and run_time < timeout: - time.sleep(1) - response = job_client.get_batch( - request=GetBatchRequest(name=batch_name), - ) - run_time = datetime.now().timestamp() - response.create_time.timestamp() - state = response.state - if not response: - raise ValueError("No response from Dataproc") - if state != Batch.State.SUCCEEDED: - if run_time >= timeout: - raise ValueError( - f"Operation did not complete within the designated timeout of {timeout} seconds." - ) - else: - raise ValueError(response.state_message) - return response - - -def update_batch_from_config(config_dict: Union[Dict, DataprocBatchConfig], target: Batch): - try: - # updates in place - ParseDict(config_dict, target._pb) - except Exception as e: - docurl = ( - "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" - "#google.cloud.dataproc.v1.Batch" - ) - raise ValueError( - f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" - ) from e - return target diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index d5b7c1115..e76f6dc13 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -1,150 +1,131 @@ -import uuid from typing import Dict, Union +import uuid -from google.api_core import retry -from google.api_core.client_options import ClientOptions -from google.api_core.future.polling import POLLING_PREDICATE -from google.cloud import dataproc_v1 -from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient -from google.cloud.dataproc_v1.types import Batch, Job +from google.cloud.dataproc_v1 import ( + Batch, + CreateBatchRequest, + GetBatchRequest, + Job, + RuntimeConfig, +) from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger +from google.protobuf.json_format import ParseDict -from dbt.adapters.bigquery.credentials import BigQueryCredentials +from dbt.adapters.bigquery.credentials import BigQueryCredentials, DataprocBatchConfig from dbt.adapters.bigquery.clients import ( batch_controller_client, job_controller_client, storage_client, ) -from dbt.adapters.bigquery.dataproc.batch import ( - DEFAULT_JAR_FILE_URI, - create_batch_request, - poll_batch_job, - update_batch_from_config, -) +from dbt.adapters.bigquery.retry import RetryFactory + -OPERATION_RETRY_TIME = 10 -logger = AdapterLogger("BigQuery") +_logger = AdapterLogger("BigQuery") + + +_DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" class BaseDataProcHelper(PythonJobHelper): - def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None: + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: """_summary_ Args: - credential (_type_): _description_ + credentials (_type_): _description_ """ # validate all additional stuff for python is set - schema = parsed_model["schema"] - identifier = parsed_model["alias"] - self.parsed_model = parsed_model - python_required_configs = [ - "dataproc_region", - "gcs_bucket", - ] - for required_config in python_required_configs: - if not getattr(credential, required_config): + for required_config in ["dataproc_region", "gcs_bucket"]: + if not getattr(credentials, required_config): raise ValueError( f"Need to supply {required_config} in profile to submit python job" ) - self.model_file_name = f"{schema}/{identifier}.py" - self.credential = credential - self.storage_client = storage_client(self.credential) - self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name) + + self._storage_client = storage_client(credentials) + self._project = credentials.execution_project + self._region = credentials.dataproc_region + + schema = parsed_model["schema"] + identifier = parsed_model["alias"] + self._model_file_name = f"{schema}/{identifier}.py" + self._gcs_bucket = credentials.gcs_bucket + self._gcs_path = f"gs://{credentials.gcs_bucket}/{self._model_file_name}" # set retry policy, default to timeout after 24 hours - self.timeout = self.parsed_model["config"].get( - "timeout", self.credential.job_execution_timeout_seconds or 60 * 60 * 24 - ) - self.result_polling_policy = retry.Retry( - predicate=POLLING_PREDICATE, maximum=10.0, timeout=self.timeout - ) - self.client_options = ClientOptions( - api_endpoint="{}-dataproc.googleapis.com:443".format(self.credential.dataproc_region) - ) - self.job_client = self._get_job_client() + retry = RetryFactory(credentials) + timeout = parsed_model["config"].get("timeout") + self._polling_retry = retry.polling(timeout) - def _upload_to_gcs(self, filename: str, compiled_code: str) -> None: - bucket = self.storage_client.get_bucket(self.credential.gcs_bucket) - blob = bucket.blob(filename) + def _upload_to_gcs(self, compiled_code: str) -> None: + bucket = self._storage_client.get_bucket(self._gcs_bucket) + blob = bucket.blob(self._model_file_name) blob.upload_from_string(compiled_code) def submit(self, compiled_code: str) -> Job: - # upload python file to GCS - self._upload_to_gcs(self.model_file_name, compiled_code) - # submit dataproc job + self._upload_to_gcs(compiled_code) return self._submit_dataproc_job() - def _get_job_client( - self, - ) -> Union[JobControllerClient, BatchControllerClient]: - raise NotImplementedError("_get_job_client not implemented") - def _submit_dataproc_job(self) -> Job: raise NotImplementedError("_submit_dataproc_job not implemented") class ClusterDataprocHelper(BaseDataProcHelper): - def _get_job_client(self) -> JobControllerClient: - if not self._get_cluster_name(): + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._job_controller_client = job_controller_client(credentials) + self._cluster_name = parsed_model["config"].get( + "dataproc_cluster_name", credentials.dataproc_cluster_name + ) + + if not self._cluster_name: raise ValueError( "Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method" ) - return job_controller_client(self.credential) - - def _get_cluster_name(self) -> str: - return self.parsed_model["config"].get( - "dataproc_cluster_name", self.credential.dataproc_cluster_name - ) def _submit_dataproc_job(self) -> Job: job = { - "placement": {"cluster_name": self._get_cluster_name()}, + "placement": {"cluster_name": self._cluster_name}, "pyspark_job": { - "main_python_file_uri": self.gcs_location, + "main_python_file_uri": self._gcs_path, }, } - operation = self.job_client.submit_job_as_operation( + operation = self._job_controller_client.submit_job_as_operation( request={ - "project_id": self.credential.execution_project, - "region": self.credential.dataproc_region, + "project_id": self._project, + "region": self._region, "job": job, } ) # check if job failed - response = operation.result(polling=self.result_polling_policy) + response = operation.result(polling=self._polling_retry) if response.status.state == 6: raise ValueError(response.status.details) return response class ServerlessDataProcHelper(BaseDataProcHelper): - def _get_job_client(self) -> BatchControllerClient: - return batch_controller_client(self.credential) - - def _get_batch_id(self) -> str: - model = self.parsed_model - default_batch_id = str(uuid.uuid4()) - return model["config"].get("batch_id", default_batch_id) + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._batch_controller_client = batch_controller_client(credentials) + self._batch_id = parsed_model["config"].get("batch_id", str(uuid.uuid4())) + self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI) + self._dataproc_batch = credentials.dataproc_batch def _submit_dataproc_job(self) -> Batch: - batch_id = self._get_batch_id() - logger.info(f"Submitting batch job with id: {batch_id}") - request = create_batch_request( - batch=self._configure_batch(), - batch_id=batch_id, - region=self.credential.dataproc_region, # type: ignore - project=self.credential.execution_project, # type: ignore - ) + _logger.info(f"Submitting batch job with id: {self._batch_id}") + # make the request - self.job_client.create_batch(request=request) - return poll_batch_job( - parent=request.parent, - batch_id=batch_id, - job_client=self.job_client, - timeout=self.timeout, + request = CreateBatchRequest( + parent=f"projects/{self._project}/locations/{self._region}", + batch=self._configure_batch(), + batch_id=self._batch_id, ) + self._batch_controller_client.create_batch(request=request) + + # return the response + batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}") + return self._batch_controller_client.get_batch(batch, retry=self._polling_retry) # there might be useful results here that we can parse and return # Dataproc job output is saved to the Cloud Storage bucket # allocated to the job. Use regex to obtain the bucket and blob info. @@ -156,30 +137,45 @@ def _submit_dataproc_job(self) -> Batch: # .download_as_string() # ) - def _configure_batch(self): + def _configure_batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties # https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig - batch = dataproc_v1.Batch( + batch = Batch( { - "runtime_config": dataproc_v1.RuntimeConfig( + "runtime_config": RuntimeConfig( version="1.1", properties={ "spark.executor.instances": "2", }, - ) + ), + "pyspark_batch": { + "main_python_file_uri": self._gcs_path, + "jar_file_uris": [self._jar_file_uri], + }, } ) - # Apply defaults - batch.pyspark_batch.main_python_file_uri = self.gcs_location - jar_file_uri = self.parsed_model["config"].get( - "jar_file_uri", - DEFAULT_JAR_FILE_URI, - ) - batch.pyspark_batch.jar_file_uris = [jar_file_uri] # Apply configuration from dataproc_batch key, possibly overriding defaults. - if self.credential.dataproc_batch: - batch = update_batch_from_config(self.credential.dataproc_batch, batch) + if self._dataproc_batch: + batch = _update_batch_from_config(self._dataproc_batch, batch) + return batch + + +def _update_batch_from_config( + config_dict: Union[Dict, DataprocBatchConfig], target: Batch +) -> Batch: + try: + # updates in place + ParseDict(config_dict, target._pb) + except Exception as e: + docurl = ( + "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" + "#google.cloud.dataproc.v1.Batch" + ) + raise ValueError( + f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" + ) from e + return target diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py index 0b1541805..4143c54e5 100644 --- a/dbt/adapters/bigquery/retry.py +++ b/dbt/adapters/bigquery/retry.py @@ -1,7 +1,8 @@ -from typing import Callable +from typing import Callable, Optional from google.api_core import retry from google.api_core.exceptions import Forbidden +from google.api_core.future.polling import POLLING_PREDICATE from google.cloud.exceptions import BadGateway, BadRequest, ServerError from requests.exceptions import ConnectionError @@ -75,6 +76,17 @@ def job_execution_capped(self, connection: Connection) -> retry.Retry: on_error=_on_error(connection), ) + def polling(self, timeout: Optional[float] = None) -> retry.Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return retry.Retry( + predicate=POLLING_PREDICATE, + minimum=1.0, + maximum=10.0, + timeout=timeout or self.job_execution_timeout or 60 * 60 * 24, + ) + def _buffered_predicate(self) -> Callable[[Exception], bool]: class BufferedPredicate: """ diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index 19a0d3012..e73e5b845 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config +from dbt.adapters.bigquery.python_submissions import _update_batch_from_config from google.cloud import dataproc_v1 from .test_bigquery_adapter import BaseTestBigQueryAdapter @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): batch = dataproc_v1.Batch() - batch = update_batch_from_config(raw_batch_config, batch) + batch = _update_batch_from_config(raw_batch_config, batch) def to_str_values(d): """google's protobuf types expose maps as dict[str, str]"""