diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 44e8651cdf..2321d306ff 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -23,6 +23,7 @@ ClusterInvalidAccessMode, FailedToConnectToDatabricksError, FailedToCreateSQLConnectionError, + FaultyDataPrepCluster, InsufficientPermissionsError, ) @@ -660,16 +661,26 @@ def fetch_DT( ) formatted_delta_table_name = format_tablename(delta_table_name) - - fetch( - method, - formatted_delta_table_name, - json_output_folder, - batch_size, - processes, - sparkSession, - dbsql, - ) + import grpc + try: + fetch( + method, + formatted_delta_table_name, + json_output_folder, + batch_size, + processes, + sparkSession, + dbsql, + ) + except grpc.RpcError as e: + if e.code( + ) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details( + ): + raise FaultyDataPrepCluster( + message= + f'Faulty data prep cluster, please try swapping data prep cluster: {e.details()}', + ) from e + raise e if dbsql is not None: dbsql.close() diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 242ac4f32c..9cbea2cac8 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -466,3 +466,18 @@ def __reduce__(self): def __str__(self): return self.message + + +class FaultyDataPrepCluster(UserError): + """Error thrown when the user uses faulty data prep cluster.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (FaultyDataPrepCluster, (self.message,)) + + def __str__(self): + return self.message diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 981f5c1ed6..34a5b5ca55 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -7,7 +7,10 @@ from typing import Any from unittest.mock import MagicMock, mock_open, patch +import grpc + from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + FaultyDataPrepCluster, InsufficientPermissionsError, download, fetch, @@ -524,3 +527,58 @@ def test_format_tablename(self): format_tablename('hyphenated-catalog.schema.test_table'), '`hyphenated-catalog`.`schema`.`test_table`', ) + + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info', + ) + def test_fetch_DT_grpc_error_handling( + self, + mock_validate_cluster_info: MagicMock, + mock_fetch: MagicMock, + ): + # Arrange + # Mock the validate_and_get_cluster_info to return test values + mock_validate_cluster_info.return_value = ('dbconnect', None, None) + + # Create a grpc.RpcError with StatusCode.INTERNAL and specific details + grpc_error = grpc.RpcError() + grpc_error.code = lambda: grpc.StatusCode.INTERNAL + grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.' + + # Configure the fetch function to raise the grpc.RpcError + mock_fetch.side_effect = grpc_error + + # Test inputs + delta_table_name = 'test_table' + json_output_folder = '/tmp/to/jsonl' + http_path = None + cluster_id = None + use_serverless = False + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'test-token' + + # Act & Assert + with self.assertRaises(FaultyDataPrepCluster) as context: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + use_serverless=use_serverless, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + + # Verify that the FaultyDataPrepCluster contains the expected message + self.assertIn( + 'Faulty data prep cluster, please try swapping data prep cluster: ', + str(context.exception), + ) + self.assertIn( + 'Job aborted due to stage failure', + str(context.exception), + ) + + # Verify that fetch was called + mock_fetch.assert_called_once()