From fd7b1874375606e7f6f9b0aed159673214eec7f8 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Mon, 24 Jun 2024 14:18:00 -0400 Subject: [PATCH] Add Retries to run_query (#1302) * add retry * pyright * slight refactor --------- Co-authored-by: v-chen_data --- scripts/data_prep/convert_delta_to_json.py | 56 +++++++++++++++++----- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f63f1b0027..a3d3f2d4bb 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -19,6 +19,7 @@ import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 import requests +from composer.utils import retry from databricks import sql from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient @@ -347,6 +348,44 @@ def fetch_data( ) +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_total_rows( + tablename: str, + method: str, + cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], +): + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + return nrows + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_columns_info( + tablename: str, + method: str, + cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], +): + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + return columns, order_by, columns_str + + def fetch( method: str, tablename: str, @@ -368,32 +407,25 @@ def fetch( dbsql (databricks.sql.connect): dbsql session """ cursor = dbsql.cursor() if dbsql is not None else None - try: - ans = run_query( - f'SELECT COUNT(*) FROM {tablename}', + nrows = get_total_rows( + tablename, method, cursor, sparkSession, ) - nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') except Exception as e: raise RuntimeError( - f'Error in get total rows from {tablename}. Restart sparkSession and try again', + f'Error in get rows from {tablename}. Restart sparkSession and try again', ) from e try: - ans = run_query( - f'SHOW COLUMNS IN {tablename}', + columns, order_by, columns_str = get_columns_info( + tablename, method, cursor, sparkSession, ) - columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore - order_by = columns[0] - columns_str = ','.join(columns) - log.info(f'order by column {order_by}') except Exception as e: raise RuntimeError( f'Error in get columns from {tablename}. Restart sparkSession and try again',