Skip to content

Commit

Permalink
[Whisper] Add sequential longform decoding (#27492)
Browse files Browse the repository at this point in the history
* [Whisper] Add seq gen

* [Whisper] Add seq gen

* more debug

* Fix whisper logit processor

* Improve whisper code further

* Fix more

* more debug

* more debug

* Improve further

* Add tests

* Prep for batch size > 1

* Get batch_size>1 working

* Correct more

* Add extensive tests

* more debug

* more debug

* more debug

* add more tests

* more debug

* Apply suggestions from code review

* more debug

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* Add more examples

* add comments to explain the code better

* fix more

* add comments to explain the code better

* add comments to explain the code better

* correct

* correct

* finalize

* Apply suggestions from code review

* Apply suggestions from code review
  • Loading branch information
patrickvonplaten authored Nov 22, 2023
1 parent b2c63c7 commit 4151fbb
Show file tree
Hide file tree
Showing 5 changed files with 836 additions and 83 deletions.
50 changes: 36 additions & 14 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
Expand Down Expand Up @@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, generate_config): # support for the kwargs
def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)

self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")

if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores

# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist())
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())

last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin

Expand All @@ -1549,8 +1556,23 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")

# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1

scores[k, self.timestamp_begin : timestamp_last] = -float("inf")

# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")

if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")

Expand All @@ -1559,7 +1581,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")

return scores
Expand Down
Loading

0 comments on commit 4151fbb

Please sign in to comment.