Skip to content

Commit

Permalink
Merge pull request #2 from whyNLP/dev-cla
Browse files Browse the repository at this point in the history
Support CLA
  • Loading branch information
why-in-Shanghaitech authored May 28, 2024
2 parents a0ce1f4 + 96a9895 commit 0ac6b7f
Show file tree
Hide file tree
Showing 8 changed files with 864 additions and 5 deletions.
84 changes: 80 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ This work is inspired by [Probabilistic Transformer](https://github.com/whyNLP/P
</div>
</details>

## News
- [24/05/28] This code base now also supports Cross-Layer Attention (CLA). The idea is similar, but they 1) devide the transformer layers into small groups with 2-4 layers in each group; 2) pairs the queries of all the layers with the keys and values of the bottom layer in each group. See details in their paper "[Reducing Transformer Key-Value Cache Size with Cross-Layer Attention](http://arxiv.org/abs/2405.12981)".
- [24/05/20] LCKV initial code release.

## Installation

You may install the dependencies with the following commands:
Expand All @@ -36,10 +40,10 @@ where the CUDA version is set to `12.1`. For other CUDA versions, please refer t

## Usage

Our implementation is based on HuggingFace `transformers` where we register a new model `opt-llama` that supports the Layer-Condensed KV Cache.
Our implementation is based on HuggingFace `transformers`. We register a new model `opt-llama` that supports the Layer-Condensed KV Cache, and a new model `cla-llama` that supports CLA. Both of them are variants of transformer `llama` models.

```python
import models # register the opt-llama model
import models # register the opt-llama and cla-llama model
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
Expand Down Expand Up @@ -73,12 +77,70 @@ We've done this for you in the provided training scripts. You may also refer to

We provide some sample configuration files in the `configs` folder. The config settings are defined in [models/configuration_llama.py](models/configuration_llama.py). You may refer to this file for more details.

#### Layer-Condensed KV Cache (LCKV)

Option 1: Modify the configurations in python:

```python
from models import OptLlamaConfig

# we have prepared a sample configuration file
config = OptLlamaConfig.from_pretrained("configs/tinyllama_opt.json")

# you may modify the configuration as you like
config.num_trained_encoders = 1 # see figure below, b-1 in the paper
config.num_encoders = 8 # see figure below, m+b-1 in the paper
config.layer_types = "0_1_1_1_1_1_1_1_1_1_1_1_1_1_1_1_1_1_1_1_2_0" # 0: std tsfm layer; 1: layers use top layer KVs; 2: layers generate the key-value pair for other layers to use
config.target_layer = -2 # the layer to generate the key-value pair for other layers to use
config.train_kv = False # add MSE loss for the key-value pair, see paper appendix

# we also supports this
config.layer_types = "0_0_0_0_0_0_0_0_0_0_2_1_1_1_1_1_1_1_1_1_1_1" # YOCO config.
config.layer_types = "0_0_0_0_0_0_0_0_0_0_1_1_1_1_1_1_2_1_1_1_1_1" # 2 does not necessarily have to be the last layer
```

Option 2: Modify the configurations in the shell script (via `--config_overrides`):

```sh
accelerate launch run_clm.py \
--config_name configs/tinyllama_opt.json \
--config_overrides model_type=opt-llama,num_encoders=8,num_trained_encoders=1,layer_types=0_1_1_1_1_1_2_0,target_layer=-2,train_kv=false \
...
```

Notice that some of the settings have different names and meanings compared to that in our paper. The following figure explains the correspondence:

<div align="center">
<img width="500" src="https://github.com/whyNLP/LCKV/assets/43395692/74671862-146f-492c-8d17-d0e6a7697170" />
</div>

#### Cross-Layer Attention (CLA)

Option 1: Modify the configurations in python:

```python
from models import ClaLlamaConfig

# we have prepared a sample configuration file
config = ClaLlamaConfig.from_pretrained("configs/tinyllama_cla.json")

