diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 8b3f065c8b19..9d8689571c65 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -6,7 +6,7 @@ import sys import types import json -from typing import Optional, Union +from typing import Callable, Optional, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -27,6 +27,8 @@ from .accelerator import get_accelerator from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from .runtime.base_optimizer import DeepSpeedOptimizer +from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine @@ -65,46 +67,44 @@ def _parse_version(version_str): # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init dist = None +DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader) + def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" - trainobj.ds_is_inited = True + if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects + trainobj.ds_is_inited = True def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" - return getattr(trainobj, 'ds_is_inited', False) - - -def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, - DeepSpeedOptimizerCallable]], - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): - """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" - if _is_ds_initialized(model): - raise ValueError( - "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") - if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer): - raise ValueError( - "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." - ) - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler): - raise ValueError( - "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." - ) - - -def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, - DeepSpeedOptimizerCallable]], - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): - """Mark the model, optimizer, and lr_scheduler as initialized. - Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked - as they are not stateful and reuse should be permissible. - """ - _mark_ds_initialized(model) - if optimizer is not None and isinstance(optimizer, Optimizer): - _mark_ds_initialized(optimizer) - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler): - _mark_ds_initialized(lr_scheduler) + if isinstance(trainobj, DS_PRIM_TYPES): + return True + else: + return getattr(trainobj, 'ds_is_inited', False) + + +def _ensure_and_mark_trainobjs_inited( + model: torch.nn.Module, + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]], + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]], + ensures_not_inited: bool = False, +): + trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} + + for name, trainobj in trainobjs.items(): + print(f"Checking {name}") + if trainobj is None: + continue + if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)): + # skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable + continue + if ensures_not_inited: + if _is_ds_initialized(trainobj): + raise ValueError( + f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once." + ) + _mark_ds_initialized(trainobj) def initialize(args=None, @@ -179,9 +179,7 @@ def initialize(args=None, assert model is not None, "deepspeed.initialize requires a model" # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call - _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) - # mark model, optimizer, and lr_scheduler as initialized - _mark_trainobjs_initialized(model, optimizer, lr_scheduler) + _ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True) global dist from deepspeed import comm as dist @@ -267,7 +265,7 @@ def initialize(args=None, zero.partition_parameters.restore_init_context() # mark engine, optimizer, and lr_scheduler as initialized - _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) + _ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False) return_items = [ engine, diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 01ab3a5e7d03..8e9566acd4b1 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -21,7 +21,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import FusedAdamBuilder -from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized +from deepspeed import _is_ds_initialized @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -459,7 +459,6 @@ def _optimizer_callable(params) -> Optimizer: client_optimizer = _optimizer_callable # Initialize DeepSpeed engine - _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) model_engine, optim, _, _ = deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) @@ -473,33 +472,14 @@ def _optimizer_callable(params) -> Optimizer: assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized" assert _is_ds_initialized(optim), "Optimizer should be marked as initialized" - exception_raised = False - try: + with pytest.raises(ValueError): deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) - except ValueError: - exception_raised = True - assert exception_raised, "Repeated initialization should raise an exception" - - exception_raised = False - try: + with pytest.raises(ValueError): deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) - except ValueError: - exception_raised = True - - assert exception_raised, "Initialization on ds types should raise an exception" - exception_raised = False - try: + with pytest.raises(ValueError): deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict) - except ValueError: - exception_raised = True - - assert exception_raised, "Initialization on ds types should raise an exception" - exception_raised = False - try: + with pytest.raises(ValueError): deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict) - except ValueError: - exception_raised = True - assert exception_raised, "Initialization on ds types should raise an exception"