From 23c50eca9db107bfef233d856b5c496a218d409c Mon Sep 17 00:00:00 2001 From: Nikhil Narasimhan Date: Thu, 26 Sep 2024 20:21:41 +0530 Subject: [PATCH 1/3] Update rnnt_decoding.py --- .../asr/parts/submodules/rnnt_decoding.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 70cc5be13..912bc2dcc 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -502,7 +502,7 @@ def rnnt_decoder_predictions_tensor( if self.preserve_frame_confidence and ( self.preserve_word_confidence or self.preserve_token_confidence ): - hypotheses = self.compute_confidence(hypotheses) + hypotheses = self.compute_confidence(hypotheses, lang_ids) return hypotheses, None best_hyp_text = [h.text for h in hypotheses] @@ -561,7 +561,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis], lang_ids: List[st return hypotheses_list - def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]: + def compute_confidence(self, hypotheses_list: List[Hypothesis], lang_ids: List[str] = None) -> List[Hypothesis]: """ Computes high-level (per-token and/or per-word) confidence scores for a list of hypotheses. Assumes that `frame_confidence` is present in the hypotheses. @@ -595,8 +595,11 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes offset += 1 hyp.token_confidence = token_confidence if self.preserve_word_confidence: - for hyp in hypotheses_list: - hyp.word_confidence = self._aggregate_token_confidence(hyp) + for idx, hyp in enumerate(hypotheses_list): + if lang_ids: + hyp.word_confidence = self._aggregate_token_confidence(hyp, lang_ids[idx]) + else: + hyp.word_confidence = self._aggregate_token_confidence(hyp) return hypotheses_list @abstractmethod @@ -1401,7 +1404,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec, blank if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): self.decoding.set_decoding_type('subword') - def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: + def _aggregate_token_confidence(self, hypothesis: Hypothesis, lang_id: str = None) -> List[float]: """ Implemented by subclass in order to reduce token confidence to a word-level confidence. @@ -1414,7 +1417,7 @@ def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: A list of word-level confidence scores. """ return self._aggregate_token_confidence_subwords_sentencepiece( - hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence + hypothesis.words, hypothesis.token_confidence, hypothesis.y_sequence, lang_id ) def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str: @@ -1431,9 +1434,10 @@ def decode_tokens_to_str(self, tokens: List[int], lang: str = None) -> str: hypothesis = self.tokenizer.ids_to_text(tokens, lang) else: hypothesis = self.tokenizer.ids_to_text(tokens) + return hypothesis - def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + def decode_ids_to_tokens(self, tokens: List[int], lang: str = None) -> List[str]: """ Implemented by subclass in order to decode a token id list into a token list. A token list is the string representation of each token id. @@ -1444,7 +1448,10 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: Returns: A list of decoded tokens. """ - token_list = self.tokenizer.ids_to_tokens(tokens) + if lang is not None: + token_list = self.tokenizer.ids_to_tokens(tokens, lang) + else: + token_list = self.tokenizer.ids_to_tokens(tokens) return token_list def decode_tokens_to_lang(self, tokens: List[int]) -> str: From bca336024c4a40d097d6b83d85b7977aaa51c701 Mon Sep 17 00:00:00 2001 From: Nikhil Narasimhan Date: Thu, 26 Sep 2024 20:22:49 +0530 Subject: [PATCH 2/3] Update asr_confidence_utils.py --- nemo/collections/asr/parts/utils/asr_confidence_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 27ced569b..ff1f6fa59 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -424,7 +424,7 @@ def _aggregate_token_confidence_chars(self, words: List[str], token_confidence: return word_confidence def _aggregate_token_confidence_subwords_sentencepiece( - self, words: List[str], token_confidence: List[float], token_ids: List[int] + self, words: List[str], token_confidence: List[float], token_ids: List[int], lang_id: str = None ) -> List[float]: """Implementation of token confidence aggregation for subword-based models. @@ -445,8 +445,8 @@ def _aggregate_token_confidence_subwords_sentencepiece( prev_unk = False prev_underline = False for i, token_id in enumerate(token_ids): - token = self.decode_ids_to_tokens([int(token_id)])[0] - token_text = self.decode_tokens_to_str([int(token_id)]) + token = self.decode_ids_to_tokens([int(token_id)], lang_id)[0] + token_text = self.decode_tokens_to_str([int(token_id)], lang_id) # treat `` as a separate word regardless of the next token # to match the result of `tokenizer.ids_to_text` if (token != token_text or prev_unk) and i > j: From 89d12f3f5339e722f83a7eb752d922c7659bae6b Mon Sep 17 00:00:00 2001 From: Nikhil Narasimhan Date: Thu, 26 Sep 2024 20:23:33 +0530 Subject: [PATCH 3/3] Update multilingual_tokenizer.py --- .../common/tokenizers/multilingual_tokenizer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nemo/collections/common/tokenizers/multilingual_tokenizer.py b/nemo/collections/common/tokenizers/multilingual_tokenizer.py index 1b4e66ed3..432b4a2ca 100644 --- a/nemo/collections/common/tokenizers/multilingual_tokenizer.py +++ b/nemo/collections/common/tokenizers/multilingual_tokenizer.py @@ -117,10 +117,10 @@ def ids_to_text(self, ids, lang): ids = ids.tolist() tokens = [] + tokenizer = self.tokenizers_dict[lang] for id in ids: # offset_id = self.offset_token_ids_by_token_id[id] # tokenizer = self.tokenizers_by_token_id[id] - tokenizer = self.tokenizers_dict[lang] # tokens.extend(tokenizer.ids_to_tokens([offset_id])) tokens.extend(tokenizer.ids_to_tokens([id])) text = ''.join(tokens).replace('▁', ' ') @@ -131,14 +131,9 @@ def token_to_id(self, token, lang_id): tokenizer = self.tokenizers_dict[lang_id] return tokenizer.token_to_id(token) + self.token_id_offset[lang_id] - def ids_to_tokens(self, ids): - tokens = [] - - for id in ids: - offset_id = self.offset_token_ids_by_token_id[id] - tokenizer = self.tokenizers_by_token_id[id] - token = tokenizer.ids_to_tokens([offset_id])[0] - tokens.append(token) + def ids_to_tokens(self, ids, lang_id): + tokenizer = self.tokenizers_dict[lang_id] + tokens = [tokenizer.ids_to_tokens([id])[0] for id in ids] return tokens