Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: forbid repeated deepspeed.initialize on training objects #6874

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,50 @@ def _parse_version(version_str):
dist = None


def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
if hasattr(trainobj, 'ds_is_inited'):
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
return

trainobj.ds_is_inited = True


def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
traincheck-team marked this conversation as resolved.
Show resolved Hide resolved
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
if hasattr(trainobj, 'ds_is_inited'):
# we shouldn't hit the assert below, but just in case
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
traincheck-team marked this conversation as resolved.
Show resolved Hide resolved
return True
return False


def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer],
lr_scheduler: Optional[_LRScheduler]):
"""Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call."""
if _is_initialized(model):
raise ValueError(
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
if optimizer is not None and _is_initialized(optimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that optimizer could be a Callable, not an object

optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,

raise ValueError(
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
)
if lr_scheduler is not None and _is_initialized(lr_scheduler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto for lr_scheduler

lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,

raise ValueError(
"LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once."
)


def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer],
lr_scheduler: Optional[_LRScheduler]):
"""Mark the model, optimizer, and lr_scheduler as initialized."""
_mark_initialized(model)
if optimizer is not None:
_mark_initialized(optimizer)
if lr_scheduler is not None:
_mark_initialized(lr_scheduler)


def initialize(args=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
Expand Down Expand Up @@ -137,6 +181,10 @@ def initialize(args=None,
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
_assert_trainobjs_not_inited(model, optimizer, lr_scheduler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this call should be moved into `_mark_trainobjs_initialized()

# mark model, optimizer, and lr_scheduler as initialized
_mark_trainobjs_initialized(model, optimizer, lr_scheduler)

global dist
from deepspeed import comm as dist
Expand Down Expand Up @@ -221,6 +269,9 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

# mark engine, optimizer, and lr_scheduler as initialized
_mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler)

return_items = [
engine,
engine.optimizer,
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder
from deepspeed import _assert_trainobjs_not_inited, _is_initialized


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -434,3 +435,60 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
else:
# callable
assert isinstance(ds_lr_scheduler, OneCycleLR)


# https://github.com/microsoft/DeepSpeed/issues/6770
class TestNoRepeatedInitializationAllowed(DistributedTest):
world_size = 1

def test_no_repeated_init(self):
hidden_dim = 10
model = SimpleModel(hidden_dim)
client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Initialize DeepSpeed configurations for fp16
config_dict = {'train_batch_size': 1}

# Initialize DeepSpeed engine
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)

# arguments should be marked as initialized now
assert _is_initialized(model), "Client model should be marked as initialized"
assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized"

# return values should also be marked as initialized
assert _is_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_initialized(optim), "Optimizer should be marked as initialized"

exception_raised = False
traincheck-team marked this conversation as resolved.
Show resolved Hide resolved
try:
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Repeated initialization should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True
assert exception_raised, "Initialization on ds types should raise an exception"