Skip to content

Commit

Permalink
move polling from manual check in python_submissions module into retr…
Browse files Browse the repository at this point in the history
…y_factory
  • Loading branch information
mikealfare committed Nov 6, 2024
1 parent bc0fbea commit 9a9f87e
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 171 deletions.
Empty file.
68 changes: 0 additions & 68 deletions dbt/adapters/bigquery/dataproc/batch.py

This file was deleted.

196 changes: 96 additions & 100 deletions dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,131 @@
import uuid
from typing import Dict, Union
import uuid

from google.api_core import retry
from google.api_core.client_options import ClientOptions
from google.api_core.future.polling import POLLING_PREDICATE
from google.cloud import dataproc_v1
from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient
from google.cloud.dataproc_v1.types import Batch, Job
from google.cloud.dataproc_v1 import (
Batch,
CreateBatchRequest,
GetBatchRequest,
Job,
RuntimeConfig,
)

from dbt.adapters.base import PythonJobHelper
from dbt.adapters.events.logging import AdapterLogger
from google.protobuf.json_format import ParseDict

from dbt.adapters.bigquery.credentials import BigQueryCredentials
from dbt.adapters.bigquery.credentials import BigQueryCredentials, DataprocBatchConfig
from dbt.adapters.bigquery.clients import (
batch_controller_client,
job_controller_client,
storage_client,
)
from dbt.adapters.bigquery.dataproc.batch import (
DEFAULT_JAR_FILE_URI,
create_batch_request,
poll_batch_job,
update_batch_from_config,
)
from dbt.adapters.bigquery.retry import RetryFactory


OPERATION_RETRY_TIME = 10
logger = AdapterLogger("BigQuery")
_logger = AdapterLogger("BigQuery")


_DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar"


class BaseDataProcHelper(PythonJobHelper):
def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None:
def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None:
"""_summary_
Args:
credential (_type_): _description_
credentials (_type_): _description_
"""
# validate all additional stuff for python is set
schema = parsed_model["schema"]
identifier = parsed_model["alias"]
self.parsed_model = parsed_model
python_required_configs = [
"dataproc_region",
"gcs_bucket",
]
for required_config in python_required_configs:
if not getattr(credential, required_config):
for required_config in ["dataproc_region", "gcs_bucket"]:
if not getattr(credentials, required_config):
raise ValueError(
f"Need to supply {required_config} in profile to submit python job"
)
self.model_file_name = f"{schema}/{identifier}.py"
self.credential = credential
self.storage_client = storage_client(self.credential)
self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name)

self._storage_client = storage_client(credentials)
self._project = credentials.execution_project
self._region = credentials.dataproc_region

schema = parsed_model["schema"]
identifier = parsed_model["alias"]
self._model_file_name = f"{schema}/{identifier}.py"
self._gcs_bucket = credentials.gcs_bucket
self._gcs_path = f"gs://{credentials.gcs_bucket}/{self._model_file_name}"

# set retry policy, default to timeout after 24 hours
self.timeout = self.parsed_model["config"].get(
"timeout", self.credential.job_execution_timeout_seconds or 60 * 60 * 24
)
self.result_polling_policy = retry.Retry(
predicate=POLLING_PREDICATE, maximum=10.0, timeout=self.timeout
)
self.client_options = ClientOptions(
api_endpoint="{}-dataproc.googleapis.com:443".format(self.credential.dataproc_region)
)
self.job_client = self._get_job_client()
retry = RetryFactory(credentials)
timeout = parsed_model["config"].get("timeout")
self._polling_retry = retry.polling(timeout)

def _upload_to_gcs(self, filename: str, compiled_code: str) -> None:
bucket = self.storage_client.get_bucket(self.credential.gcs_bucket)
blob = bucket.blob(filename)
def _upload_to_gcs(self, compiled_code: str) -> None:
bucket = self._storage_client.get_bucket(self._gcs_bucket)
blob = bucket.blob(self._model_file_name)
blob.upload_from_string(compiled_code)

def submit(self, compiled_code: str) -> Job:
# upload python file to GCS
self._upload_to_gcs(self.model_file_name, compiled_code)
# submit dataproc job
self._upload_to_gcs(compiled_code)
return self._submit_dataproc_job()

def _get_job_client(
self,
) -> Union[JobControllerClient, BatchControllerClient]:
raise NotImplementedError("_get_job_client not implemented")

def _submit_dataproc_job(self) -> Job:
raise NotImplementedError("_submit_dataproc_job not implemented")


class ClusterDataprocHelper(BaseDataProcHelper):
def _get_job_client(self) -> JobControllerClient:
if not self._get_cluster_name():
def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None:
super().__init__(parsed_model, credentials)
self._job_controller_client = job_controller_client(credentials)
self._cluster_name = parsed_model["config"].get(
"dataproc_cluster_name", credentials.dataproc_cluster_name
)

if not self._cluster_name:
raise ValueError(
"Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method"
)
return job_controller_client(self.credential)

def _get_cluster_name(self) -> str:
return self.parsed_model["config"].get(
"dataproc_cluster_name", self.credential.dataproc_cluster_name
)

