diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 76d70bdf3e..0ee65c832b 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -103,7 +103,7 @@ class RemoteUploaderDownloader(LoggerDestination): backend_kwargs={ 'provider': 's3', 'container': 'my-bucket', - 'provider_kwargs=': { + 'provider_kwargs': { 'key': 'AKIA...', 'secret': '*********', 'region': 'ap-northeast-1', diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 63e87f57fe..2af494e68b 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -641,7 +641,7 @@ def download_checkpoint(path: str, raise FileNotFoundError( (f'Checkpoint {_format_path_with_current_rank(path)} does not exist, ' f'but is required for sharded checkpointing on rank {dist.get_global_rank()}. ' - 'Please ensure that the checkpoint exists and your load_path was specified as a format string' + 'Please ensure that the checkpoint exists and your load_path was specified as a format string ' 'with the {rank} argument.')) from e if extracted_checkpoint_folder is not None: diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index d62487e106..c42aa7ce6f 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -21,8 +21,8 @@ from composer.utils import dist from composer.utils.iter_helpers import iterate_with_callback from composer.utils.misc import partial_format -from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, - UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + OCIObjectStore, S3ObjectStore, UCObjectStore) from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX if TYPE_CHECKING: @@ -319,6 +319,7 @@ def parse_uri(uri: str) -> Tuple[str, str, str]: Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path. Backend name will be empty string if the input is a local path """ + uri = uri.replace('AZURE_BLOBS', 'azure') # urlparse does not support _ in scheme parse_result = urlparse(uri) backend, net_loc, path = parse_result.scheme, parse_result.netloc, parse_result.path bucket_name = net_loc if '@' not in net_loc else net_loc.split('@')[0] @@ -354,6 +355,13 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: return GCSObjectStore(bucket=bucket_name) elif backend == 'oci': return OCIObjectStore(bucket=bucket_name) + elif backend == 'azure': + return LibcloudObjectStore( + provider='AZURE_BLOBS', + container=bucket_name, + key_environ='AZURE_ACCOUNT_NAME', + secret_environ='AZURE_ACCOUNT_ACCESS_KEY', + ) elif backend == 'dbfs': if path.startswith(MLFLOW_DBFS_PATH_PREFIX): store = None @@ -411,14 +419,21 @@ def maybe_create_remote_uploader_downloader_from_uri( return None if backend in ['s3', 'oci', 'gs']: return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}') - + elif backend == 'azure': + return RemoteUploaderDownloader( + bucket_uri=f'libcloud://{bucket_name}', + backend_kwargs={ + 'provider': 'AZURE_BLOBS', + 'container': bucket_name, + 'key_environ': 'AZURE_ACCOUNT_NAME', + 'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY', + }, + ) elif backend == 'dbfs': return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path}) - elif backend == 'wandb': raise NotImplementedError(f'There is no implementation for WandB via URI. Please use ' 'WandBLogger with log_artifacts set to True') - else: raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use ' 'one of the supported RemoteUploaderDownloader object stores')