Skip to content

Commit

Permalink
one more
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Dec 5, 2023
1 parent e7be3c1 commit e29fb26
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ class TypicalLogitsWarper(LogitsWarper):
>>> set_seed(18)
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
1, 2, 3, 4, 5, 6, 7, 8, 9 and 10
>>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem
>>> set_seed(18)
Expand All @@ -557,7 +558,7 @@ class TypicalLogitsWarper(LogitsWarper):
1, 2, 3 and 5
>>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token
>>> # under default parameterization, was entirely blocked out
>>> # as seen with greedy decoding, was entirely blocked out
>>> print(outputs.scores[1][0, 934])
tensor(-inf)
```
Expand Down Expand Up @@ -1183,6 +1184,40 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
`batch_id`.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("Alice and Bob", return_tensors="pt")
>>> # By default, it continues generating according to the model's logits
>>> outputs = model.generate(**inputs, max_new_tokens=5)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob are friends
>>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
>>> # For instance, we can force an entire entity to be generated when its beginning is detected.
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
... return entity[2]
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob Marley
```
"""

def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
Expand Down

0 comments on commit e29fb26

Please sign in to comment.