Skip to content

Commit

Permalink
1st half
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 5, 2023
1 parent ac97507 commit e7be3c1
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 61 deletions.
225 changes: 174 additions & 51 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa

class MinLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
[`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
like most LLMs, the length includes the prompt.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["A number:"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
>>> gen_out = model.generate(**inputs, min_length=3)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
>>> # necessarily incorrect
>>> gen_out = model.generate(**inputs, min_length=10)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand, nine hundred and ninety-four
```
"""

def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
Expand All @@ -133,9 +159,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly
generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account
the newly generated ones.
Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
Args:
prompt_length_to_skip (`int`):
Expand All @@ -149,29 +173,21 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["A number:"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # If the maximum length (default = 20) is smaller than the minimum length constraint, the latter is ignored!
>>> outputs = model.generate(**inputs, min_new_tokens=30)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company that has been working on a new product for the past year.
>>> # For testing purposes, let's set `eos_token` to `"company"`, the first generated token. This will make
>>> # generation end there.
>>> outputs = model.generate(**inputs, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company
>>> # Increasing `min_new_tokens` will make generation ignore occurences `"company"` (eos token) before the
>>> # minimum length condition is honored.
>>> outputs = model.generate(**inputs, min_new_tokens=2, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new company
>>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
>>> # necessarily incorrect
>>> gen_out = model.generate(**inputs, min_new_tokens=2)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand
```
"""

Expand Down Expand Up @@ -205,7 +221,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class TemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens.
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
[`TopKLogitsWarper`].
<Tip>
Expand Down Expand Up @@ -269,22 +286,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
generation process, the probability distribution for the next token is determined using a formula that incorporates
token scores based on their occurrence in the generated sequence. Tokens with higher scores are more likely to be
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
Examples:
Expand Down Expand Up @@ -327,20 +340,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original
input.
[`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty
that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of
selecting tokens that were present in the prompt.
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
strongly.
It was designed to avoid hallucination in input-grounded tasks, like summarization. Although originally intended
for encoder-decoder models, it can also be used with decoder-only models like LLMs.
Args:
penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between
0.0 and 1.0 rewards hallucination.
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 rewards prompt tokens. Between 0.0
and 1.0 penalizes prompt tokens.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should be repeated within the decoder ids.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["Alice and Bob. The third member's name was"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Alice and Bob. The third member's name was not mentioned.
>>> # With the `encoder_repetition_penalty` argument we can trigger this logits processor in `generate`, which can
>>> # promote the use of prompt tokens ("Bob" in this example)
>>> gen_out = model.generate(**inputs, encoder_repetition_penalty=1.2)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Alice and Bob. The third member's name was Bob. The third member's name was Bob.
```
"""

def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
Expand All @@ -363,7 +395,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class TopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
Args:
top_p (`float`):
Expand All @@ -375,6 +408,7 @@ class TopPLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
Expand Down Expand Up @@ -426,7 +460,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class TopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
Args:
top_k (`int`):
Expand All @@ -435,6 +470,29 @@ class TopKLogitsWarper(LogitsWarper):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt")
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, G, H, I. A, M
>>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
>>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
>>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, E, F, G, H, I
```
"""

def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
Expand All @@ -455,8 +513,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class TypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
log probability is close to the entropy of the token probability distribution. This means that the most likely
tokens may be discarded in the process.
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
Args:
mass (`float`, *optional*, defaults to 0.9):
Expand All @@ -465,6 +526,41 @@ class TypicalLogitsWarper(LogitsWarper):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("1, 2, 3", return_tensors="pt")
>>> # We can see that greedy decoding produces a sequence of numbers
>>> outputs = model.generate(**inputs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
>>> # For this particular seed, we can see that sampling produces nearly the same low-information (= low entropy)
>>> # sequence
>>> set_seed(18)
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
>>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem
>>> set_seed(18)
>>> outputs = model.generate(
... **inputs, do_sample=True, typical_p=0.1, return_dict_in_generate=True, output_scores=True
... )
>>> print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0])
1, 2, 3 and 5
>>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token
>>> # under default parameterization, was entirely blocked out
>>> print(outputs.scores[1][0, 934])
tensor(-inf)
```
"""

def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
Expand Down Expand Up @@ -721,7 +817,8 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
from consideration when further processing the scores.
from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the
prompt is also considered to obtain the n-grams.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
<Tip>
Expand Down Expand Up @@ -774,14 +871,40 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See
[ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350).
[`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent
the repetition of n-grams present in the prompt.
It was designed to promote chattiness in a language model, by preventing the generation of n-grams present in
previous conversation rounds.
Args:
encoder_ngram_size (`int`):
All ngrams of size `ngram_size` can only occur within the encoder input ids.
encoder_input_ids (`int`):
The encoder_input_ids that should not be repeated within the decoder ids.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("Alice: I love cats. What do you love?\nBob:", return_tensors="pt")
>>> # With greedy decoding, we see Bob repeating Alice's opinion. If Bob was a chatbot, it would be a poor one.
>>> outputs = model.generate(**inputs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice: I love cats. What do you love?
Bob: I love cats. What do you
>>> # With this logits processor, we can prevent Bob from repeating Alice's opinion.
>>> outputs = model.generate(**inputs, encoder_no_repeat_ngram_size=2)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice: I love cats. What do you love?
Bob: My cats are very cute.
```
"""

def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,16 +1031,9 @@ def _get_logits_processor(
generation_config.encoder_no_repeat_ngram_size is not None
and generation_config.encoder_no_repeat_ngram_size > 0
):
if self.config.is_encoder_decoder:
processors.append(
EncoderNoRepeatNGramLogitsProcessor(
generation_config.encoder_no_repeat_ngram_size, encoder_input_ids
)
)
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
processors.append(
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
)
if generation_config.bad_words_ids is not None:
processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
Expand Down

0 comments on commit e7be3c1

Please sign in to comment.