diff --git a/.changes/unreleased/Fixes-20230721-101041.yaml b/.changes/unreleased/Fixes-20230721-101041.yaml new file mode 100644 index 000000000..6db81cf50 --- /dev/null +++ b/.changes/unreleased/Fixes-20230721-101041.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Serverless Spark to Poll with .GetBatch() instead of using operation.result() +time: 2023-07-21T10:10:41.64843-07:00 +custom: + Author: wazi55 + Issue: "734" diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 99f78e33d..bb0211b35 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -64,6 +64,7 @@ jobs: outputs: matrix: ${{ steps.generate-matrix.outputs.result }} + run-python-tests: ${{ steps.filter.outputs.bigquery-python }} steps: - name: Check out the repository (non-PR) @@ -96,6 +97,11 @@ jobs: - 'dbt/**' - 'tests/**' - 'dev-requirements.txt' + bigquery-python: + - 'dbt/adapters/bigquery/dataproc/**' + - 'dbt/adapters/bigquery/python_submissions.py' + - 'dbt/include/bigquery/python_model/**' + - name: Generate integration test matrix id: generate-matrix uses: actions/github-script@v6 @@ -186,6 +192,21 @@ jobs: GCS_BUCKET: dbt-ci run: tox -- --ddtrace + # python models tests are slow so we only want to run them if we're changing them + - name: Run tox (python models) + if: needs.test-metadata.outputs.run-python-tests == 'true' + env: + BIGQUERY_TEST_SERVICE_ACCOUNT_JSON: ${{ secrets.BIGQUERY_TEST_SERVICE_ACCOUNT_JSON }} + BIGQUERY_TEST_ALT_DATABASE: ${{ secrets.BIGQUERY_TEST_ALT_DATABASE }} + BIGQUERY_TEST_NO_ACCESS_DATABASE: ${{ secrets.BIGQUERY_TEST_NO_ACCESS_DATABASE }} + DBT_TEST_USER_1: group:buildbot@dbtlabs.com + DBT_TEST_USER_2: group:engineering-core-team@dbtlabs.com + DBT_TEST_USER_3: serviceAccount:dbt-integration-test-user@dbt-test-env.iam.gserviceaccount.com + DATAPROC_REGION: us-central1 + DATAPROC_CLUSTER_NAME: dbt-test-1 + GCS_BUCKET: dbt-ci + run: tox -e python-tests -- --ddtrace + - uses: actions/upload-artifact@v3 if: always() with: diff --git a/dbt/adapters/bigquery/dataproc/__init__.py b/dbt/adapters/bigquery/dataproc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbt/adapters/bigquery/dataproc/batch.py b/dbt/adapters/bigquery/dataproc/batch.py new file mode 100644 index 000000000..0dc54aa78 --- /dev/null +++ b/dbt/adapters/bigquery/dataproc/batch.py @@ -0,0 +1,67 @@ +from typing import Union, Dict + +import time +from datetime import datetime +from google.cloud.dataproc_v1 import ( + CreateBatchRequest, + BatchControllerClient, + Batch, + GetBatchRequest, +) +from google.protobuf.json_format import ParseDict + +from dbt.adapters.bigquery.connections import DataprocBatchConfig + +_BATCH_RUNNING_STATES = [Batch.State.PENDING, Batch.State.RUNNING] +DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar" + + +def create_batch_request( + batch: Batch, batch_id: str, project: str, region: str +) -> CreateBatchRequest: + return CreateBatchRequest( + parent=f"projects/{project}/locations/{region}", # type: ignore + batch_id=batch_id, # type: ignore + batch=batch, # type: ignore + ) + + +def poll_batch_job( + parent: str, batch_id: str, job_client: BatchControllerClient, timeout: int +) -> Batch: + batch_name = "".join([parent, "/batches/", batch_id]) + state = Batch.State.PENDING + response = None + run_time = 0 + while state in _BATCH_RUNNING_STATES and run_time < timeout: + time.sleep(1) + response = job_client.get_batch( # type: ignore + request=GetBatchRequest(name=batch_name), # type: ignore + ) + run_time = datetime.now().timestamp() - response.create_time.timestamp() # type: ignore + state = response.state + if not response: + raise ValueError("No response from Dataproc") + if state != Batch.State.SUCCEEDED: + if run_time >= timeout: + raise ValueError( + f"Operation did not complete within the designated timeout of {timeout} seconds." + ) + else: + raise ValueError(response.state_message) + return response + + +def update_batch_from_config(config_dict: Union[Dict, DataprocBatchConfig], target: Batch): + try: + # updates in place + ParseDict(config_dict, target._pb) + except Exception as e: + docurl = ( + "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" + "#google.cloud.dataproc.v1.Batch" + ) + raise ValueError( + f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" + ) from e + return target diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 6e5a11e52..8fd354eb5 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -4,11 +4,17 @@ from google.api_core.future.polling import POLLING_PREDICATE from dbt.adapters.bigquery import BigQueryConnectionManager, BigQueryCredentials -from dbt.adapters.bigquery.connections import DataprocBatchConfig from google.api_core import retry from google.api_core.client_options import ClientOptions from google.cloud import storage, dataproc_v1 # type: ignore -from google.protobuf.json_format import ParseDict +from google.cloud.dataproc_v1.types.batches import Batch + +from dbt.adapters.bigquery.dataproc.batch import ( + create_batch_request, + poll_batch_job, + DEFAULT_JAR_FILE_URI, + update_batch_from_config, +) OPERATION_RETRY_TIME = 10 @@ -102,8 +108,8 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: "job": job, } ) - response = operation.result(polling=self.result_polling_policy) # check if job failed + response = operation.result(polling=self.result_polling_policy) if response.status.state == 6: raise ValueError(response.status.details) return response @@ -118,21 +124,22 @@ def _get_job_client(self) -> dataproc_v1.BatchControllerClient: def _get_batch_id(self) -> str: return self.parsed_model["config"].get("batch_id") - def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: - batch = self._configure_batch() - parent = f"projects/{self.credential.execution_project}/locations/{self.credential.dataproc_region}" - - request = dataproc_v1.CreateBatchRequest( - parent=parent, - batch=batch, - batch_id=self._get_batch_id(), - ) + def _submit_dataproc_job(self) -> Batch: + batch_id = self._get_batch_id() + request = create_batch_request( + batch=self._configure_batch(), + batch_id=batch_id, + region=self.credential.dataproc_region, # type: ignore + project=self.credential.execution_project, # type: ignore + ) # type: ignore # make the request - operation = self.job_client.create_batch(request=request) # type: ignore - # this takes quite a while, waiting on GCP response to resolve - # (not a google-api-core issue, more likely a dataproc serverless issue) - response = operation.result(polling=self.result_polling_policy) - return response + self.job_client.create_batch(request=request) # type: ignore + return poll_batch_job( + parent=request.parent, + batch_id=batch_id, + job_client=self.job_client, # type: ignore + timeout=self.timeout, + ) # there might be useful results here that we can parse and return # Dataproc job output is saved to the Cloud Storage bucket # allocated to the job. Use regex to obtain the bucket and blob info. @@ -163,27 +170,11 @@ def _configure_batch(self): batch.pyspark_batch.main_python_file_uri = self.gcs_location jar_file_uri = self.parsed_model["config"].get( "jar_file_uri", - "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar", + DEFAULT_JAR_FILE_URI, ) batch.pyspark_batch.jar_file_uris = [jar_file_uri] # Apply configuration from dataproc_batch key, possibly overriding defaults. if self.credential.dataproc_batch: - self._update_batch_from_config(self.credential.dataproc_batch, batch) + batch = update_batch_from_config(self.credential.dataproc_batch, batch) return batch - - @classmethod - def _update_batch_from_config( - cls, config_dict: Union[Dict, DataprocBatchConfig], target: dataproc_v1.Batch - ): - try: - # updates in place - ParseDict(config_dict, target._pb) - except Exception as e: - docurl = ( - "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" - "#google.cloud.dataproc.v1.Batch" - ) - raise ValueError( - f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" - ) from e diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index 241082cdb..b389fe8aa 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -66,7 +66,7 @@ def model(dbt, spark): """ models__python_array_batch_id_python = """ -import pandas +import pandas as pd def model(dbt, spark): random_array = [ diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index 58ff52bab..94cb28efb 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from dbt.adapters.bigquery.python_submissions import ServerlessDataProcHelper +from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config from google.cloud import dataproc_v1 from .test_bigquery_adapter import BaseTestBigQueryAdapter @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): batch = dataproc_v1.Batch() - ServerlessDataProcHelper._update_batch_from_config(raw_batch_config, batch) + batch = update_batch_from_config(raw_batch_config, batch) def to_str_values(d): """google's protobuf types expose maps as dict[str, str]"""