From 4fc2c8e7d583977b17a4ce09cb5f364cbdd229e3 Mon Sep 17 00:00:00 2001 From: Elsa Granger Date: Thu, 21 Sep 2023 04:30:38 +0800 Subject: [PATCH] Fix llama meta tensor loading in AutoTP and kernel injected inference (#3608) * Adapt to Llama when using meta tensor to load * Fix gated mlp parameter mp * Re-enable meta tensor for kernel injection Fix layer params loading in meta tensor * Revert mlp_inter_mp for gated mlp as it is fixed * Monkey patch for fixing llama output * Fix formatting * Add comment --------- Co-authored-by: Lev Kurilenko Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> --- deepspeed/module_inject/containers/llama.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepspeed/module_inject/containers/llama.py b/deepspeed/module_inject/containers/llama.py index aa4dbbec4b8a..af99d658017c 100644 --- a/deepspeed/module_inject/containers/llama.py +++ b/deepspeed/module_inject/containers/llama.py @@ -4,7 +4,7 @@ # DeepSpeed Team from .base import * -from .features import HybridSplitQKVContainer, HybridGatedMLPContainer +from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer from deepspeed.utils.types import ActivationFuncType, NormType from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference import torch @@ -20,7 +20,8 @@ ) -class DS_LLAMAContainer(HybridGatedMLPContainer, HybridSplitQKVContainer, BaseTransformerContainer): +class DS_LLAMAContainer(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer, + BaseTransformerContainer): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -85,8 +86,8 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): 'mlp.up_proj.weight', \ 'mlp.gate_proj.weight', \ 'mlp.down_proj.weight', \ - 'input_layernorm.weight', \ - 'post_attention_layernorm.weight' + 'post_attention_layernorm.weight', \ + 'input_layernorm.weight', ) maybe_copy_qkv(module.attention, @@ -105,6 +106,10 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7]) maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8]) + # This line is necessary for proper output when kernels + meta tensors are used in Llama models + # TODO: Investigate root-cause and fix meta tensor loading + module.mlp.output_b = None + class LLAMALayerPolicy(TransformerPolicy):