Skip to content

Commit

Permalink
[Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114)
Browse files Browse the repository at this point in the history
* fix frames

* use smaller chunk length

* correct beam search + tentative stride

* fix whisper word timestamp in batch

* add test batch generation with return token timestamps

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>

* clean a test

* make style + correct typo

* write clearer comments

* explain test in comment

---------

Co-authored-by: sanchit-gandhi <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
  • Loading branch information
4 people authored Dec 22, 2023
1 parent c4df7c1 commit 5da3db3
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 11 deletions.
71 changes: 61 additions & 10 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,6 +2224,7 @@ def generate(
if return_token_timestamps:
kwargs["output_attentions"] = True
return_dict_in_generate = True
kwargs["output_scores"] = True

if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
Expand Down Expand Up @@ -2555,22 +2556,72 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])

if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weights = weights[:, :, :weight_length]

# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = generate_outputs.beam_indices[:, :weight_length]
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)

# Select the cross attention from the right beam for each output sequences
weights = torch.stack(
[
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
for i in range(beam_indices.shape[1])
],
dim=2,
)

timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
batch_size = timestamps.shape[0]

if num_frames is not None:
weights = weights[..., : num_frames // 2]
# two cases:
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
# 2. num_frames is different, compute the DTW matrix for each sample sequentially

# we're using np.unique because num_frames can be int/list/tuple
if len(np.unique(num_frames)) == 1:
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]

# Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)
weights = weights[..., : num_frames // 2]
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
num_frames = np.repeat(num_frames, repeat_time)

# Average the different cross-attention heads.
matrix = weights.mean(dim=1)
if num_frames is None or isinstance(num_frames, int):
# Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)

timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
# Average the different cross-attention heads.
weights = weights.mean(dim=1)

# Perform dynamic time warping on each element of the batch.
for batch_idx in range(timestamps.shape[0]):
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]

# Normalize and smoothen the weights.
std, mean = torch.std_mean(matrix, dim=-2, keepdim=True, unbiased=False)
matrix = (matrix - mean) / std
matrix = _median_filter(matrix, self.config.median_filter_width)

# Average the different cross-attention heads.
matrix = matrix.mean(dim=0)
else:
matrix = weights[batch_idx]

text_indices, time_indices = _dynamic_time_warping(-matrix.double().cpu().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,10 @@ def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
generate_kwargs["return_token_timestamps"] = True

if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
if isinstance(stride, tuple):
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]

if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
Expand Down
29 changes: 29 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,35 @@ def test_tiny_token_timestamp_generation(self):

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))

@slow
def test_tiny_token_timestamp_batch_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]

num_samples = 4
num_return_sequences = 2

input_speech = self._load_datasamples(num_samples)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)

generate_outputs = model.generate(
input_features,
max_length=448,
return_timestamps=True,
return_token_timestamps=True,
num_beams=3,
num_return_sequences=num_return_sequences,
)

self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)

self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)

@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"
Expand Down
44 changes: 44 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,50 @@ def test_whisper_timestamp_prediction(self):
},
)

@slow
@require_torch
def test_whisper_word_timestamps_batched(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=3,
return_timestamps="word",
)
data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = data[0]["audio"]

# not the same output as test_simple_whisper_asr because of chunking
EXPECTED_OUTPUT = {
"text": " Mr. Quilder is the apostle of the middle classes and we are glad to welcome his gospel.",
"chunks": [
{"text": " Mr.", "timestamp": (0.48, 0.96)},
{"text": " Quilder", "timestamp": (0.96, 1.24)},
{"text": " is", "timestamp": (1.24, 1.5)},
{"text": " the", "timestamp": (1.5, 1.72)},
{"text": " apostle", "timestamp": (1.72, 1.98)},
{"text": " of", "timestamp": (1.98, 2.32)},
{"text": " the", "timestamp": (2.32, 2.5)},
{"text": " middle", "timestamp": (2.5, 2.68)},
{"text": " classes", "timestamp": (2.68, 3.2)},
{"text": " and", "timestamp": (3.2, 3.56)},
{"text": " we", "timestamp": (3.56, 3.68)},
{"text": " are", "timestamp": (3.68, 3.8)},
{"text": " glad", "timestamp": (3.8, 4.1)},
{"text": " to", "timestamp": (4.1, 4.34)},
{"text": " welcome", "timestamp": (4.3, 4.6)},
{"text": " his", "timestamp": (4.6, 4.94)},
{"text": " gospel.", "timestamp": (4.94, 5.82)},
],
}

# batch size 1: copy the audio sample since pipeline consumes it
output = pipe(sample.copy(), batch_size=1)
self.assertDictEqual(output, EXPECTED_OUTPUT)

# batch size 2: input audio is chunked into smaller pieces so it's testing batching
output = pipe(sample, batch_size=2)
self.assertDictEqual(output, EXPECTED_OUTPUT)

@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
Expand Down

0 comments on commit 5da3db3

Please sign in to comment.