From 2a4d56c8b1481875cbfd5a6ac281bb4e329d77cf Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 25 Sep 2023 14:37:22 -0400 Subject: [PATCH] propagate bias --- llmfoundry/models/layers/attention.py | 15 ++++++++++++--- llmfoundry/models/layers/blocks.py | 14 +++++++++----- llmfoundry/models/layers/ffn.py | 8 +++++++- llmfoundry/models/mpt/modeling_mpt.py | 5 +++++ 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 76969b7810..ea99ba81f8 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -419,6 +419,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, ): super().__init__() @@ -450,7 +451,9 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - fc_kwargs = {} + fc_kwargs = { + 'bias': not no_bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( @@ -557,6 +560,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False ): super().__init__( d_model=d_model, @@ -569,7 +573,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + no_bias=no_bias, + ) class MultiQueryAttention(GroupedQueryAttention): @@ -591,6 +597,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False ): super().__init__( d_model=d_model, @@ -603,7 +610,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + no_bias=no_bias, + ) def attn_bias_shape( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 2c5b5d1c7c..6acc234fb8 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -26,6 +26,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, **kwargs: Any, ): if attn_config is None: @@ -66,11 +67,14 @@ def __init__( } self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class(d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class) + self.attn = attn_class( + d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + device=device, + **attn_config_subset_for_attn_class, + no_bias=no_bias, + ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index af770a84f7..e839f7ca14 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -24,9 +24,12 @@ def __init__( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, ): super().__init__() - fc_kwargs = {} + fc_kwargs = { + 'bias': not no_bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.up_proj = FC_CLASS_REGISTRY[fc_type]( @@ -60,6 +63,7 @@ def build_ffn( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') @@ -72,12 +76,14 @@ def build_ffn( expansion_ratio=expansion_ratio, fc_type=fc_type, device=device, + no_bias=no_bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, + bias=not no_bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 26d564ff8c..26e8daa85b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -150,6 +150,11 @@ def __init__(self, config: MPTConfig): log.info(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) + # For transformer engine + if hasattr(module, 'use_bias'): + log.info(f'Setting use_bias=False for {module}.') + module.use_bias = False + log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.')