Skip to content

Commit

Permalink
align new retries with original methods, simplify retry factory
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Nov 7, 2024
1 parent e90c24d commit a2db35b
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 161 deletions.
27 changes: 12 additions & 15 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def cancel_open(self):
continue

if connection.handle is not None and connection.state == ConnectionState.OPEN:
client = connection.handle
client: Client = connection.handle
for job_id in self.jobs_by_thread.get(thread_id, []):
with self.exception_handler(f"Cancel job: {job_id}"):
client.cancel_job(
job_id,
retry=self._retry.deadline(connection),
retry=self._retry.reopen_with_deadline(connection),
)
self.close(connection)

Expand Down Expand Up @@ -444,9 +444,8 @@ def copy_bq_table(self, source, destination, write_disposition) -> None:
source_ref_array,
destination_ref,
job_config=CopyJobConfig(write_disposition=write_disposition),
retry=self._retry.deadline(conn),
)
copy_job.result(retry=self._retry.job_execution_capped(conn))
copy_job.result(timeout=self._retry.job_execution_timeout(300))

def load_dataframe(
self,
Expand Down Expand Up @@ -497,7 +496,7 @@ def _load_table_from_file(
with open(file_path, "rb") as f:
job = client.load_table_from_file(f, table, rewind=True, job_config=config)

if not job.done(retry=self._retry.polling_done(fallback_timeout=fallback_timeout)):
if not job.done(retry=self._retry.retry(fallback_timeout=fallback_timeout)):
raise DbtRuntimeError("BigQuery Timeout Exceeded")

elif job.error_result:
Expand All @@ -522,7 +521,7 @@ def get_bq_table(self, database, schema, identifier) -> Table:
schema = schema or conn.credentials.schema
return client.get_table(
table=self.table_ref(database, schema, identifier),
retry=self._retry.deadline(conn),
retry=self._retry.reopen_with_deadline(conn),
)

def drop_dataset(self, database, schema) -> None:
Expand All @@ -533,7 +532,7 @@ def drop_dataset(self, database, schema) -> None:
dataset=self.dataset_ref(database, schema),
delete_contents=True,
not_found_ok=True,
retry=self._retry.deadline(conn),
retry=self._retry.reopen_with_deadline(conn),
)

def create_dataset(self, database, schema) -> Dataset:
Expand All @@ -543,7 +542,7 @@ def create_dataset(self, database, schema) -> Dataset:
return client.create_dataset(
dataset=self.dataset_ref(database, schema),
exists_ok=True,
retry=self._retry.deadline(conn),
retry=self._retry.reopen_with_deadline(conn),
)

def list_dataset(self, database: str):
Expand All @@ -556,7 +555,7 @@ def list_dataset(self, database: str):
all_datasets = client.list_datasets(
project=database.strip("`"),
max_results=10000,
retry=self._retry.deadline(conn),
retry=self._retry.reopen_with_deadline(conn),
)
return [ds.dataset_id for ds in all_datasets]

Expand All @@ -571,13 +570,11 @@ def _query_and_results(
client: Client = conn.handle
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_factory = QueryJobConfig
job_config = job_factory(**job_params)
query_job = client.query(
query=sql,
job_config=job_config,
job_config=QueryJobConfig(**job_params),
job_id=job_id, # note, this disables retry since the job_id will have been used
timeout=self._retry.job_creation_timeout,
timeout=self._retry.job_creation_timeout(),
)
if (
query_job.location is not None
Expand All @@ -589,11 +586,11 @@ def _query_and_results(
)
try:
iterator = query_job.result(
max_results=limit, timeout=self._retry.job_execution_timeout
max_results=limit, timeout=self._retry.job_execution_timeout()
)
return query_job, iterator
except TimeoutError:
exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout} seconds."
exc = f"Operation did not complete within the designated timeout of {self._retry.job_execution_timeout()} seconds."
raise TimeoutError(exc)

def _labels_from_query_comment(self, comment: str) -> Dict:
Expand Down
63 changes: 26 additions & 37 deletions dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from google.cloud.dataproc_v1 import (
Batch,
CreateBatchRequest,
GetBatchRequest,
Job,
RuntimeConfig,
)
Expand All @@ -30,11 +29,6 @@

class BaseDataProcHelper(PythonJobHelper):
def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None:
"""_summary_
Args:
credentials (_type_): _description_
"""
# validate all additional stuff for python is set
for required_config in ["dataproc_region", "gcs_bucket"]:
if not getattr(credentials, required_config):
Expand Down Expand Up @@ -83,23 +77,26 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None
)

def _submit_dataproc_job(self) -> Job:
job = {
"placement": {"cluster_name": self._cluster_name},
"pyspark_job": {
"main_python_file_uri": self._gcs_path,
request = {
"project_id": self._project,
"region": self._region,
"job": {
"placement": {"cluster_name": self._cluster_name},
"pyspark_job": {
"main_python_file_uri": self._gcs_path,
},
},
}
operation = self._job_controller_client.submit_job_as_operation(
request={
"project_id": self._project,
"region": self._region,
"job": job,
}
)
# check if job failed

# submit the job
operation = self._job_controller_client.submit_job_as_operation(request)

# wait for the job to complete
response = operation.result(polling=self._polling_retry)

if response.status.state == 6:
raise ValueError(response.status.details)

return response


Expand All @@ -114,29 +111,21 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None
def _submit_dataproc_job(self) -> Batch:
_logger.info(f"Submitting batch job with id: {self._batch_id}")

# make the request
request = CreateBatchRequest(
parent=f"projects/{self._project}/locations/{self._region}",
batch=self._configure_batch(),
batch=self._batch(),
batch_id=self._batch_id,
)
self._batch_controller_client.create_batch(request=request)

# return the response
batch = GetBatchRequest(f"{request.parent}/batches/{self._batch_id}")
return self._batch_controller_client.get_batch(batch, retry=self._polling_retry)
# there might be useful results here that we can parse and return
# Dataproc job output is saved to the Cloud Storage bucket
# allocated to the job. Use regex to obtain the bucket and blob info.
# matches = re.match("gs://(.*?)/(.*)", response.driver_output_resource_uri)
# output = (
# self.storage_client
# .get_bucket(matches.group(1))
# .blob(f"{matches.group(2)}.000000000")
# .download_as_string()
# )

def _configure_batch(self) -> Batch:

# submit the batch
operation = self._batch_controller_client.create_batch(request)

# wait for the batch to complete
response = operation.result(polling=self._polling_retry)

return response

def _batch(self) -> Batch:
# create the Dataproc Serverless job config
# need to pin dataproc version to 1.1 as it now defaults to 2.0
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
Expand Down
Loading

0 comments on commit a2db35b

Please sign in to comment.