Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add language to word timestamps for Whisper #31572

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def new_chunk():
chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
)
chunks.append(chunk)

Expand Down Expand Up @@ -1065,7 +1065,7 @@ def new_chunk():
chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
)
chunks.append(chunk)

Expand Down Expand Up @@ -1197,12 +1197,16 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
return total_sequence, []


def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is return_language always a bool, or could it be None too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be a bool or None. However, the behavior for return_language=None is equal to that of return_language=False

words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)

optional_language_field = {"language": language} if return_language else {}

timings = [
{
"text": word,
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
**optional_language_field,
}
for word, indices in zip(words, token_indices)
]
Expand Down
63 changes: 59 additions & 4 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def test_torch_large_with_input_features(self):

@slow
@require_torch
@slow
def test_return_timestamps_in_preprocess(self):
pipe = pipeline(
task="automatic-speech-recognition",
Expand All @@ -334,10 +333,10 @@ def test_return_timestamps_in_preprocess(self):
)
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
sample = next(iter(data))
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe")

res = pipe(sample["audio"]["array"])
self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."})

res = pipe(sample["audio"]["array"], return_timestamps=True)
self.assertEqual(
res,
Expand All @@ -346,9 +345,8 @@ def test_return_timestamps_in_preprocess(self):
"chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")

res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
self.assertEqual(
res,
Expand All @@ -368,6 +366,63 @@ def test_return_timestamps_in_preprocess(self):
)
# fmt: on

@slow
@require_torch
def test_return_timestamps_and_language_in_preprocess(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=8,
stride_length_s=1,
return_language=True,
)
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
sample = next(iter(data))

res = pipe(sample["audio"]["array"])
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [{"language": "english", "text": " Conquered returned to its place amidst the tents."}],
},
)

res = pipe(sample["audio"]["array"], return_timestamps=True)
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{
"timestamp": (0.0, 3.36),
"language": "english",
"text": " Conquered returned to its place amidst the tents.",
}
],
},
)

res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
self.assertEqual(
res,
{
'text': ' Conquered returned to its place amidst the tents.',
'chunks': [
{"language": "english",'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{"language": "english", 'text': ' returned', 'timestamp': (1.2, 1.64)},
{"language": "english",'text': ' to', 'timestamp': (1.64, 1.84)},
{"language": "english",'text': ' its', 'timestamp': (1.84, 2.02)},
{"language": "english",'text': ' place', 'timestamp': (2.02, 2.28)},
{"language": "english",'text': ' amidst', 'timestamp': (2.28, 2.8)},
{"language": "english",'text': ' the', 'timestamp': (2.8, 2.98)},
{"language": "english",'text': ' tents.', 'timestamp': (2.98, 3.48)},
],
},
)
# fmt: on

@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
Expand Down