# you may modify the configuration as you like
config.layer_types = "2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1" # CLA-2, similar to LCKV, "1" uses the KVs from the nearest previous layer
config.layer_types = "0_2_1_1_2_1_1_2_1_1_2_1_1_2_1_1_2_1_1_2_1_1" # CLA-3, also supports "0"
```

Option 2: Modify the configurations in the shell script (via `--config_overrides`):

```sh
accelerate launch run_clm.py \
--config_name configs/tinyllama_cla.json \
--config_overrides layer_types=2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1 \
...
```

> [!WARNING]
> The authors of CLA tuned the hyperparameters of the model architecture and training settings for the CLA model. The provided configuration files are not the optimal settings for the CLA model. You may need to change the hyperparameters for the CLA model, such as `intermediate_size`, `num_key_value_heads`, etc.
### Training

We use the same [training script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py) as the original `transformers` library. You may refer to the [official documentation](https://huggingface.co/transformers/training.html) for more details.
Expand All @@ -89,7 +151,7 @@ We provide a training script `run_clm.sh` for training a 50M parameter model on
bash run_clm.sh
```

See the script for more details. For pretraining on SlimPajama, please follow the instructions in [tinyllama-zh](https://github.com/whyNLP/tinyllama-zh) and replace the dataset with SlimPajama.
See the script for more details. For CLA, we also provide a sample training script `run_cla.sh`. For pretraining on SlimPajama, please follow the instructions in [tinyllama-zh](https://github.com/whyNLP/tinyllama-zh) and replace the dataset with SlimPajama.

### Inference

Expand All @@ -99,7 +161,7 @@ We use the same [inference script](https://github.com/huggingface/transformers/b
bash run_generation.sh
```

See the script for more details.
You may get responses from the trained model given any prompts. See the script for more details.

### Streaming

Expand Down Expand Up @@ -168,3 +230,17 @@ pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/
pip install -r requirements.txt
```

### The performance is incredibly poor

Some users have reported that the model's performance is incredibly poor and the loss does not decrease when using `torch_dtype=bfloat16` (requried by flash attention). This issue seems to be related to precision problems. Although I have not been able to reproduce this issue, a potential solution could be to use a larger learning rate. To confirm whether the issue is indeed related to precision, one could disable flash attention and use float32 instead. If the loss decreases as expected, then it is likely that the issue is related to precision.

### The code always raises exceptions

Since we start the project very early, this code base uses an old version of `transformers` (v.4.35.2). Newer versions may not be compatible with the code (I think some minor changes would fix the issue).


## Questions

> 1. Is it possible to integrate the LCKV with MQA / GQA?
Yes. The fact is that we have already done this in our experiments. Tinyllama uses 32 attention heads and 4 KV heads. We follow the same setting in our experiments. If you want to experiment with different settings, you may modify the `num_attention_heads` and `num_key_value_heads` in the configuration file.
25 changes: 25 additions & 0 deletions configs/llama_tiny_cla.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_name_or_path": "meta-llama/Llama-2-7b-hf",
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 512,
"initializer_range": 0.02,
"intermediate_size": 1024,
"max_position_embeddings": 1024,
"model_type": "cla-llama",
"num_attention_heads": 8,
"num_hidden_layers": 8,
"num_key_value_heads": 4,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"tie_word_embeddings": false,
"torch_dtype": "float32",
"transformers_version": "4.31.0.dev0",
"use_cache": true,
"layer_types": "2_1_2_1_2_1_2_1"
}
26 changes: 26 additions & 0 deletions configs/tinyllama_cla.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"_name_or_path": "meta-llama/Llama-2-7b-hf",
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5632,
"max_position_embeddings": 2048,
"model_type": "cla-llama",
"num_attention_heads": 32,
"num_hidden_layers": 22,
"num_key_value_heads": 4,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"tie_word_embeddings": false,
"torch_dtype": "float32",
"transformers_version": "4.31.0.dev0",
"use_cache": true,
"vocab_size": 32000,
"layer_types": "2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1_2_1"
}
15 changes: 15 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from .modeling_llama_opt import LlamaForCausalLM as OptLlamaForCausalLM
from .wandb_callback import WandbCallback

from .modeling_llama_cla import LlamaForCausalLM as ClaLlamaForCausalLM
from .configuration_llama import ClaLlamaConfig

from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("opt-llama", OptLlamaConfig)
AutoModelForCausalLM.register(OptLlamaConfig, OptLlamaForCausalLM)
AutoConfig.register("cla-llama", ClaLlamaConfig)
AutoModelForCausalLM.register(ClaLlamaConfig, ClaLlamaForCausalLM)

import os

Expand All @@ -14,13 +19,17 @@
transformers.models.llama.modeling_llama.LlamaRMSNorm = RMSNorm
from . import modeling_llama_opt
modeling_llama_opt.LlamaRMSNorm = RMSNorm
from . import modeling_llama_cla
modeling_llama_cla.LlamaRMSNorm = RMSNorm

if os.environ.get('LCKV_FUSED_CROSSENTROPY', False):
import transformers
from flash_attn.losses.cross_entropy import CrossEntropyLoss
transformers.models.llama.modeling_llama.CrossEntropyLoss = CrossEntropyLoss
from . import modeling_llama_opt
modeling_llama_opt.CrossEntropyLoss = CrossEntropyLoss
from . import modeling_llama_cla
modeling_llama_cla.CrossEntropyLoss = CrossEntropyLoss

if os.environ.get('LCKV_FUSED_ROTARY', False):
import transformers
Expand All @@ -43,12 +52,18 @@
from . import modeling_llama_opt_streaming
modeling_llama_opt_streaming.apply_rotary_pos_emb_q = fused_apply_rotary_pos_emb_q

from . import modeling_llama_cla
modeling_llama_cla.apply_rotary_pos_emb = fused_apply_rotary_pos_emb
modeling_llama_cla.apply_rotary_pos_emb_q = fused_apply_rotary_pos_emb_q

if os.environ.get('LCKV_FUSED_SWIGLU', False):
import transformers
from .llama_fused_swiglu import LlamaMLP
transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP
from . import modeling_llama_opt
modeling_llama_opt.LlamaMLP = LlamaMLP
from . import modeling_llama_cla
modeling_llama_cla.LlamaMLP = LlamaMLP

try:
from streaming_llm import enable_streaming_llm
Expand Down
44 changes: 44 additions & 0 deletions models/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,47 @@ def __init__(
raise ValueError("The target layer should be the layer of type 2.")
if layer_types.count(2) > 1:
raise ValueError("Only one layer can be type 2.")

class ClaLlamaConfig(_LlamaConfig):
model_type = "cla-llama"

def __init__(
self,
layer_types: str = None,
**kwargs,
):
"""
This is an implementation of the Cross-Layer Attention (CLA) model.
Args:
layer_types (`str`, *optional*, defaults to ""):
The type of each layer. The value should be a underscore separated string
of integers. The value "0" means the layer will use the key-value pair in
the original layers as the kv cache. The value "1" means the layer will
use the key-value pair in the nearest lower layer as the kv cache.
The value "2" is the same as "0", but to be consistent with LCKV we name
the bottom layer of each group as "2". The default value is all "0".
Example:
- "2_1_2_1_2_1_2_1_2_1" is a CLA-2 model with 10 layers.
- "0_2_1_1_2_1_1_2_1_1" is a CLA-3 model with 10 layers.
See more info in Figure 2 of the paper "Reducing Transformer Key-Value Cache
Size with Cross-Layer Attention", http://arxiv.org/abs/2405.12981
"""
super().__init__(**kwargs)
self.layer_types = layer_types

if self.layer_types is None:
self.layer_types = "_".join(["0"]*self.num_hidden_layers)

# post check
num_hidden_layers = self.num_hidden_layers
layer_types = [int(x) for x in self.layer_types.split("_")]
if len(layer_types) != num_hidden_layers:
raise ValueError("The number of layer types should be equal to the number of hidden layers.")
for i in range(num_hidden_layers):
if layer_types[i] not in (0, 1, 2):
raise ValueError("The layer type should be one of 0, 1 and 2.")
if layer_types[0] == 1:
raise ValueError("The first layer should be type 0 or 2. It must calculates the KV.")
Loading

0 comments on commit 0ac6b7f

Please sign in to comment.