Skip to content

Commit

Permalink
Azure checkpointing support (mosaicml#2893)
Browse files Browse the repository at this point in the history
* v1

* fix

* fix

* logs

* dump env

* fix

* logs

* force logs

* bucket support

* typo

* more logs

* logs

* more logs

* fix autoresume

* logs

* fix

* fix

* lint

* morelogs

* logs

* fix autoresume

* fix

* lint

* fix

* fix lstirp

* strip prefix

* muck around

* logs

* azure

* timestamp

* fix

* state

* logs

* logs

* remove

* game

* fix

* lint
  • Loading branch information
mvpatel2000 authored Jan 23, 2024
1 parent a91573d commit 4a53dfe
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class RemoteUploaderDownloader(LoggerDestination):
backend_kwargs={
'provider': 's3',
'container': 'my-bucket',
'provider_kwargs=': {
'provider_kwargs': {
'key': 'AKIA...',
'secret': '*********',
'region': 'ap-northeast-1',
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def download_checkpoint(path: str,
raise FileNotFoundError(
(f'Checkpoint {_format_path_with_current_rank(path)} does not exist, '
f'but is required for sharded checkpointing on rank {dist.get_global_rank()}. '
'Please ensure that the checkpoint exists and your load_path was specified as a format string'
'Please ensure that the checkpoint exists and your load_path was specified as a format string '
'with the {rank} argument.')) from e

if extracted_checkpoint_folder is not None:
Expand Down
25 changes: 20 additions & 5 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.misc import partial_format
from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore,
UCObjectStore)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
OCIObjectStore, S3ObjectStore, UCObjectStore)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
Expand Down Expand Up @@ -319,6 +319,7 @@ def parse_uri(uri: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path.
Backend name will be empty string if the input is a local path
"""
uri = uri.replace('AZURE_BLOBS', 'azure') # urlparse does not support _ in scheme
parse_result = urlparse(uri)
backend, net_loc, path = parse_result.scheme, parse_result.netloc, parse_result.path
bucket_name = net_loc if '@' not in net_loc else net_loc.split('@')[0]
Expand Down Expand Up @@ -354,6 +355,13 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
Expand Down Expand Up @@ -411,14 +419,21 @@ def maybe_create_remote_uploader_downloader_from_uri(
return None
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'azure':
return RemoteUploaderDownloader(
bucket_uri=f'libcloud://{bucket_name}',
backend_kwargs={
'provider': 'AZURE_BLOBS',
'container': bucket_name,
'key_environ': 'AZURE_ACCOUNT_NAME',
'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY',
},
)
elif backend == 'dbfs':
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
Expand Down

0 comments on commit 4a53dfe

Please sign in to comment.