From e29fb264f9fd30e43c873b7cffb85be4cf2972fa Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 5 Dec 2023 12:20:31 +0000 Subject: [PATCH] one more --- src/transformers/generation/logits_process.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 550380c07cdee1..59492c37a34c55 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -547,6 +547,7 @@ class TypicalLogitsWarper(LogitsWarper): >>> set_seed(18) >>> outputs = model.generate(**inputs, do_sample=True) >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + 1, 2, 3, 4, 5, 6, 7, 8, 9 and 10 >>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem >>> set_seed(18) @@ -557,7 +558,7 @@ class TypicalLogitsWarper(LogitsWarper): 1, 2, 3 and 5 >>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token - >>> # under default parameterization, was entirely blocked out + >>> # as seen with greedy decoding, was entirely blocked out >>> print(outputs.scores[1][0, 934]) tensor(-inf) ``` @@ -1183,6 +1184,40 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID `batch_id`. + + Examples: + + ```py + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m") + >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m") + + >>> inputs = tokenizer("Alice and Bob", return_tensors="pt") + + >>> # By default, it continues generating according to the model's logits + >>> outputs = model.generate(**inputs, max_new_tokens=5) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice and Bob are friends + + >>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix. + >>> # For instance, we can force an entire entity to be generated when its beginning is detected. + >>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens + >>> def prefix_allowed_tokens_fn(batch_id, input_ids): + ... ''' + ... Attempts to generate 'Bob Marley' when 'Bob' is detected. + ... In this case, `batch_id` is not used, but you can set rules for each batch member. + ... ''' + ... if input_ids[-1] == entity[0]: + ... return entity[1] + ... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]: + ... return entity[2] + ... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens + + >>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + Alice and Bob Marley + ``` """ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):