Skip to content

Commit

Permalink
Adding support for Rotary Position Embeddings (#675)
Browse files Browse the repository at this point in the history
* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* removed the roformer impementation of rope

* ..

* fixed all the lint errors

* ..

* ..

* ../llmfoundry/models/mpt/modeling_mpt.py

* ..

* ..

* ..

* added unit test to test rotary embeddings

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* Update llmfoundry/models/mpt/modeling_mpt.py

Accepting the suggestion

Co-authored-by: Vitaliy Chiley <[email protected]>

* incorporated some suggestions from the pr

* ..

* ..

* ..

* ..

* ..

* ..

* ..

* added mark for gpu in the rotary embedding test

* ..

* ..

* ..

* removed thecode for hf implementation of rope

* ..

* ..

* added tests

* ..

* ..

* ...

* ..

* ..

* ..

* ..

* ..

* fixed the tests after the merge

* minor change

* Fixed some tests failing due to a transformers library bug

* added check for flash_attention before importing their rotary embedding

* added check for flash_attention in tests before using dail rope

* fixed tests

* ..

* ..

* temporary fix

* ..

* ..

* fixed a test

* ..

* minor change

* minor changes

* added documentation

* added documentation

* temp commit

* made _set_config_defaults recursive

* minor changes

* reformatted tutorial table

* reformatted tutorial table

* reformatted tutorial table

* added documentation on how to install flash attention 2

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* minor changes

* ..

* resolved some comments from the PR

* fixed tests

* modified is_flash_v2_installed

* minor changes

* Update TUTORIAL.md

Co-authored-by: Daniel King <[email protected]>

* Update TUTORIAL.md

Co-authored-by: Daniel King <[email protected]>

* Update TUTORIAL.md

Co-authored-by: Daniel King <[email protected]>

* Update TUTORIAL.md

Co-authored-by: Daniel King <[email protected]>

* resolved PR comments

---------

Co-authored-by: Shashank Rajput <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
4 people authored Nov 6, 2023
1 parent 58d7cf3 commit 1d504c8
Show file tree
Hide file tree
Showing 8 changed files with 952 additions and 187 deletions.
49 changes: 38 additions & 11 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,42 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se

This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release.

- [LLM Foundry Tutorial](#llm-foundry-tutorial)
- [Intro](#intro)
- [How this repo is structured](#how-this-repo-is-structured)
- [Key components](#key-components)
- [Composer](#composer)
- [StreamingDataset](#streamingdataset)
- [MCLI](#mcli)
- [How the YAMLs work](#how-the-yamls-work)
- [Example Workflows](#example-workflows)
- [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally)
- [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b)
- [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b)
- [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning)
- [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation)
- [Data](#data)
- [Modeling](#modeling)
- [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch)
- [FAQs](#faqs)
- [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
- [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
- [What hardware can I train on?](#what-hardware-can-i-train-on)
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What is FSDP?](#what-is-fsdp)
- [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use)
- [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora)
- [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
- [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
- [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
- [Common installation issues](#common-installation-issues)
- [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
- [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
- [What hardware can I train on?](#what-hardware-can-i-train-on)
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on)
- [What is FSDP?](#what-is-fsdp)
- [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use)
- [Limitations](#limitations)
- [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir)
- [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus)
- [Support for FlashAttention-2](#support-for-flashattention-2)
- [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support)
- [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora)
- [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
- [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
- [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support)
- [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
- [Common installation issues](#common-installation-issues)

Let’s get started!

Expand Down Expand Up @@ -328,6 +343,18 @@ The majority of our training setups use `triton`. -->
Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes.
What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance.

#### Support for FlashAttention-2
- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention).

### What kinds of positional embeddings does LLM Foundry support?
Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf).

| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings | <pre>model:<br> learned_pos_emb:&nbsp;True</pre>| 65.7 | |
| ALiBi | <pre>model:<br> attn_config:<br> alibi:&nbsp;True</pre>| 64.5 | Requires Triton or Torch attention. |
| RoPE (Dao-AILab Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;dail</pre>| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;hf</pre>| 62.3 | |

### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
Expand Down
71 changes: 58 additions & 13 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import Any, List, Optional, Tuple
from typing import Any, Optional

import torch
import torch.nn as nn
Expand All @@ -17,12 +17,13 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def is_flash_v2_installed():
def is_flash_v2_installed(v2_version: str = '2.0.0'):
assert version.parse(v2_version) >= version.parse('2.0.0')
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
return version.parse(flash_attn.__version__) >= version.parse(v2_version)


def is_flash_v1_installed():
Expand All @@ -33,6 +34,16 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
if is_flash_v1_installed():
import transformers
transformers.utils.is_flash_attn_available = lambda: False

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
Expand Down Expand Up @@ -70,7 +81,7 @@ def scaled_multihead_dot_product_attention(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -79,7 +90,7 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

if multiquery:
Expand Down Expand Up @@ -183,7 +194,7 @@ def scaled_multihead_dot_product_attention(


def check_valid_inputs(*tensors: torch.Tensor,
valid_dtypes: Optional[List[torch.dtype]] = None):
valid_dtypes: Optional[list[torch.dtype]] = None):
if valid_dtypes is None:
valid_dtypes = [torch.float16, torch.bfloat16]
for tensor in tensors:
Expand All @@ -199,7 +210,7 @@ def flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -208,7 +219,7 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -337,7 +348,7 @@ def triton_flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -346,7 +357,7 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
Expand Down Expand Up @@ -552,12 +563,13 @@ def __init__(
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)

Expand All @@ -581,6 +593,39 @@ def forward(
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
Expand Down Expand Up @@ -677,7 +722,7 @@ def __init__(
def attn_bias_shape(
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool,
use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
Expand Down
43 changes: 30 additions & 13 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}


class MPTBlock(nn.Module):

Expand All @@ -30,18 +55,7 @@ def __init__(
**kwargs: Any,
):
if attn_config is None:
attn_config = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
}
attn_config = attn_config_defaults

if ffn_config is None:
ffn_config = {
Expand All @@ -58,7 +72,8 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
'alibi_bias_max'
'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
'rope_dail_config', 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
k: v
Expand Down Expand Up @@ -94,6 +109,7 @@ def forward(
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[Dict] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
Expand All @@ -104,6 +120,7 @@ def forward(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
Expand Down
Loading

0 comments on commit 1d504c8

Please sign in to comment.