From 03c6a44db7a48ca025f8f6061c261c25333f14d4 Mon Sep 17 00:00:00 2001 From: kaushal-py Date: Sun, 18 Aug 2024 17:16:42 +0000 Subject: [PATCH] fix cuda error --- nemo/collections/asr/data/audio_to_text.py | 6 +++++- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 1 - .../asr/modules/conformer_encoder.py | 8 +++---- nemo/collections/asr/modules/rnnt.py | 4 ++-- .../tokenizers/multilingual_tokenizer.py | 21 ++++++++++++------- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 3b5f61cb6..b010b58c7 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -100,7 +100,11 @@ def _speech_collate_fn(batch, pad_id): tokens.append(tokens_i) if has_audio: - audio_signal = torch.stack(audio_signal) + try: + audio_signal = torch.stack(audio_signal) + except RuntimeError: + print("audio signal problem") + breakpoint() audio_lengths = torch.stack(audio_lengths) else: audio_signal, audio_lengths = None, None diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 16e749bca..07fa6861e 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -124,7 +124,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # self.language_masks[language].extend([True]*num_languages) self.language_masks[language].append(True) # Insert blank token # breakpoint() - num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + num_languages # breakpoint() self.ctc_loss = CTCLoss( num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + num_languages - 1, diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 79cb89c3f..85017295f 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -561,10 +561,10 @@ def forward_internal( audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) # breakpoint() - # if language_ids is not None: - # language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device) - # language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1) - # audio_signal = torch.cat((language_inputs, audio_signal), 1) + if language_ids is not None: + language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device) + language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1) + audio_signal = torch.cat((language_inputs, audio_signal), 1) # breakpoint() # Create the self-attention and padding masks diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index a5e69751c..aa0d98dff 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1609,7 +1609,7 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids= res = torch.stack(res_single) else: res = self.joint_net(inp) # [B, T, U, V + 1] - + del inp if self.preserve_memory: @@ -1662,7 +1662,7 @@ def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_h final_layer = torch.nn.ModuleDict() logging.info(f"Vocab size for each language: {self._vocab_size // len(self.language_keys)}") for lang in self.language_keys: - final_layer[lang] = torch.nn.Linear(joint_n_hidden, (self._vocab_size // len(self.language_keys)+1)) + final_layer[lang] = torch.nn.Linear(joint_n_hidden, ((self._vocab_size)//len(self.language_keys) + 1 + len(self.language_keys))) layers = ( [activation] + ([torch.nn.Dropout(p=dropout)] if dropout else []) diff --git a/nemo/collections/common/tokenizers/multilingual_tokenizer.py b/nemo/collections/common/tokenizers/multilingual_tokenizer.py index 179a32adf..a24853b06 100644 --- a/nemo/collections/common/tokenizers/multilingual_tokenizer.py +++ b/nemo/collections/common/tokenizers/multilingual_tokenizer.py @@ -108,7 +108,7 @@ def text_to_tokens(self, text, lang_id): def text_to_ids(self, text, lang_id): tokenizer = self.tokenizers_dict[lang_id] token_ids = tokenizer.text_to_ids(text) - lang_token_index = self.vocabulary.index(f"<{lang_id}>") + lang_token_index = self.vocabulary.index(f"<{lang_id}>") - self.token_id_offset[""] + len(tokenizer.vocab) # Insert language token at index 0 to the list of tokens token_ids = [lang_token_index] + token_ids # token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids] @@ -132,8 +132,12 @@ def ids_to_text(self, ids, lang): # tokenizer = self.tokenizers_by_token_id[id] tokenizer = self.tokenizers_dict[lang] # tokens.extend(tokenizer.ids_to_tokens([offset_id])) - if id >= self.token_id_offset[""]: - tokens.append(self.vocabulary[id]+' ') + if id >= len(tokenizer.vocab): + try: + tokens.append(self.vocabulary[id-len(tokenizer.vocab)+self.token_id_offset[""]]+' ') + except IndexError: + print("Index error occured") + breakpoint() else: tokens.extend(tokenizer.ids_to_tokens([id])) text = ''.join(tokens).replace('▁', ' ') @@ -148,10 +152,13 @@ def ids_to_tokens(self, ids): tokens = [] for id in ids: - offset_id = self.offset_token_ids_by_token_id[id] - tokenizer = self.tokenizers_by_token_id[id] - token = tokenizer.ids_to_tokens([offset_id])[0] - tokens.append(token) + if id >= self.token_id_offset[""]: + tokens.append(self.vocabulary[id]+' ') + else: + offset_id = self.offset_token_ids_by_token_id[id] + tokenizer = self.tokenizers_by_token_id[id] + token = tokenizer.ids_to_tokens([offset_id])[0] + tokens.append(token) return tokens