Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: forbid repeated deepspeed.initialize on training objects #6874

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import types
import json
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -27,6 +27,8 @@

from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.base_optimizer import DeepSpeedOptimizer
from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -65,6 +67,45 @@ def _parse_version(version_str):
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader)


def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects
trainobj.ds_is_inited = True


def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
if isinstance(trainobj, DS_PRIM_TYPES):
return True
else:
return getattr(trainobj, 'ds_is_inited', False)


def _ensure_and_mark_trainobjs_inited(
model: torch.nn.Module,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]],
ensures_not_inited: bool = False,
):
trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}

for name, trainobj in trainobjs.items():
print(f"Checking {name}")
if trainobj is None:
continue
if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)):
# skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable
continue
if ensures_not_inited:
if _is_ds_initialized(trainobj):
raise ValueError(
f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once."
)
_mark_ds_initialized(trainobj)


def initialize(args=None,
model: torch.nn.Module = None,
Expand Down Expand Up @@ -137,6 +178,8 @@ def initialize(args=None,
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
_ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True)

global dist
from deepspeed import comm as dist
Expand Down Expand Up @@ -221,6 +264,9 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

# mark engine, optimizer, and lr_scheduler as initialized
_ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False)

return_items = [
engine,
engine.optimizer,
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder
from deepspeed import _is_ds_initialized


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -434,3 +435,52 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
else:
# callable
assert isinstance(ds_lr_scheduler, OneCycleLR)


# https://github.com/microsoft/DeepSpeed/issues/6770
class TestNoRepeatedInitializationAllowed(DistributedTest):
world_size = 1

@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test(self, optimizer_type):
hidden_dim = 10
model = SimpleModel(hidden_dim)

def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

# Initialize DeepSpeed engine
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)

# arguments should be marked as initialized now
assert _is_ds_initialized(model), "Client model should be marked as initialized"
if optimizer_type is Optimizer:
assert _is_ds_initialized(client_optimizer), "Client optimizer should be marked as initialized"

# return values should also be marked as initialized
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"

err_msg_pattern = "has already been initialized"
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)

with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)
Loading