Skip to content

Commit

Permalink
Merge branch 'main' into release/v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 26, 2023
2 parents 779d8af + 7a8c1a5 commit 2a9b8ac
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 23 deletions.
3 changes: 2 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def _build_hf_dataset_from_remote(

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
signal_file_path = os.path.join(
finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self, om_model_config: Union[DictConfig,
f'init_device="{init_device}" must be either "cpu" or "meta".'
)

signal_file_path = '.local_rank0_completed_autoresume'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
60 changes: 47 additions & 13 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -243,10 +263,16 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -383,9 +409,8 @@ def triton_flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand Down Expand Up @@ -419,6 +444,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -450,7 +476,9 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -557,6 +585,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -569,7 +598,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


class MultiQueryAttention(GroupedQueryAttention):
Expand All @@ -591,6 +622,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -603,7 +635,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


def attn_bias_shape(
Expand Down
15 changes: 10 additions & 5 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -66,11 +67,14 @@ def __init__(
}

self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
Expand All @@ -79,6 +83,7 @@ def __init__(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ def __init__(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()
fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.up_proj = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -60,6 +63,7 @@ def build_ffn(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
Expand All @@ -72,12 +76,14 @@ def build_ffn(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
device=device,
bias=bias,
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=d_model * expansion_ratio,
bias=bias,
**kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def __init__(self, config: MPTConfig):
log.info(f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)

# For transformer engine
if hasattr(module, 'use_bias'):
log.info(f'Setting use_bias=False for {module}.')
module.use_bias = False

log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.1,<0.17',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.3,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'mosaicml-streaming>=0.6,<0.7',
Expand Down Expand Up @@ -89,7 +89,7 @@
'flash-attn==1.0.9',
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected].3#subdirectory=csrc/xentropy',
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected].9#subdirectory=csrc/xentropy',
]

extra_deps['peft'] = [
Expand Down

0 comments on commit 2a9b8ac

Please sign in to comment.