Skip to content

Commit

Permalink
[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299
Browse files Browse the repository at this point in the history
)

make decoding faster
  • Loading branch information
sanchit-gandhi authored Sep 28, 2023
1 parent 4e931a8 commit 211f93a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 34 deletions.
31 changes: 14 additions & 17 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def __init__(

# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")

self.language = language
super().__init__(
Expand Down Expand Up @@ -560,10 +561,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
Expand All @@ -585,9 +588,7 @@ def timestamp_ids(self, time_precision=0.02):
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Expand All @@ -597,24 +598,17 @@ def _preprocess_token_ids(
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)

def decode(
self,
token_ids,
Expand Down Expand Up @@ -644,6 +638,8 @@ def decode(
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
Expand All @@ -652,8 +648,6 @@ def decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
Expand All @@ -668,6 +662,9 @@ def decode(
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
Expand Down
33 changes: 16 additions & 17 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
import re
from functools import lru_cache
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -190,6 +191,7 @@ def __init__(
self.english_spelling_normalizer = None

self.add_prefix_space = add_prefix_space
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")

self.language = language
self.task = task
Expand Down Expand Up @@ -269,10 +271,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
Expand All @@ -296,9 +300,7 @@ def timestamp_ids(self, time_precision=0.02):
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Expand All @@ -308,24 +310,18 @@ def _preprocess_token_ids(
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)

if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]

return token_ids

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode(
self,
Expand Down Expand Up @@ -356,6 +352,8 @@ def decode(
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
Expand All @@ -364,8 +362,6 @@ def decode(
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)

text = super().decode(
Expand All @@ -380,6 +376,9 @@ def decode(
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
Expand Down

0 comments on commit 211f93a

Please sign in to comment.