From 884b85002bd23782d9c6964083aebad32835d26a Mon Sep 17 00:00:00 2001 From: tanconghui <6702866+tanconghui@users.noreply.github.com> Date: Tue, 12 Mar 2024 11:46:10 +0800 Subject: [PATCH] disable torch.nn.init when counting parameters in initializing PipelineModule --- deepspeed/runtime/pipe/module.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index c11379b0a0d7..6cf0d847fd62 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import contextlib import os import glob @@ -83,6 +84,24 @@ def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_att self.tied_weight_attr = [tied_weight_attr] if type(tied_weight_attr) == str else tied_weight_attr +if hasattr(torch.overrides, "TorchFunctionMode"): + + class _DisableInit(torch.overrides.TorchFunctionMode): + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, '__module__', None) == 'torch.nn.init': + if 'tensor' in kwargs: + return kwargs['tensor'] + else: + return args[0] + else: + return func(*args, **kwargs) + +else: + _DisableInit = contextlib.suppress + + class PipelineModule(nn.Module): """Modules to be parallelized with pipeline parallelism. @@ -269,7 +288,8 @@ def _get_frozen_parameter_names(self, layer): A list of frozen parameter names """ if isinstance(layer, LayerSpec): - l = layer.build() + with _DisableInit(): + l = layer.build() return [n for n, p in l.named_parameters() if not p.requires_grad] elif isinstance(layer, nn.Module): return [n for n, p in layer.named_parameters() if not p.requires_grad] @@ -287,7 +307,8 @@ def _count_layer_params(self): param_counts = [0] * len(self._layer_specs) for idx, layer in enumerate(self._layer_specs): if isinstance(layer, LayerSpec): - l = layer.build() + with _DisableInit(): + l = layer.build() params = filter(lambda p: p.requires_grad, l.parameters()) param_counts[idx] = sum(p.numel() for p in params) elif isinstance(layer, nn.Module):