Skip to content

Commit

Permalink
[Whisper] Allow basic text normalization (#26149)
Browse files Browse the repository at this point in the history
* [Whisper] Allow basic text normalization

* up

* style copies
  • Loading branch information
sanchit-gandhi authored Oct 3, 2023
1 parent bd62059 commit 57f44dc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 10 deletions.
37 changes: 32 additions & 5 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer


VOCAB_FILES_NAMES = {
Expand Down Expand Up @@ -510,6 +510,15 @@ def _normalize(self, text):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)

@staticmethod
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)

def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
Expand Down Expand Up @@ -617,6 +626,9 @@ def decode(
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
"""
Expand All @@ -633,15 +645,24 @@ def decode(
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
Expand All @@ -654,7 +675,9 @@ def decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs,
)
if decode_with_timestamps:
Expand All @@ -676,7 +699,8 @@ def _decode(
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
decode_with_timestamps: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
Expand Down Expand Up @@ -705,6 +729,9 @@ def _decode(
if normalize:
clean_text = self._normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text

Expand Down
39 changes: 34 additions & 5 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr


Expand Down Expand Up @@ -331,6 +331,9 @@ def decode(
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
"""
Expand All @@ -347,15 +350,24 @@ def decode(
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
Expand All @@ -368,7 +380,9 @@ def decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs,
)
if decode_with_timestamps:
Expand All @@ -385,12 +399,17 @@ def decode(
return {"text": text, "offsets": offsets}
return text

def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
def _decode(
self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs
) -> str:
text = super()._decode(*args, **kwargs)

if normalize:
clean_text = self._normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text

Expand All @@ -403,6 +422,16 @@ def _normalize(self, text):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)

@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)

Expand Down
34 changes: 34 additions & 0 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,40 @@ def test_combine_tokens_into_words(self):
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])

def test_basic_normalizer(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()

input_str = "Hola güey!"
expected_output_normalize = "hola güey "
expected_output_diacritics = "hola guey "

# tokenizer tests
encoded_input = tokenizer(input_str).input_ids
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)

decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)

decoded_output_diacritics = tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)

# fast tokenizer tests
encoded_input = rust_tokenizer(input_str).input_ids
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)

decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)

decoded_output_diacritics = rust_tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)


class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
Expand Down

0 comments on commit 57f44dc

Please sign in to comment.