diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 2578e7151b..9b08530888 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -24,8 +24,9 @@ def _wrap_errors(uri: str, e: Exception): from databricks.sdk.core import DatabricksError + from databricks.sdk.errors.mapping import NotFound if isinstance(e, DatabricksError): - if e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore + if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore raise FileNotFoundError(f'Object {uri} not found') from e raise ObjectStoreTransientError from e @@ -48,6 +49,7 @@ class UCObjectStore(ObjectStore): """ _UC_VOLUME_LIST_API_ENDPOINT = '/api/2.0/fs/list' + _UC_VOLUME_FILES_API_ENDPOINT = '/api/2.0/fs/files' def __init__(self, path: str) -> None: try: @@ -206,13 +208,15 @@ def get_object_size(self, object_name: str) -> int: """ from databricks.sdk.core import DatabricksError try: - file_info = self.client.files.get_status(self._get_object_path(object_name)) - if file_info.is_dir: - raise IsADirectoryError(f'{object_name} is a UC directory, not a file.') - - assert file_info.file_size is not None - return file_info.file_size + # Note: The UC team is working on changes to fix the files.get_status API, but it currently + # does not work. Once fixed, we will call the files API endpoint. We currently only use this + # function in Composer and LLM-foundry to check the UC object's existence. + self.client.api_client.do(method='HEAD', + path=f'{self._UC_VOLUME_FILES_API_ENDPOINT}/{self.prefix}/{object_name}', + headers={'Source': 'mosaicml/composer'}) + return 1000000 # Dummy value, as we don't have a way to get the size of the file except DatabricksError as e: + # If the code reaches here, the file was not found _wrap_errors(self.get_uri(object_name), e) return -1 diff --git a/tests/utils/object_store/test_uc_object_store.py b/tests/utils/object_store/test_uc_object_store.py index 0ca3dbbbd3..6d047d42ee 100644 --- a/tests/utils/object_store/test_uc_object_store.py +++ b/tests/utils/object_store/test_uc_object_store.py @@ -78,13 +78,12 @@ def test_uc_object_store_invalid_prefix(monkeypatch): @pytest.mark.parametrize('result', ['success', 'not_found']) def test_get_object_size(ws_client, uc_object_store, result: str): if result == 'success': - db_files = pytest.importorskip('databricks.sdk.service.files') - ws_client.files.get_status.return_value = db_files.FileInfo(file_size=100) - assert uc_object_store.get_object_size('train.txt') == 100 + ws_client.api_client.do.return_value = {} + assert uc_object_store.get_object_size('train.txt') == 1000000 elif result == 'not_found': db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks') - ws_client.files.get_status.side_effect = db_core.DatabricksError('The file being accessed is not found', - error_code='NOT_FOUND') + ws_client.api_client.do.side_effect = db_core.DatabricksError('The file being accessed is not found', + error_code='NOT_FOUND') with pytest.raises(FileNotFoundError): uc_object_store.get_object_size('train.txt') else: