Skip to content

Commit

Permalink
Generate: Add new decoding strategy "DoLa" in .generate() (#29619)
Browse files Browse the repository at this point in the history
Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
voidism and gante authored Jul 9, 2024
1 parent 99c0e55 commit d094d8d
Show file tree
Hide file tree
Showing 7 changed files with 530 additions and 5 deletions.
64 changes: 60 additions & 4 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.

KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
Expand Down Expand Up @@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I

## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
the inner functioning of watermarking, it is recommended to refer to the paper.

The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
Expand Down Expand Up @@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t

Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
### DoLa Decoding

**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883).

DoLa is achieved by contrasting the differences in logits obtained from final
layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers.

Do the following two steps to activate DoLa decoding when calling the `model.generate` function:
1. Set the `dola_layers` argument, which can be either a string or a list of integers.
- If set to a string, it can be one of `low`, `high`.
- If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on.
2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding.

See the following examples for DoLa decoding with the 32-layer LLaMA-7B model.

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16)
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> model.to(device)
>>> set_seed(42)

>>> text = "On what date was the Declaration of Independence officially signed?"
>>> inputs = tokenizer(text, return_tensors="pt").to(device)

# Vanilla greddy decoding
>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']

# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30)
>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high')
>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.']

# DoLa decoding with contrasting specific layers (layers 28 and 30)
>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2)
>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2']
```

#### Understanding the `dola_layers` argument

`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer.

Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively.
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.

The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.
38 changes: 38 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum):
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
DOLA_GENERATION = "dola_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
Expand All @@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin):
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *dola decoding* if `dola_layers` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
Expand Down Expand Up @@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
Expand Down Expand Up @@ -397,6 +411,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
Expand Down Expand Up @@ -495,6 +512,16 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)

# DoLa generation may extend some generation modes
if self.dola_layers is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.DOLA_GENERATION
else:
raise ValueError(
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode

def validate(self, is_init=False):
Expand Down Expand Up @@ -700,6 +727,17 @@ def validate(self, is_init=False):
"`generate()` (or a pipeline) directly."
)

# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
Loading

0 comments on commit d094d8d

Please sign in to comment.