Skip to content

Commit

Permalink
Fix ds nvme (#34444)
Browse files Browse the repository at this point in the history
* skip nested deepspeed.zero.Init call

* make fixup

* solve conflict

* solve conflict

* put back local

* use context mangers instead of local thread

* Skip recursive calls to deepspeed.zero.Init

* Skip recursive calls to deepspeed.zero.Init

* back to old notebooks

* make style
  • Loading branch information
eljandoubi authored Nov 21, 2024
1 parent ae5cbf8 commit d6a5c23
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@

_init_weights = True
_is_quantized = False
_is_ds_init_called = False


def is_fsdp_enabled():
Expand Down Expand Up @@ -226,6 +227,19 @@ def set_quantized_state():
_is_quantized = False


# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
# This issue occurs with ZeRO stage 3 when using NVMe offloading.
# For more details, refer to issue #34429.
@contextmanager
def set_zero3_state():
global _is_ds_init_called
_is_ds_init_called = True
try:
yield
finally:
_is_ds_init_called = False


def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -1473,13 +1487,14 @@ def _from_config(cls, config, **kwargs):
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled() and not _is_quantized:
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)

else:
Expand Down Expand Up @@ -4026,11 +4041,14 @@ def from_pretrained(
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized:
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
if not is_accelerate_available():
raise ImportError(
Expand Down

0 comments on commit d6a5c23

Please sign in to comment.