diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 76969b7810..da21000cf3 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn @@ -419,6 +419,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__() @@ -450,7 +451,9 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - fc_kwargs = {} + fc_kwargs: dict[str, Any] = { + 'bias': bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( @@ -557,6 +560,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__( d_model=d_model, @@ -569,7 +573,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + bias=bias, + ) class MultiQueryAttention(GroupedQueryAttention): @@ -591,6 +597,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__( d_model=d_model, @@ -603,7 +610,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + bias=bias, + ) def attn_bias_shape( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 2c5b5d1c7c..a08ef6d77f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -26,6 +26,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, **kwargs: Any, ): if attn_config is None: @@ -66,11 +67,14 @@ def __init__( } self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class(d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class) + self.attn = attn_class( + d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + device=device, + **attn_config_subset_for_attn_class, + bias=not no_bias, + ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): @@ -79,6 +83,7 @@ def __init__( d_model=d_model, expansion_ratio=expansion_ratio, device=device, + bias=not no_bias, **ffn_config, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index af770a84f7..2f6d05f424 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -24,9 +24,12 @@ def __init__( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__() - fc_kwargs = {} + fc_kwargs: dict[str, Any] = { + 'bias': bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.up_proj = FC_CLASS_REGISTRY[fc_type]( @@ -60,6 +63,7 @@ def build_ffn( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') @@ -72,12 +76,14 @@ def build_ffn( expansion_ratio=expansion_ratio, fc_type=fc_type, device=device, + bias=bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, + bias=bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 26d564ff8c..26e8daa85b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -150,6 +150,11 @@ def __init__(self, config: MPTConfig): log.info(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) + # For transformer engine + if hasattr(module, 'use_bias'): + log.info(f'Setting use_bias=False for {module}.') + module.use_bias = False + log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') diff --git a/setup.py b/setup.py index 772dda98a4..df3930c405 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ 'flash-attn==1.0.9', 'mosaicml-turbo==0.0.4', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI - 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy', + 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['peft'] = [