Skip to content

Commit

Permalink
[MLFlowObjectStore] [2/2] Support checkpointing with MLFlow (mosaicml…
Browse files Browse the repository at this point in the history
…#2810)

* Support checkpoint uploads to MLFlow (untested)

Use MLFlow run tag for autoresume

Add MLFlowLogger test for existing composer run tag

* Try formatting mlflow save folder after INIT

Make MLFlow experiment and run ID available on all ranks

Fix path issue

Format mlflow placeholders in remote filenames

* Unit tests for partial_format

* Log mlflow info as hyperparams

* partial_format doc update

* Fix formatting

* Pull distributed logic out of MLFlowObjectStore

Add debug tracebacks

Bugfix

Add path to debug info

Try fixing RUD object store init

Pyright

* Partial format in format_name helpers

* Fix import

* Add extra partial_format test

* Fix mlflow RUD check

* Fix test

pyright

No longer expect KeyError for format_with_dist using partial_format

Refactor partial_format for readability

* Max iters on partial_format

* Fix partial_format

* Clean up

* fix test import

* Fix test
  • Loading branch information
jerrychen109 authored Jan 12, 2024
1 parent 2ff7c27 commit 56fa4bd
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 56 deletions.
29 changes: 27 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from typing import Any, Callable, Dict, List, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, using_torch_2)
format_name_with_dist_and_time, is_model_deepspeed, partial_format, using_torch_2)
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,6 +271,30 @@ def __init__(
self.start_batch = None

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
for destination in logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id
}
self.folder = partial_format(self.folder, **mlflow_format_kwargs)

self.filename.folder = self.folder
if self.latest_filename is not None:
self.latest_filename.folder = self.folder

# The remote paths have the placeholders in their filename rather than folder
if self.remote_file_name is not None:
self.remote_file_name.filename = partial_format(self.remote_file_name.filename,
**mlflow_format_kwargs)
if self.latest_remote_file_name is not None:
self.latest_remote_file_name.filename = partial_format(self.latest_remote_file_name.filename,
**mlflow_format_kwargs)

break

folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down
35 changes: 30 additions & 5 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(
self._rank_zero_only = rank_zero_only
self._last_flush_time = time.time()
self._flush_interval = flush_interval

self._experiment_id = None
self._run_id = None

if self._enabled:
self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri())
mlflow.set_tracking_uri(self.tracking_uri)
Expand Down Expand Up @@ -128,6 +132,10 @@ def init(self, state: State, logger: Logger) -> None:
if self.run_name is None:
self.run_name = state.run_name

# Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume.
self.tags = self.tags or {}
self.tags['composer_run_name'] = state.run_name

# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
self.run_name += f'-rank{dist.get_global_rank()}'
Expand All @@ -141,17 +149,34 @@ def init(self, state: State, logger: Logger) -> None:
if env_run_id is not None:
self._run_id = env_run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
# Search for an existing run tagged with this Composer run.
existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id],
filter_string=f'tags.composer_run_name = "{state.run_name}"',
output_format='list')
if len(existing_runs) > 0:
self._run_id = existing_runs[0].info.run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
mlflow.start_run(
run_id=self._run_id,
tags=self.tags,
log_system_metrics=self.log_system_metrics,
)

# If rank zero only, broadcast the MLFlow experiment and run IDs to other ranks, so the MLFlow run info is
# available to other ranks during runtime.
if self._rank_zero_only:
mlflow_ids_list = [self._experiment_id, self._run_id]
dist.broadcast_object_list(mlflow_ids_list, src=0)
self._experiment_id, self._run_id = mlflow_ids_list

def after_load(self, state: State, logger: Logger) -> None:
logger.log_hyperparameters({'mlflow_experiment_id': self._experiment_id, 'mlflow_run_id': self._run_id})

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
if self._enabled:
try:
Expand Down
45 changes: 37 additions & 8 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, OCIObjectStore,
S3ObjectStore, SFTPObjectStore, UCObjectStore, dist, format_name_with_dist, get_file, retry)
from composer.utils import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore,
dist, format_name_with_dist, get_file, retry)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
from composer.core import State
Expand All @@ -37,19 +39,32 @@


def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]):
remote_backend_cls = None
remote_backend_name_to_cls = {
's3': S3ObjectStore,
'oci': OCIObjectStore,
'sftp': SFTPObjectStore,
'libcloud': LibcloudObjectStore,
'gs': GCSObjectStore,
'dbfs': UCObjectStore,
}
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({list(remote_backend_name_to_cls.keys())})'
)

# Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore`
# or :class:`.MLFlowObjectStore`.
if remote_backend_name == 'dbfs':
path = backend_kwargs['path']
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
remote_backend_cls = MLFlowObjectStore
else:
# Validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
remote_backend_cls = UCObjectStore
else:
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs']
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})'
)

return remote_backend_cls(**backend_kwargs)

Expand Down Expand Up @@ -322,6 +337,20 @@ def init(self, state: State, logger: Logger) -> None:
if dist.get_global_rank() == 0:
retry(ObjectStoreTransientError,
self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))()

# If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be
# updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information
# must be propagated across all ranks before the workers are started so that all workers use the same
# MLFlow run.
if self.backend_kwargs.get('path', '').startswith(MLFLOW_DBFS_PATH_PREFIX):
if dist.get_global_rank() == 0:
assert isinstance(self.remote_backend, MLFlowObjectStore)
self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path'])

path_list = [self.backend_kwargs['path']]
dist.broadcast_object_list(path_list, src=0)
self.backend_kwargs['path'] = path_list[0]

assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None'
for _ in range(self._num_concurrent_uploads):
worker = self._proc_class(
Expand Down
34 changes: 32 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
from composer.models import ComposerModel
Expand All @@ -54,8 +54,9 @@
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
model_eval_mode, parse_uri, reproducibility, using_torch_2)
model_eval_mode, parse_uri, partial_format, reproducibility, using_torch_2)
from composer.utils.misc import is_model_deepspeed
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

if is_tpu_installed():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -1085,6 +1086,11 @@ def __init__(
mosaicml_logger = MosaicMLLogger()
loggers.append(mosaicml_logger)

# Remote Uploader Downloader
# Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before
# the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a
# run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize
# the active MLFlow run.
if save_folder is not None:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
if remote_ud is not None:
Expand Down Expand Up @@ -1158,6 +1164,30 @@ def __init__(
# Run Event.INIT
self.engine.run_event(Event.INIT)

# If the experiment is being tracked with an `MLFlowLogger`, then MLFlow experiment and run are available
# after Event.INIT.
if save_folder is not None:
mlflow_logger = None
for destination in self.logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_logger = destination
break

if mlflow_logger is not None:
mlflow_experiment_id = mlflow_logger._experiment_id
mlflow_run_id = mlflow_logger._run_id

# The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so
# populate them now.
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: mlflow_experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: mlflow_run_id
}

save_folder = partial_format(save_folder, **mlflow_format_kwargs)
if latest_remote_file_name is not None:
latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs)

# Log hparams.
if self.auto_log_hparams:
self.local_hparams = extract_hparams(locals())
Expand Down
3 changes: 2 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp,
is_notebook, model_eval_mode, using_torch_2)
is_notebook, model_eval_mode, partial_format, using_torch_2)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore,
UCObjectStore)
Expand Down Expand Up @@ -92,4 +92,5 @@
'LambdaEvalClient',
'LocalEvalClient',
'MosaicMLLambdaEvalClient',
'partial_format',
]
44 changes: 34 additions & 10 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store import GCSObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore
from composer.utils.misc import partial_format
from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore,
UCObjectStore)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
from composer.core import Timestamp
Expand Down Expand Up @@ -166,7 +169,8 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]


def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
format_str,
run_name=run_name,
**_get_dist_config(strict=False),
**extra_format_kwargs,
Expand Down Expand Up @@ -259,7 +263,8 @@ def format_name_with_dist_and_time(
timestamp: Timestamp,
**extra_format_kwargs: object,
): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
format_str,
run_name=run_name,
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
Expand Down Expand Up @@ -350,9 +355,28 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores')
Expand Down Expand Up @@ -388,13 +412,13 @@ def maybe_create_remote_uploader_downloader_from_uri(
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'dbfs':
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
Expand Down
20 changes: 20 additions & 0 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,23 @@ def using_torch_2_0_1() -> bool:
bool: Return True if current version is greater than or equal to 2.0.1 else False
"""
return version.parse(torch.__version__) >= version.parse('2.0.1')


def partial_format(s, *args, **kwargs) -> str:
"""Format a string with a partial set of arguments.
Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this
function allows for a partial set of arguments to be provided. Any missing arguments will be
left as-is in the string.
"""
max_iters = 10_000 # Just in case we get stuck in a loop somehow.
for _ in range(max_iters):
try:
return s.format(*args, **kwargs)
except IndexError as e: # Missing positional arg
args += ('{}',)
except KeyError as e: # Missing keyword arg
key = e.args[0]
kwargs[key] = '{' + key + '}'

raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.')
Loading

0 comments on commit 56fa4bd

Please sign in to comment.