Skip to content

Commit

Permalink
Propagate bias through model (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Sep 26, 2023
1 parent f488ad5 commit 547b313
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
17 changes: 13 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -419,6 +419,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -450,7 +451,9 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -557,6 +560,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -569,7 +573,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


class MultiQueryAttention(GroupedQueryAttention):
Expand All @@ -591,6 +597,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -603,7 +610,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


def attn_bias_shape(
Expand Down
15 changes: 10 additions & 5 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -66,11 +67,14 @@ def __init__(
}

self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
Expand All @@ -79,6 +83,7 @@ def __init__(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ def __init__(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()
fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.up_proj = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -60,6 +63,7 @@ def build_ffn(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
Expand All @@ -72,12 +76,14 @@ def build_ffn(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
device=device,
bias=bias,
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=d_model * expansion_ratio,
bias=bias,
**kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def __init__(self, config: MPTConfig):
log.info(f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)

# For transformer engine
if hasattr(module, 'use_bias'):
log.info(f'Setting use_bias=False for {module}.')
module.use_bias = False

log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
'flash-attn==1.0.9',
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected].3#subdirectory=csrc/xentropy',
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected].9#subdirectory=csrc/xentropy',
]

extra_deps['peft'] = [
Expand Down

0 comments on commit 547b313

Please sign in to comment.