From 6e55ec6759f16db5b057782967078672433251b1 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Mon, 17 Jun 2024 14:10:24 -0700 Subject: [PATCH] add retry --- scripts/data_prep/convert_delta_to_json.py | 75 +++++++++++++--------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f63f1b0027..edbfd56a87 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -42,6 +42,7 @@ MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' +MAX_RETRY = 5 log = logging.getLogger(__name__) @@ -368,36 +369,52 @@ def fetch( dbsql (databricks.sql.connect): dbsql session """ cursor = dbsql.cursor() if dbsql is not None else None + for row_retry in range(MAX_RETRY): + try: + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) + nrows = [row.asDict() for row in ans + ][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + break + except Exception as e: + if row_retry == MAX_RETRY - 1: + raise RuntimeError( + f'Error in get total rows from {tablename}. Restart sparkSession and try again', + ) from e + else: + log.warning( + f'Error in get total rows from {tablename}, trying again...' + ) - try: - ans = run_query( - f'SELECT COUNT(*) FROM {tablename}', - method, - cursor, - sparkSession, - ) - nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') - except Exception as e: - raise RuntimeError( - f'Error in get total rows from {tablename}. Restart sparkSession and try again', - ) from e - - try: - ans = run_query( - f'SHOW COLUMNS IN {tablename}', - method, - cursor, - sparkSession, - ) - columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore - order_by = columns[0] - columns_str = ','.join(columns) - log.info(f'order by column {order_by}') - except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e + for row_retry in range(MAX_RETRY): + try: + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) + columns = [ + row.asDict().popitem()[1] for row in ans + ] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + break + except Exception as e: + if row_retry == MAX_RETRY - 1: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again', + ) from e + else: + log.warning( + f'Error in get columns from {tablename}, trying again...' + ) if method == 'dbconnect' and sparkSession is not None: log.info(f'{processes=}')