Skip to content

Commit

Permalink
change: do init checking and marking in one func
Browse files Browse the repository at this point in the history
  • Loading branch information
traincheck-team committed Dec 30, 2024
1 parent 62067cc commit 2c5806b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 63 deletions.
74 changes: 36 additions & 38 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import types
import json
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -27,6 +27,8 @@

from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.base_optimizer import DeepSpeedOptimizer
from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -65,46 +67,44 @@ def _parse_version(version_str):
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader)


def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
trainobj.ds_is_inited = True
if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects
trainobj.ds_is_inited = True


def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
return getattr(trainobj, 'ds_is_inited', False)


def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
"""Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call."""
if _is_ds_initialized(model):
raise ValueError(
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer):
raise ValueError(
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
)
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler):
raise ValueError(
"LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once."
)


def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
"""Mark the model, optimizer, and lr_scheduler as initialized.
Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked
as they are not stateful and reuse should be permissible.
"""
_mark_ds_initialized(model)
if optimizer is not None and isinstance(optimizer, Optimizer):
_mark_ds_initialized(optimizer)
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler):
_mark_ds_initialized(lr_scheduler)
if isinstance(trainobj, DS_PRIM_TYPES):
return True
else:
return getattr(trainobj, 'ds_is_inited', False)


def _ensure_and_mark_trainobjs_inited(
model: torch.nn.Module,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]],
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]],
ensures_not_inited: bool = False,
):
trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}

for name, trainobj in trainobjs.items():
print(f"Checking {name}")
if trainobj is None:
continue
if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)):
# skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable
continue
if ensures_not_inited:
if _is_ds_initialized(trainobj):
raise ValueError(
f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once."
)
_mark_ds_initialized(trainobj)


def initialize(args=None,
Expand Down Expand Up @@ -179,9 +179,7 @@ def initialize(args=None,

assert model is not None, "deepspeed.initialize requires a model"
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
_assert_trainobjs_not_inited(model, optimizer, lr_scheduler)
# mark model, optimizer, and lr_scheduler as initialized
_mark_trainobjs_initialized(model, optimizer, lr_scheduler)
_ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True)

global dist
from deepspeed import comm as dist
Expand Down Expand Up @@ -267,7 +265,7 @@ def initialize(args=None,
zero.partition_parameters.restore_init_context()

# mark engine, optimizer, and lr_scheduler as initialized
_mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler)
_ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False)

return_items = [
engine,
Expand Down
31 changes: 6 additions & 25 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import FusedAdamBuilder
from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized
from deepspeed import _is_ds_initialized


@pytest.mark.parametrize('zero_stage', [0, 3])
Expand Down Expand Up @@ -459,7 +459,6 @@ def _optimizer_callable(params) -> Optimizer:
client_optimizer = _optimizer_callable

# Initialize DeepSpeed engine
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)
Expand All @@ -473,33 +472,15 @@ def _optimizer_callable(params) -> Optimizer:
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"

exception_raised = False
try:
err_msg_pattern = "has already been initialized"
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Repeated initialization should raise an exception"

exception_raised = False
try:
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
with pytest.raises(ValueError, match=err_msg_pattern):
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)
except ValueError:
exception_raised = True
assert exception_raised, "Initialization on ds types should raise an exception"

0 comments on commit 2c5806b

Please sign in to comment.