Skip to content

Commit

Permalink
[Whisper] Fix whisper tokenizer (huggingface#34537)
Browse files Browse the repository at this point in the history
* handle single timestamp ending

* include last timestamp token

* handle single timestamp ending

* avoid floating points arithm limitations

* ensure float64 operations

* new test

* make fixup

* make copies

* handle edge case double tokens ending with different tokens

* handle single timestamp ending

* make fixup

* handle conditioning on prev segments

* fix

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Yoach Lacombe <[email protected]>

* [run-slow] whisper

* don't call item() to avoid unnecessary sync

* fix

---------

Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Eustache Le Bihan <[email protected]>
  • Loading branch information
3 people authored and BernardZach committed Dec 5, 2024
1 parent eb2d804 commit b80e418
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 25 deletions.
40 changes: 30 additions & 10 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def generate(
num_segment_frames: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
time_precision: float = 0.02,
time_precision_features: float = 0.01,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
Expand Down Expand Up @@ -417,6 +418,8 @@ def generate(
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
time_precision_features (`int`, *optional*, defaults to 0.01):
The duration represented by a feature frame in seconds.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
Expand Down Expand Up @@ -629,7 +632,7 @@ def generate(
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek * time_precision / input_stride
time_offset = seek.to(torch.float64) * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)

# 6.2 cut out next 30s segment from input features
Expand Down Expand Up @@ -658,6 +661,7 @@ def generate(
config=self.config,
device=init_tokens.device,
suppress_tokens=suppress_tokens,
timestamp_begin=timestamp_begin,
kwargs=kwargs,
)

Expand Down Expand Up @@ -718,6 +722,7 @@ def generate(
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
time_precision=time_precision,
time_precision_features=time_precision_features,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
Expand Down Expand Up @@ -1665,6 +1670,7 @@ def _prepare_decoder_input_ids(
config,
device,
suppress_tokens,
timestamp_begin,
kwargs,
):
if "decoder_input_ids" in kwargs:
Expand All @@ -1684,6 +1690,14 @@ def _prepare_decoder_input_ids(
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]

for segments in active_segments:
for seg in segments:
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
seg["tokens"] = seg["tokens"][:-1]

if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
Expand Down Expand Up @@ -1778,6 +1792,7 @@ def _retrieve_segment(
timestamp_begin,
seek_num_frames,
time_precision,
time_precision_features,
input_stride,
prev_idx,
idx,
Expand All @@ -1799,17 +1814,22 @@ def _retrieve_segment(
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
else:
# we want to include the last timestamp token in the last segment to know it was no single ending
slices[-1] += 1

last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
for i, current_slice in enumerate(slices):
is_last_slice = i == len(slices) - 1
sliced_tokens = seek_sequence[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
start_timestamp_pos = sliced_tokens[0] - timestamp_begin
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
Expand All @@ -1827,16 +1847,16 @@ def _retrieve_segment(
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
segments = [
{
"start": time_offset[prev_idx],
Expand Down
34 changes: 27 additions & 7 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ def basic_normalize(text, remove_diacritics=False):
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)

def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
def _decode_with_timestamps(
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
Expand All @@ -538,15 +540,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre

cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
Expand All @@ -558,7 +570,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
]
return "".join(outputs)

def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
"""
Compute offsets for a given tokenized input
Expand All @@ -567,6 +579,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, *optional*, defaults to 0.02):
The time ratio to convert from token to time.
segment_size (`int`, *optional*, defaults to 1500):
The number of features in the input mel spectrogram.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand Down Expand Up @@ -597,7 +611,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

Expand All @@ -609,8 +629,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
start_timestamp_position * time_precision + prev_segments_len * time_precision,
end_timestamp_position * time_precision + prev_segments_len * time_precision,
),
}
)
Expand Down
36 changes: 28 additions & 8 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
return super()._encode_plus(*args, **kwargs)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
def _decode_with_timestamps(
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
Expand All @@ -179,15 +181,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre

cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
Expand All @@ -200,7 +212,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
return "".join(outputs)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
"""
Compute offsets for a given tokenized input
Expand All @@ -209,6 +221,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, *optional*, defaults to 0.02):
The time ratio to convert from token to time.
segment_size (`int`, *optional*, defaults to 1500):
The number of features in the input mel spectrogram.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand Down Expand Up @@ -239,7 +253,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

Expand All @@ -251,8 +271,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
start_timestamp_position * time_precision + prev_segments_len * time_precision,
end_timestamp_position * time_precision + prev_segments_len * time_precision,
),
}
)
Expand Down
88 changes: 88 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,94 @@ def test_tiny_longform_timestamps_generation(self):
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

@slow
def test_small_longform_timestamps_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
model.to(torch_device)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]

sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]]
sample = np.array(sample)

input_features = processor(
sample,
sampling_rate=16_000,
padding="longest",
truncation=False,
return_attention_mask=True,
return_tensors="pt",
).input_features

input_features = input_features.to(torch_device)
generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True)

EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.38),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.38, 11.32),
},
{
"text": " He tells us that at this festive season of the year,",
"timestamp": (11.32, 15.0),
},
{
"text": " With Christmas and roast beef looming before us, similes drawn from eating and its results",
"timestamp": (30.0, 36.76),
},
{
"text": " occur most readily to the mind.",
"timestamp": (36.76, 39.80),
},
{
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
"timestamp": (39.80, 45.36),
},
{
"text": " can discover in it but little of rocky Ithaca.",
"timestamp": (45.36, 49.0),
},
{
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
"timestamp": (49.0, 56.28),
},
{
"text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in",
"timestamp": (56.28, 64.12),
},
{
"text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his",
"timestamp": (64.12, 70.76),
},
{
"text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,",
"timestamp": (70.76, 77.16),
},
{
"text": " Next Man",
"timestamp": (77.16, 78.16),
},
]

transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

transcript_segments = [
{
"text": processor.decode(seg["tokens"], skip_special_tokens=True),
"timestamp": (seg["start"].item(), seg["end"].item()),
}
for seg in generated_ids["segments"][0]
]
self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
Expand Down

0 comments on commit b80e418

Please sign in to comment.