diff --git a/deepspeed/module_inject/containers/llama.py b/deepspeed/module_inject/containers/llama.py index aa4dbbec4b8a..af99d658017c 100644 --- a/deepspeed/module_inject/containers/llama.py +++ b/deepspeed/module_inject/containers/llama.py @@ -4,7 +4,7 @@ # DeepSpeed Team from .base import * -from .features import HybridSplitQKVContainer, HybridGatedMLPContainer +from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer from deepspeed.utils.types import ActivationFuncType, NormType from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference import torch @@ -20,7 +20,8 @@ ) -class DS_LLAMAContainer(HybridGatedMLPContainer, HybridSplitQKVContainer, BaseTransformerContainer): +class DS_LLAMAContainer(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer, + BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -85,8 +86,8 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): 'mlp.up_proj.weight', \ 'mlp.gate_proj.weight', \ 'mlp.down_proj.weight', \ - 'input_layernorm.weight', \ - 'post_attention_layernorm.weight' + 'post_attention_layernorm.weight', \ + 'input_layernorm.weight', ) maybe_copy_qkv(module.attention, @@ -105,6 +106,10 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7]) maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8]) + # This line is necessary for proper output when kernels + meta tensors are used in Llama models + # TODO: Investigate root-cause and fix meta tensor loading + module.mlp.output_b = None + class LLAMALayerPolicy(TransformerPolicy):