diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f679f7a190f0a0..a4de8abed03df4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -139,6 +139,7 @@ _init_weights = True _is_quantized = False +_is_ds_init_called = False def is_fsdp_enabled(): @@ -226,6 +227,19 @@ def set_quantized_state(): _is_quantized = False +# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors. +# This issue occurs with ZeRO stage 3 when using NVMe offloading. +# For more details, refer to issue #34429. +@contextmanager +def set_zero3_state(): + global _is_ds_init_called + _is_ds_init_called = True + try: + yield + finally: + _is_ds_init_called = False + + def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -1473,13 +1487,14 @@ def _from_config(cls, config, **kwargs): torch_dtype=torch_dtype, ) - if is_deepspeed_zero3_enabled() and not _is_quantized: + if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called: import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") # this immediately partitions the model across all gpus, to avoid the overhead in time # and memory copying it on CPU or each GPU first - with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): + init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()] + with ContextManagers(init_contexts): model = cls(config, **kwargs) else: @@ -4026,11 +4041,14 @@ def from_pretrained( init_contexts = [no_init_weights(_enable=_fast_init)] tp_device = None - if is_deepspeed_zero3_enabled() and not is_quantized: + if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + init_contexts = [ + deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), + set_zero3_state(), + ] + init_contexts elif low_cpu_mem_usage: if not is_accelerate_available(): raise ImportError(