From 238ba1f38ea3ac107374d58b33a1ae6c63dcc2ed Mon Sep 17 00:00:00 2001 From: traincheck-team Date: Sun, 15 Dec 2024 19:11:41 -0500 Subject: [PATCH 1/4] fix: forbid repeated deepspeed.initialize on training objects --- deepspeed/__init__.py | 51 ++++++++++++++++++++++++ tests/unit/runtime/test_ds_initialize.py | 39 ++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a8d15cd5332b..6bc5642ec8ef 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -66,6 +66,50 @@ def _parse_version(version_str): dist = None +def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): + """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" + # we shouldn't hit the assert below, but just in case + assert not hasattr( + trainobj, 'ds_is_inited' + ), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once." + trainobj.ds_is_inited = True + + +def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): + """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" + if hasattr(trainobj, 'ds_is_inited'): + # we shouldn't hit the assert below, but just in case + assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + return True + return False + + +def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer], + lr_scheduler: Optional[_LRScheduler]): + """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" + if _is_initialized(model): + raise ValueError( + "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") + if optimizer is not None and _is_initialized(optimizer): + raise ValueError( + "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." + ) + if lr_scheduler is not None and _is_initialized(lr_scheduler): + raise ValueError( + "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." + ) + + +def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer], + lr_scheduler: Optional[_LRScheduler]): + """Mark the model, optimizer, and lr_scheduler as initialized.""" + _mark_initialized(model) + if optimizer is not None: + _mark_initialized(optimizer) + if lr_scheduler is not None: + _mark_initialized(lr_scheduler) + + def initialize(args=None, model: torch.nn.Module = None, optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, @@ -137,6 +181,10 @@ def initialize(args=None, zero.partition_parameters.shutdown_init_context() assert model is not None, "deepspeed.initialize requires a model" + # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call + _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) + # mark model, optimizer, and lr_scheduler as initialized + _mark_trainobjs_initialized(model, optimizer, lr_scheduler) global dist from deepspeed import comm as dist @@ -221,6 +269,9 @@ def initialize(args=None, # Restore zero.Init context if necessary zero.partition_parameters.restore_init_context() + # mark engine, optimizer, and lr_scheduler as initialized + _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) + return_items = [ engine, engine.optimizer, diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index a30f81cedde9..2c9ad701bfff 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -21,6 +21,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import FusedAdamBuilder +from deepspeed import _assert_trainobjs_not_inited, _is_initialized @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -434,3 +435,41 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler: else: # callable assert isinstance(ds_lr_scheduler, OneCycleLR) + + +# https://github.com/microsoft/DeepSpeed/issues/6770 +class TestNoRepeatedInitializationAllowed(DistributedTest): + world_size = 1 + + def test_no_repeated_init(self): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + model = SimpleModel() + # Initialize DeepSpeed configurations for fp16 + config_dict = {'train_batch_size': 1} + + client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) + # Initialize DeepSpeed engine + _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) + model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model, + optimizer=client_optimizer, + config_params=config_dict) + + # arguments should be marked as initialized now + assert _is_initialized(model), "Client model should be marked as initialized" + assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized" + + # return values should also be marked as initialized + assert _is_initialized(model_engine), "Model engine should be marked as initialized" + assert _is_initialized(optim), "Optimizer should be marked as initialized" + assert _is_initialized(scheduler), "Scheduler should be marked as initialized" + + exception_raised = False + try: + deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Repeated initialization should raise an exception" From d1e7777b8df6fb949c4680e6770d3925995c5168 Mon Sep 17 00:00:00 2001 From: TrainCheck Team Date: Mon, 16 Dec 2024 15:59:25 -0500 Subject: [PATCH 2/4] fix: remove mark-time checking for non-existence of the flag as DeepSpeedEngine propagates flag from the internal model --- deepspeed/__init__.py | 10 +++---- tests/unit/runtime/test_ds_initialize.py | 33 +++++++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 6bc5642ec8ef..eb245b9492c3 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -68,10 +68,10 @@ def _parse_version(version_str): def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" - # we shouldn't hit the assert below, but just in case - assert not hasattr( - trainobj, 'ds_is_inited' - ), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once." + if hasattr(trainobj, 'ds_is_inited'): + assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + return + trainobj.ds_is_inited = True @@ -79,7 +79,7 @@ def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" if hasattr(trainobj, 'ds_is_inited'): # we shouldn't hit the assert below, but just in case - assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." return True return False diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 2c9ad701bfff..0da24dc2ba32 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -445,17 +445,14 @@ def test_no_repeated_init(self): hidden_dim = 10 model = SimpleModel(hidden_dim) client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - - model = SimpleModel() # Initialize DeepSpeed configurations for fp16 config_dict = {'train_batch_size': 1} - client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) # Initialize DeepSpeed engine _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) - model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model, - optimizer=client_optimizer, - config_params=config_dict) + model_engine, optim, _, _ = deepspeed.initialize(model=model, + optimizer=client_optimizer, + config_params=config_dict) # arguments should be marked as initialized now assert _is_initialized(model), "Client model should be marked as initialized" @@ -464,7 +461,6 @@ def test_no_repeated_init(self): # return values should also be marked as initialized assert _is_initialized(model_engine), "Model engine should be marked as initialized" assert _is_initialized(optim), "Optimizer should be marked as initialized" - assert _is_initialized(scheduler), "Scheduler should be marked as initialized" exception_raised = False try: @@ -473,3 +469,26 @@ def test_no_repeated_init(self): exception_raised = True assert exception_raised, "Repeated initialization should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Initialization on ds types should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Initialization on ds types should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + assert exception_raised, "Initialization on ds types should raise an exception" From 62067ccd5a1f53ec3cd4ce2eca24593031270096 Mon Sep 17 00:00:00 2001 From: traincheck-team Date: Thu, 19 Dec 2024 13:37:00 -0500 Subject: [PATCH 3/4] handle callable types in init mark --- deepspeed/__init__.py | 45 +++++++++++------------- tests/unit/runtime/test_ds_initialize.py | 31 ++++++++++------ 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index eb245b9492c3..8b3f065c8b19 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -66,48 +66,45 @@ def _parse_version(version_str): dist = None -def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): +def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" - if hasattr(trainobj, 'ds_is_inited'): - assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." - return - trainobj.ds_is_inited = True -def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): +def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" - if hasattr(trainobj, 'ds_is_inited'): - # we shouldn't hit the assert below, but just in case - assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." - return True - return False + return getattr(trainobj, 'ds_is_inited', False) -def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer], - lr_scheduler: Optional[_LRScheduler]): +def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, + DeepSpeedOptimizerCallable]], + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" - if _is_initialized(model): + if _is_ds_initialized(model): raise ValueError( "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") - if optimizer is not None and _is_initialized(optimizer): + if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer): raise ValueError( "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." ) - if lr_scheduler is not None and _is_initialized(lr_scheduler): + if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler): raise ValueError( "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." ) -def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer], - lr_scheduler: Optional[_LRScheduler]): - """Mark the model, optimizer, and lr_scheduler as initialized.""" - _mark_initialized(model) - if optimizer is not None: - _mark_initialized(optimizer) - if lr_scheduler is not None: - _mark_initialized(lr_scheduler) +def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, + DeepSpeedOptimizerCallable]], + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): + """Mark the model, optimizer, and lr_scheduler as initialized. + Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked + as they are not stateful and reuse should be permissible. + """ + _mark_ds_initialized(model) + if optimizer is not None and isinstance(optimizer, Optimizer): + _mark_ds_initialized(optimizer) + if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler): + _mark_ds_initialized(lr_scheduler) def initialize(args=None, diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 0da24dc2ba32..01ab3a5e7d03 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -21,7 +21,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import FusedAdamBuilder -from deepspeed import _assert_trainobjs_not_inited, _is_initialized +from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -441,12 +441,22 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler: class TestNoRepeatedInitializationAllowed(DistributedTest): world_size = 1 - def test_no_repeated_init(self): + @pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable]) + def test(self, optimizer_type): hidden_dim = 10 model = SimpleModel(hidden_dim) - client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - # Initialize DeepSpeed configurations for fp16 + + def _optimizer_callable(params) -> Optimizer: + return AdamW(params=params) + config_dict = {'train_batch_size': 1} + if optimizer_type is None: + client_optimizer = None + config_dict['optimizer'] = {'type': ADAM_OPTIMIZER} + elif optimizer_type is Optimizer: + client_optimizer = Adam(model.parameters()) + else: + client_optimizer = _optimizer_callable # Initialize DeepSpeed engine _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) @@ -455,12 +465,13 @@ def test_no_repeated_init(self): config_params=config_dict) # arguments should be marked as initialized now - assert _is_initialized(model), "Client model should be marked as initialized" - assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized" + assert _is_ds_initialized(model), "Client model should be marked as initialized" + if optimizer_type is Optimizer: + assert _is_ds_initialized(client_optimizer), "Client optimizer should be marked as initialized" # return values should also be marked as initialized - assert _is_initialized(model_engine), "Model engine should be marked as initialized" - assert _is_initialized(optim), "Optimizer should be marked as initialized" + assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized" + assert _is_ds_initialized(optim), "Optimizer should be marked as initialized" exception_raised = False try: @@ -480,7 +491,7 @@ def test_no_repeated_init(self): exception_raised = False try: - deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) + deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict) except ValueError: exception_raised = True @@ -488,7 +499,7 @@ def test_no_repeated_init(self): exception_raised = False try: - deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) + deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict) except ValueError: exception_raised = True assert exception_raised, "Initialization on ds types should raise an exception" From 2c5806b82d82d555a30db8c2e7b8724e9dddb80a Mon Sep 17 00:00:00 2001 From: traincheck-team Date: Sun, 29 Dec 2024 21:02:06 -0500 Subject: [PATCH 4/4] change: do init checking and marking in one func --- deepspeed/__init__.py | 74 ++++++++++++------------ tests/unit/runtime/test_ds_initialize.py | 31 ++-------- 2 files changed, 42 insertions(+), 63 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 8b3f065c8b19..9d8689571c65 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -6,7 +6,7 @@ import sys import types import json -from typing import Optional, Union +from typing import Callable, Optional, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -27,6 +27,8 @@ from .accelerator import get_accelerator from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from .runtime.base_optimizer import DeepSpeedOptimizer +from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine @@ -65,46 +67,44 @@ def _parse_version(version_str): # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init dist = None +DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader) + def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" - trainobj.ds_is_inited = True + if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects + trainobj.ds_is_inited = True def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" - return getattr(trainobj, 'ds_is_inited', False) - - -def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, - DeepSpeedOptimizerCallable]], - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): - """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" - if _is_ds_initialized(model): - raise ValueError( - "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") - if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer): - raise ValueError( - "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." - ) - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler): - raise ValueError( - "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." - ) - - -def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, - DeepSpeedOptimizerCallable]], - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): - """Mark the model, optimizer, and lr_scheduler as initialized. - Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked - as they are not stateful and reuse should be permissible. - """ - _mark_ds_initialized(model) - if optimizer is not None and isinstance(optimizer, Optimizer): - _mark_ds_initialized(optimizer) - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler): - _mark_ds_initialized(lr_scheduler) + if isinstance(trainobj, DS_PRIM_TYPES): + return True + else: + return getattr(trainobj, 'ds_is_inited', False) + + +def _ensure_and_mark_trainobjs_inited( + model: torch.nn.Module, + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]], + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]], + ensures_not_inited: bool = False, +): + trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} + + for name, trainobj in trainobjs.items(): + print(f"Checking {name}") + if trainobj is None: + continue + if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)): + # skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable + continue + if ensures_not_inited: + if _is_ds_initialized(trainobj): + raise ValueError( + f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once." + ) + _mark_ds_initialized(trainobj) def initialize(args=None, @@ -179,9 +179,7 @@ def initialize(args=None, assert model is not None, "deepspeed.initialize requires a model" # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call - _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) - # mark model, optimizer, and lr_scheduler as initialized - _mark_trainobjs_initialized(model, optimizer, lr_scheduler) + _ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True) global dist from deepspeed import comm as dist @@ -267,7 +265,7 @@ def initialize(args=None, zero.partition_parameters.restore_init_context() # mark engine, optimizer, and lr_scheduler as initialized - _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) + _ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False) return_items = [ engine, diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 01ab3a5e7d03..f445fffd6323 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -21,7 +21,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import FusedAdamBuilder -from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized +from deepspeed import _is_ds_initialized @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -459,7 +459,6 @@ def _optimizer_callable(params) -> Optimizer: client_optimizer = _optimizer_callable # Initialize DeepSpeed engine - _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) model_engine, optim, _, _ = deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) @@ -473,33 +472,15 @@ def _optimizer_callable(params) -> Optimizer: assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized" assert _is_ds_initialized(optim), "Optimizer should be marked as initialized" - exception_raised = False - try: + err_msg_pattern = "has already been initialized" + with pytest.raises(ValueError, match=err_msg_pattern): deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) - except ValueError: - exception_raised = True - assert exception_raised, "Repeated initialization should raise an exception" - - exception_raised = False - try: + with pytest.raises(ValueError, match=err_msg_pattern): deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) - except ValueError: - exception_raised = True - - assert exception_raised, "Initialization on ds types should raise an exception" - exception_raised = False - try: + with pytest.raises(ValueError, match=err_msg_pattern): deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict) - except ValueError: - exception_raised = True - - assert exception_raised, "Initialization on ds types should raise an exception" - exception_raised = False - try: + with pytest.raises(ValueError, match=err_msg_pattern): deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict) - except ValueError: - exception_raised = True - assert exception_raised, "Initialization on ds types should raise an exception"