Skip to content

Commit

Permalink
Generate: consistently handle special tokens as tensors (#30624)
Browse files Browse the repository at this point in the history
* tmp commit

* [test_all] mvp

* missing not

* [test_all] final test fixes

* fix musicgen_melody and rag

* [test_all] empty commit

* PR comments

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
gante and ArthurZucker authored May 9, 2024
1 parent 5413b89 commit 7130a22
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 191 deletions.
40 changes: 24 additions & 16 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def process(
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
Expand All @@ -245,8 +245,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
Expand Down Expand Up @@ -322,15 +324,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down Expand Up @@ -513,8 +517,8 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
Expand Down Expand Up @@ -578,8 +582,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -811,15 +817,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down
137 changes: 81 additions & 56 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
Expand Down Expand Up @@ -137,23 +137,23 @@ class MinLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

self.min_length = min_length
self.eos_token_id = eos_token_id

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
Expand All @@ -171,8 +171,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
Expand All @@ -195,18 +195,20 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
def __init__(
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
Expand All @@ -217,8 +219,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_id = torch.tensor(self.eos_token_id, device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, eos_token_id)
self.eos_token_id = self.eos_token_id.to(scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)

Expand Down Expand Up @@ -1195,8 +1197,8 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`, *optional*):
The id(s) of the *end-of-sequence* token.
Examples:
Expand Down Expand Up @@ -1233,18 +1235,22 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
```
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
def __init__(
self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None
):
self.bad_word_ids = bad_words_ids
self._validate_arguments()

# Filter EOS token from bad_words_ids
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
if eos_token_id is not None:
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)

# Forbidding a sequence is equivalent to setting its bias to -inf
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
Expand Down Expand Up @@ -1522,9 +1528,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
Expand All @@ -1548,15 +1553,22 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = torch.full_like(scores, -math.inf)
Expand Down Expand Up @@ -1595,8 +1607,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
Expand Down Expand Up @@ -1656,27 +1668,33 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
def __init__(
self,
exponential_decay_length_penalty: Tuple[int, float],
eos_token_id: Union[int, List[int]],
eos_token_id: Union[int, List[int], torch.Tensor],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.eos_token_id = self.eos_token_id.to(scores.device)
penalties = torch.zeros_like(scores)
scores_processed = scores
if cur_len > self.regulation_start:
for i in self.eos_token_id:
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, i] = penalty
scores_processed = scores + penalties
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
penalty = torch.abs(scores[:, self.eos_token_id]) * (pow(self.regulation_factor, penalty_idx) - 1)
penalties[:, self.eos_token_id] = penalty
scores_processed = scores + penalties
return scores_processed


Expand Down Expand Up @@ -1753,7 +1771,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
self.begin_index = begin_index

def set_begin_index(self, begin_index):
Expand All @@ -1762,8 +1780,8 @@ def set_begin_index(self, begin_index):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens)
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)
Expand Down Expand Up @@ -1801,13 +1819,13 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
"""

def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
self.suppress_tokens = torch.tensor(list(suppress_tokens))

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
self.suppress_tokens = self.suppress_tokens.to(scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores

Expand Down Expand Up @@ -2268,23 +2286,30 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
</Tip>
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
eos_token_id (`Union[int, List[int], torch.Tensor]`):
The id(s) of the *end-of-sequence* token.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""

def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
Expand Down
Loading

0 comments on commit 7130a22

Please sign in to comment.