diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index b000cc06779918..68430de643f17b 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models. -Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. +Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache] (https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper. @@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I ## Watermarking -The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". +The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated. The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is -statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper -["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on +statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper +["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on the inner functioning of watermarking, it is recommended to refer to the paper. The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model @@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). +### DoLa Decoding + +**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the +hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883). + +DoLa is achieved by contrasting the differences in logits obtained from final +layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers. + +Do the following two steps to activate DoLa decoding when calling the `model.generate` function: +1. Set the `dola_layers` argument, which can be either a string or a list of integers. + - If set to a string, it can be one of `low`, `high`. + - If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on. +2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding. + +See the following examples for DoLa decoding with the 32-layer LLaMA-7B model. + +```python +>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed +>>> import torch + +>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") +>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16) +>>> device = 'cuda' if torch.cuda.is_available() else 'cpu' +>>> model.to(device) +>>> set_seed(42) + +>>> text = "On what date was the Declaration of Independence officially signed?" +>>> inputs = tokenizer(text, return_tensors="pt").to(device) + +# Vanilla greddy decoding +>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50) +>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True) +['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,'] + +# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30) +>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high') +>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True) +['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.'] + +# DoLa decoding with contrasting specific layers (layers 28 and 30) +>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2) +>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True) +['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2'] +``` + +#### Understanding the `dola_layers` argument + +`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer. + +Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively. +- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively. +- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively. +- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function. +- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers. + +The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper. diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 8ba17a6a350f78..dcdccad23a54c1 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum): GREEDY_SEARCH = "greedy_search" SAMPLE = "sample" ASSISTED_GENERATION = "assisted_generation" + DOLA_GENERATION = "dola_generation" # Beam methods BEAM_SEARCH = "beam_search" BEAM_SAMPLE = "beam_sample" @@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin): - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1` - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None` - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()` + - *dola decoding* if `dola_layers` is passed to `.generate()` To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). @@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin): max_matching_ngram_size (`int`, *optional*, default to `None`): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. + > Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883) + + dola_layers (`str` or `List[int]`, *optional*): + The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must + be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. + "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the + layers up to the last 20 layers. + If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. + The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, + `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) + or [the paper](https://arxiv.org/abs/2309.03883) for more details. + > Parameters specific to the caching mechanism: cache_implementation (`str`, *optional*, default to `None`): @@ -397,6 +411,9 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # DoLa generation + self.dola_layers = kwargs.pop("dola_layers", None) + # Cache implementation self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) @@ -495,6 +512,16 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " "is only supported with Greedy Search and Sample." ) + + # DoLa generation may extend some generation modes + if self.dola_layers is not None: + if generation_mode in ("greedy_search", "sample"): + generation_mode = GenerationMode.DOLA_GENERATION + else: + raise ValueError( + "You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate " + "is only supported with Greedy Search and Sample." + ) return generation_mode def validate(self, is_init=False): @@ -700,6 +727,17 @@ def validate(self, is_init=False): "`generate()` (or a pipeline) directly." ) + # 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2 + if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): + dola_decoding_wrong_parameter_msg = ( + "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, " + "which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`." + ) + warnings.warn( + dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty), + UserWarning, + ) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 371c59460309b3..6b4a055fba8de3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -20,9 +20,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist from torch import nn +from torch.nn import functional as F from ..cache_utils import ( Cache, @@ -1908,6 +1910,28 @@ def generate( streamer=streamer, **model_kwargs, ) + elif generation_mode == GenerationMode.DOLA_GENERATION: + if self._is_stateful: + # DoLa decoding was not designed for stateful models, and would require some changes + raise ValueError( + f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" + ) + prepared_logits_warper = ( + self._get_logits_warper(generation_config, device=input_ids.device) + if generation_config.do_sample + else None + ) + result = self._dola_decoding( + input_ids, + dola_layers=generation_config.dola_layers, + logits_processor=prepared_logits_processor, + logits_warper=prepared_logits_warper, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: if not model_kwargs["use_cache"]: @@ -2189,6 +2213,225 @@ def contrastive_search(self, *args, **kwargs): ) return self._contrastive_search(*args, **kwargs) + def _dola_decoding( + self, + input_ids: torch.LongTensor, + dola_layers: Union[str, List[int]], + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"] = None, + logits_warper: Optional[LogitsProcessorList] = None, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be + used for decoder-only text models. + The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language + Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + dola_layers (`Union[str, List[int]]`): + The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which + means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices + to be used for candidate layers. The 0-th layer is the word embedding layer of the model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + + if self.config.is_encoder_decoder: + raise ValueError("DoLa decoding is only available for decoder-only models.") + # init values + + pad_token_id = generation_config.pad_token_id + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + this_peer_finished = False + + # prepare layers for DoLa decoding + final_layer = self.config.num_hidden_layers + # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, + # as the early exit from word embeddings will become identity function + # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th + # layer otherwise. Notice that DoLa does not help shallow models much. + if not self.config.tie_word_embeddings: + start_layer = 0 + elif final_layer > 2: + start_layer = 2 + elif final_layer == 2: + start_layer = 1 + else: + start_layer = 0 + + # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` + # are used for `'low'` and `'high'` layers, respectively. + # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for + # `'low'` and `'high'` layers, respectively. + if isinstance(dola_layers, str) and dola_layers == "low": + if start_layer == final_layer // 2: + candidate_premature_layers = [start_layer] + else: + candidate_premature_layers = ( + list(range(start_layer, final_layer // 2, 2)) + if final_layer <= 40 + else list(range(start_layer, 20, 2)) + ) + elif isinstance(dola_layers, str) and dola_layers == "high": + candidate_premature_layers = ( + list(range(final_layer // 2, final_layer, 2)) + if final_layer <= 40 + else list(range(final_layer - 20, final_layer, 2)) + ) + # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. + elif isinstance(dola_layers, list): + candidate_premature_layers = [i for i in dola_layers if i < final_layer] + else: + raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.") + + lm_head = self.get_output_embeddings() + if lm_head is None: + raise ValueError("DoLa is not supported for models that don't have output embeddings.") + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone() + final_logits = outputs.logits[:, -1, :] + candidate_premature_logits = {} + for candidate_premature_layer in candidate_premature_layers: + candidate_premature_logits[candidate_premature_layer] = lm_head( + outputs.hidden_states[candidate_premature_layer][:, -1, :] + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = _dola_select_contrast( + candidate_premature_layers, candidate_premature_logits, final_logits + ) + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + if do_sample: # sample + next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (final_layer_next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + if do_sample: # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: # argmax + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + @torch.no_grad() def _contrastive_search( self, @@ -4197,3 +4440,75 @@ def _concat(data): # Return a new object of the inferred class with the concatenated attributes return model_output_cls(**concatenated_data) + + +def _relative_top_filter( + scores: torch.FloatTensor, + baseline_scores: torch.FloatTensor, + relative_top: float = 0.1, + filter_value: float = -float("Inf"), + base_filter_value=-1e-3, + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235 + Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution. + """ + scores_normalized = scores.log_softmax(dim=-1) + baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) + sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) + min_thresh = sorted_logits[..., min_tokens_to_keep - 1] + probs_max = torch.max(scores_normalized, dim=-1).values + probs_thresh = probs_max + np.log(relative_top) + probs_thresh = torch.min(min_thresh, probs_thresh) + probs_thresh = probs_thresh.unsqueeze(-1) + baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value + scores_normalized[scores_normalized < probs_thresh] = filter_value + return scores_normalized, baseline_scores_normalized + + +def _dola_select_contrast( + candidate_premature_layers: List[int], + candidate_premature_logits: Dict[int, torch.FloatTensor], + final_logits: torch.FloatTensor, +) -> torch.FloatTensor: + if len(candidate_premature_layers) == 1: + base_logits = candidate_premature_logits[candidate_premature_layers[0]] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits + + # 1. Stacking all premature_layers into a new dimension + stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0) + + # 2. Calculate the softmax values for mature_layer and all premature_layers + # shape: (batch_size, vocab_size) + softmax_mature_layer = F.softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) + + # 3. Calculate the average distribution + # shape: (num_premature_layers, batch_size, vocab_size) + avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) + + # 4. Calculate log-softmax for the KL divergence + # shape: (batch_size, vocab_size) + log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) + + # 5. Calculate the KL divergences and then the JS divergences + # shape: (num_premature_layers, batch_size) + kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1) + # shape: (num_premature_layers, batch_size) + kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1) + js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) + + # 6. Reduce the batchmean + js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) + premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] + + base_logits = candidate_premature_logits[premature_layer] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8fa41fbdbe2b07..cab6fe8d094cd6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1264,6 +1264,55 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): for output in (output_greedy, output_prompt_lookup): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_dola_decoding_sample(self): + # TODO (joao): investigate skips, try to reduce incompatibilities + for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support DoLa decoding") + + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): + self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.") + + if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]): + self.skipTest("DoLa is not supported for models that don't return layerwise hidden states") + + # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm + config, input_ids, attention_mask = self._get_input_ids_and_config() + + # Some models don't support the cache and returning past_key_values + if not hasattr(config, "use_cache"): + config.use_cache = False + else: + config.use_cache = True + + # Encoder-decoder models are not supported + if config.is_encoder_decoder: + self.skipTest("DoLa is not supported for encoder-decoder models") + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + + if model.get_output_embeddings() is None: + self.skipTest("DoLa is not supported for models that don't have output embeddings") + # Sets dola generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see b) + "num_beams": 1, + "do_sample": True, + "output_scores": True, + "output_logits": True, + "output_hidden_states": True, + "output_attentions": self.has_attentions, + "return_dict_in_generate": True, + } + generation_kwargs.update({"dola_layers": "low"}) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs) + self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) + def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index e6a40eb10255a5..373885f535bd73 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -839,7 +839,6 @@ def test_model_7b_4bit(self): output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) @slow @@ -898,3 +897,24 @@ def test_compile_static_cache(self): ) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) + + def test_model_2b_bf16_dola(self): + model_id = "google/gemma-2b" + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXTS = [ + "Hello I am doing an experiment and need to get the mass of a block. The problem is, it has no scale", + "Hi today we have the review for a 2016/2017 season of", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate( + **inputs, max_new_tokens=20, do_sample=False, dola_layers="low", repetition_penalty=1.2 + ) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0935e802c685b9..3e1d5c26eb1587 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -703,6 +703,29 @@ def test_model_7b_logits(self): ) ) + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @slow @require_torch_gpu @require_read_token diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index f2ba612a47b1c7..34246b8e967a82 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -555,6 +555,30 @@ def test_model_7b_generation(self): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + """My favourite condiment is 100% ketchup. I love it on everything, and I’m not ash""" + ) + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16 + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate( + input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2 + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + @require_bitsandbytes @slow @require_flash_attn