Skip to content

Commit

Permalink
Deterministic GCRP Errors (#1559)
Browse files Browse the repository at this point in the history
Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Sep 30, 2024
1 parent 4202a06 commit 0ad6ab4
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 10 deletions.
31 changes: 21 additions & 10 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ClusterInvalidAccessMode,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
FaultyDataPrepCluster,
InsufficientPermissionsError,
)

Expand Down Expand Up @@ -660,16 +661,26 @@ def fetch_DT(
)

formatted_delta_table_name = format_tablename(delta_table_name)

fetch(
method,
formatted_delta_table_name,
json_output_folder,
batch_size,
processes,
sparkSession,
dbsql,
)
import grpc
try:
fetch(
method,
formatted_delta_table_name,
json_output_folder,
batch_size,
processes,
sparkSession,
dbsql,
)
except grpc.RpcError as e:
if e.code(
) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details(
):
raise FaultyDataPrepCluster(
message=
f'Faulty data prep cluster, please try swapping data prep cluster: {e.details()}',
) from e
raise e

if dbsql is not None:
dbsql.close()
Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,18 @@ def __reduce__(self):

def __str__(self):
return self.message


class FaultyDataPrepCluster(UserError):
"""Error thrown when the user uses faulty data prep cluster."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)

def __reduce__(self):
# Return a tuple of class, a tuple of arguments, and optionally state
return (FaultyDataPrepCluster, (self.message,))

def __str__(self):
return self.message
58 changes: 58 additions & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from typing import Any
from unittest.mock import MagicMock, mock_open, patch

import grpc

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
FaultyDataPrepCluster,
InsufficientPermissionsError,
download,
fetch,
Expand Down Expand Up @@ -524,3 +527,58 @@ def test_format_tablename(self):
format_tablename('hyphenated-catalog.schema.test_table'),
'`hyphenated-catalog`.`schema`.`test_table`',
)

@patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch')
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_grpc_error_handling(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
):
# Arrange
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a grpc.RpcError with StatusCode.INTERNAL and specific details
grpc_error = grpc.RpcError()
grpc_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

# Configure the fetch function to raise the grpc.RpcError
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'Faulty data prep cluster, please try swapping data prep cluster: ',
str(context.exception),
)
self.assertIn(
'Job aborted due to stage failure',
str(context.exception),
)

# Verify that fetch was called
mock_fetch.assert_called_once()

0 comments on commit 0ad6ab4

Please sign in to comment.