Skip to content

Commit

Permalink
final update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Nov 22, 2023
1 parent 4113c45 commit 0bd1b0c
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,73 @@ Check out how to use this integration throughout this [Google Colab demo](https:

You can benefit from fused modules by passing an `AwqConfig` with `fuse_modules=True` and your expected maximum sequence length for generation to `fuse_max_seq_len`. For architectures that do not support `fused_modules=True`, you can still fuse the modules, however you need to pass a custom `fusing_mapping` to `AwqConfig()`. Let's dive into these specific usecases.

Note that you cannot combine fusing modules and other optimization techniques such as Flash Attention 2.

#### Fusing modules for supported architectures

...
Currently we support out of the box AWQ module fusing for `llama` and `mistral`.

To enable this feature for supported architectures simply create an `AwqConfig` and pass the arguments `fuse_max_seq_len` and `do_fuse=True`.

For example to enable module fusing for the model `TheBloke/Mistral-7B-OpenOrca-AWQ`, run:

```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Mistral-7B-OpenOrca-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
do_fuse=True,
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

Note that you need to prealably define `fuse_max_seq_len` to `AwqConfig`. That total sequence length should include the context length and the expected generation length. You can set it to a large value to be on the safe zone.

You can also apply module fusing for other architectures that are not supported.

#### Fusing modules for unsupported architectures

...
For architectures that do not support out of the box module fusing, you can pass a custom fusing mapping; simply pass a dictionnary `modules_to_fuse` to `AwqConfig`, let's take an example with the Yi model:


```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Yi-34B-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
modules_to_fuse={
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"layernorm": ["ln1", "ln2", "norm"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"use_alibi": False,
"num_attention_heads": 56,
"num_key_value_heads": 8,
"hidden_size": 7168
}
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

The parameter `modules_to_fuse` needs to have the following respective fields:

- `"attention"`: The names of the attention layers to fuse - in the order: query, key, value and output projection layer. In case you don't want to fuse the attention layers you can pass an empty list.
- `"layernorm"`: The names of all the layernorm layers you want to replace with a custom fused layer norm. In case you don't want to fuse these layers you can also pass an empty list.
- `"mlp"`: The names of the MLP layers you want to fuse into a single MLP layer in the order: (gate (dense layer post-attention) / up / down layers).
- `"use_alibi"`: If you model uses alibi positional embedding
- `"num_attention_heads"`: The number of attention heads
- `"num_key_value_heads"`: This is the number of key value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used.
- `"hidden_size"`: Dimension of the hidden representations.


#### Benchmarks

Expand Down

0 comments on commit 0bd1b0c

Please sign in to comment.