def _submit_dataproc_job(self) -> Job:
job = {
"placement": {"cluster_name": self._get_cluster_name()},
"placement": {"cluster_name": self._cluster_name},
"pyspark_job": {
"main_python_file_uri": self.gcs_location,
"main_python_file_uri": self._gcs_path,
},
}
operation = self.job_client.submit_job_as_operation(
operation = self._job_controller_client.submit_job_as_operation(
request={
"project_id": self.credential.execution_project,
"region": self.credential.dataproc_region,
"project_id": self._project,
"region": self._region,
"job": job,
}
)
# check if job failed
response = operation.result(polling=self.result_polling_policy)
response = operation.result(polling=self._polling_retry)
if response.status.state == 6:
raise ValueError(response.status.details)
return response


class ServerlessDataProcHelper(BaseDataProcHelper):
def _get_job_client(self) -> BatchControllerClient:
return batch_controller_client(self.credential)

def _get_batch_id(self) -> str:
model = self.parsed_model
default_batch_id = str(uuid.uuid4())
return model["config"].get("batch_id", default_batch_id)
def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None:
super().__init__(parsed_model, credentials)
self._batch_controller_client = batch_controller_client(credentials)
self._batch_id = parsed_model["config"].get("batch_id", str(uuid.uuid4()))
self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI)
self._dataproc_batch = credentials.dataproc_batch

def _submit_dataproc_job(self) -> Batch:
batch_id = self._get_batch_id()
logger.info(f"Submitting batch job with id: {batch_id}")
request = create_batch_request(
batch=self._configure_batch(),
batch_id=batch_id,
region=self.credential.dataproc_region, # type: ignore
project=self.credential.execution_project, # type: ignore
)
_logger.info(f"Submitting batch job with id: {self._batch_id}")

# make the request
self.job_client.create_batch(request=request)
return poll_batch_job(
parent=request.parent,
batch_id=batch_id,
job_client=self.job_client,
timeout=self.timeout,
request = CreateBatchRequest(
parent=f"projects/{self._project}/locations/{self._region}",
batch=self._configure_batch(),
batch_id=self._batch_id,
)
self._batch_controller_client.create_batch(request=request)

# return the response
batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}")
return self._batch_controller_client.get_batch(batch, retry=self._polling_retry)
# there might be useful results here that we can parse and return
# Dataproc job output is saved to the Cloud Storage bucket
# allocated to the job. Use regex to obtain the bucket and blob info.
Expand All @@ -156,30 +137,45 @@ def _submit_dataproc_job(self) -> Batch:
# .download_as_string()
# )

def _configure_batch(self):
def _configure_batch(self) -> Batch:
# create the Dataproc Serverless job config
# need to pin dataproc version to 1.1 as it now defaults to 2.0
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch = dataproc_v1.Batch(
batch = Batch(
{
"runtime_config": dataproc_v1.RuntimeConfig(
"runtime_config": RuntimeConfig(
version="1.1",
properties={
"spark.executor.instances": "2",
},
)
),
"pyspark_batch": {
"main_python_file_uri": self._gcs_path,
"jar_file_uris": [self._jar_file_uri],
},
}
)
# Apply defaults
batch.pyspark_batch.main_python_file_uri = self.gcs_location
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
DEFAULT_JAR_FILE_URI,
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]

# Apply configuration from dataproc_batch key, possibly overriding defaults.
if self.credential.dataproc_batch:
batch = update_batch_from_config(self.credential.dataproc_batch, batch)
if self._dataproc_batch:
batch = _update_batch_from_config(self._dataproc_batch, batch)

return batch


def _update_batch_from_config(
config_dict: Union[Dict, DataprocBatchConfig], target: Batch
) -> Batch:
try:
# updates in place
ParseDict(config_dict, target._pb)
except Exception as e:
docurl = (
"https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1"
"#google.cloud.dataproc.v1.Batch"
)
raise ValueError(
f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}"
) from e
return target
14 changes: 13 additions & 1 deletion dbt/adapters/bigquery/retry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Callable
from typing import Callable, Optional

from google.api_core import retry
from google.api_core.exceptions import Forbidden
from google.api_core.future.polling import POLLING_PREDICATE
from google.cloud.exceptions import BadGateway, BadRequest, ServerError
from requests.exceptions import ConnectionError

Expand Down Expand Up @@ -75,6 +76,17 @@ def job_execution_capped(self, connection: Connection) -> retry.Retry:
on_error=_on_error(connection),
)

def polling(self, timeout: Optional[float] = None) -> retry.Retry:
"""
This strategy mimics what was accomplished with _retry_and_handle
"""
return retry.Retry(
predicate=POLLING_PREDICATE,
minimum=1.0,
maximum=10.0,
timeout=timeout or self.job_execution_timeout or 60 * 60 * 24,
)

def _buffered_predicate(self) -> Callable[[Exception], bool]:
class BufferedPredicate:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_configure_dataproc_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import patch

from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config
from dbt.adapters.bigquery.python_submissions import _update_batch_from_config
from google.cloud import dataproc_v1

from .test_bigquery_adapter import BaseTestBigQueryAdapter
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults):

batch = dataproc_v1.Batch()

batch = update_batch_from_config(raw_batch_config, batch)
batch = _update_batch_from_config(raw_batch_config, batch)

def to_str_values(d):
"""google's protobuf types expose maps as dict[str, str]"""
Expand Down

0 comments on commit 9a9f87e

Please sign in to comment.