Skip to content

Commit

Permalink
Fix llama meta tensor loading in AutoTP and kernel injected inference (
Browse files Browse the repository at this point in the history
…#3608)

* Adapt to Llama when using meta tensor to load

* Fix gated mlp parameter mp

* Re-enable meta tensor for kernel injection
Fix layer params loading in meta tensor

* Revert mlp_inter_mp for gated mlp as it is fixed

* Monkey patch for fixing llama output

* Fix formatting

* Add comment

---------

Co-authored-by: Lev Kurilenko <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent 463dea2 commit 4fc2c8e
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions deepspeed/module_inject/containers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team

from .base import *
from .features import HybridSplitQKVContainer, HybridGatedMLPContainer
from .features import HybridSplitQKVContainer, HybridGatedMLPContainer, MetaTensorContainer
from deepspeed.utils.types import ActivationFuncType, NormType
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
import torch
Expand All @@ -20,7 +20,8 @@
)


class DS_LLAMAContainer(HybridGatedMLPContainer, HybridSplitQKVContainer, BaseTransformerContainer):
class DS_LLAMAContainer(MetaTensorContainer, HybridGatedMLPContainer, HybridSplitQKVContainer,
BaseTransformerContainer):

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -85,8 +86,8 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
'mlp.up_proj.weight', \
'mlp.gate_proj.weight', \
'mlp.down_proj.weight', \
'input_layernorm.weight', \
'post_attention_layernorm.weight'
'post_attention_layernorm.weight', \
'input_layernorm.weight',
)

maybe_copy_qkv(module.attention,
Expand All @@ -105,6 +106,10 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[8], prefix + param_names[7])
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[10], prefix + param_names[8])

# This line is necessary for proper output when kernels + meta tensors are used in Llama models
# TODO: Investigate root-cause and fix meta tensor loading
module.mlp.output_b = None


class LLAMALayerPolicy(TransformerPolicy):

Expand Down

0 comments on commit 4fc2c8e

Please sign in to comment.