From 547b313790fd1d4f3b80269617f740342d60df6f Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 25 Sep 2023 21:18:49 -0400 Subject: [PATCH 1/4] Propagate bias through model (#627) --- llmfoundry/models/layers/attention.py | 17 +++++++++++++---- llmfoundry/models/layers/blocks.py | 15 ++++++++++----- llmfoundry/models/layers/ffn.py | 8 +++++++- llmfoundry/models/mpt/modeling_mpt.py | 5 +++++ setup.py | 2 +- 5 files changed, 36 insertions(+), 11 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 76969b7810..da21000cf3 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn @@ -419,6 +419,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__() @@ -450,7 +451,9 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - fc_kwargs = {} + fc_kwargs: dict[str, Any] = { + 'bias': bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( @@ -557,6 +560,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__( d_model=d_model, @@ -569,7 +573,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + bias=bias, + ) class MultiQueryAttention(GroupedQueryAttention): @@ -591,6 +597,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__( d_model=d_model, @@ -603,7 +610,9 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - device=device) + device=device, + bias=bias, + ) def attn_bias_shape( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 2c5b5d1c7c..a08ef6d77f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -26,6 +26,7 @@ def __init__( norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', device: Optional[str] = None, + no_bias: bool = False, **kwargs: Any, ): if attn_config is None: @@ -66,11 +67,14 @@ def __init__( } self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class(d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class) + self.attn = attn_class( + d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + device=device, + **attn_config_subset_for_attn_class, + bias=not no_bias, + ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): @@ -79,6 +83,7 @@ def __init__( d_model=d_model, expansion_ratio=expansion_ratio, device=device, + bias=not no_bias, **ffn_config, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index af770a84f7..2f6d05f424 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -24,9 +24,12 @@ def __init__( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, ): super().__init__() - fc_kwargs = {} + fc_kwargs: dict[str, Any] = { + 'bias': bias, + } if fc_type != 'te': fc_kwargs['device'] = device self.up_proj = FC_CLASS_REGISTRY[fc_type]( @@ -60,6 +63,7 @@ def build_ffn( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, + bias: bool = True, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') @@ -72,12 +76,14 @@ def build_ffn( expansion_ratio=expansion_ratio, fc_type=fc_type, device=device, + bias=bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, + bias=bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 26d564ff8c..26e8daa85b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -150,6 +150,11 @@ def __init__(self, config: MPTConfig): log.info(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) + # For transformer engine + if hasattr(module, 'use_bias'): + log.info(f'Setting use_bias=False for {module}.') + module.use_bias = False + log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') diff --git a/setup.py b/setup.py index 772dda98a4..df3930c405 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,7 @@ 'flash-attn==1.0.9', 'mosaicml-turbo==0.0.4', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI - 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy', + 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] extra_deps['peft'] = [ From 61dfbd6e93452c5ba8972fa921b914f7e02e1725 Mon Sep 17 00:00:00 2001 From: Sasha Doubov Date: Mon, 25 Sep 2023 21:46:06 -0400 Subject: [PATCH 2/4] Change repeat to expand in GQA (#628) --- llmfoundry/models/layers/attention.py | 43 +++++++++++++++++++++------ 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index da21000cf3..bea6284fb5 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, return original_is_causal +def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor: + """Perform repeat of kv heads along a particular dimension. + + hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim) + n_rep: amount of repetitions of kv_n_heads + Unlike torch.repeat_interleave, this function avoids allocating new memory. + """ + if n_rep == 1: + return hidden + + b, s, kv_n_heads, d = hidden.shape + + hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d) + + return hidden.reshape(b, s, kv_n_heads * n_rep, d) + + def scaled_multihead_dot_product_attention( query: torch.Tensor, key: torch.Tensor, @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention( # grouped query case if kv_n_heads > 1 and kv_n_heads < n_heads: - k = k.repeat_interleave(n_heads // kv_n_heads, dim=1) - v = v.repeat_interleave(n_heads // kv_n_heads, dim=1) + # necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function + k = repeat_kv_for_gqa(k.transpose(1, 2), + n_heads // kv_n_heads).transpose(1, 2) + v = repeat_kv_for_gqa(v.transpose(1, 2), + n_heads // kv_n_heads).transpose(1, 2) if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) @@ -243,10 +263,16 @@ def flash_attn_fn( elif kv_n_heads < n_heads: # Each query belong to a group of kv heads of group size n_heads // kv_n_heads # We repeat each kv head by the group size number to use the underlying MHA kernels - # done along the head dimension = 1 - key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1) - value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads, - dim=1) + + # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) + # we use .view to modify {key, value}_unpad appropriately + + key_unpad = repeat_kv_for_gqa( + key_unpad.view(batch_size, seqlen, kv_n_heads, -1), + n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) + value_unpad = repeat_kv_for_gqa( + value_unpad.view(batch_size, seqlen, kv_n_heads, -1), + n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) dropout_p = dropout_p if training else 0.0 @@ -383,9 +409,8 @@ def triton_flash_attn_fn( elif kv_n_heads < n_heads: # Each query belong to a group of kv heads of group size n_heads // kv_n_heads # We repeat each kv head by the group size number to use the underlying MHA kernels - # done along dim = 2, unlike the implementation for flash and torch attn - key = key.repeat_interleave(n_heads // kv_n_heads, dim=2) - value = value.repeat_interleave(n_heads // kv_n_heads, dim=2) + key = repeat_kv_for_gqa(key, n_heads // kv_n_heads) + value = repeat_kv_for_gqa(value, n_heads // kv_n_heads) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func( # type: ignore From 36b8ccbf78425f3d97402a093860eb585ad4cb25 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 26 Sep 2023 14:07:36 -0400 Subject: [PATCH 3/4] Add node rank to signal paths (#629) * add node rank * lint --- llmfoundry/data/finetuning/dataloader.py | 3 ++- llmfoundry/models/hf/hf_causal_lm.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 5d4dfdbf85..ebb7991dde 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -313,7 +313,8 @@ def _build_hf_dataset_from_remote( # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file - signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed') + signal_file_path = os.path.join( + finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed') if dist.get_local_rank() == 0: try: get_file(path=name, destination=destination, overwrite=True) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 3100478a27..bf6f5288e4 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -164,7 +164,7 @@ def __init__(self, om_model_config: Union[DictConfig, f'init_device="{init_device}" must be either "cpu" or "meta".' ) - signal_file_path = '.local_rank0_completed_autoresume' + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') From 7a8c1a5126a456bd34e4fb74bc01b7c929fb221a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:53:00 -0700 Subject: [PATCH 4/4] bump composer version (#630) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index df3930c405..be5b6708a3 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.1,<0.17', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.3,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.33,<4.34', 'mosaicml-streaming>=0.6,<0.7',