Skip to content

Commit

Permalink
Skip DeepSpeed ZeRO Stage 3 model initialization when bnb (#34395)
Browse files Browse the repository at this point in the history
* Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized.

* Propagate the quantization state using a context manager

* make fixup
  • Loading branch information
eljandoubi authored Nov 5, 2024
1 parent eb81144 commit d0b1d8d
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@


_init_weights = True
_is_quantized = False


def is_fsdp_enabled():
Expand Down Expand Up @@ -213,6 +214,16 @@ def _skip_init(*args, **kwargs):
setattr(torch.nn.init, name, init_func)


@contextmanager
def set_quantized_state():
global _is_quantized
_is_quantized = True
try:
yield
finally:
_is_quantized = False


def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -1531,7 +1542,7 @@ def _from_config(cls, config, **kwargs):
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not _is_quantized:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
Expand Down Expand Up @@ -4086,6 +4097,9 @@ def from_pretrained(
)
init_contexts.append(init_empty_weights())

if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
Expand Down

0 comments on commit d0b1d8d

Please sign in to comment.