Skip to content

Commit

Permalink
fix cuda error
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Aug 18, 2024
1 parent 871aaaa commit 03c6a44
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 15 deletions.
6 changes: 5 additions & 1 deletion nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def _speech_collate_fn(batch, pad_id):
tokens.append(tokens_i)

if has_audio:
audio_signal = torch.stack(audio_signal)
try:
audio_signal = torch.stack(audio_signal)
except RuntimeError:
print("audio signal problem")
breakpoint()
audio_lengths = torch.stack(audio_lengths)
else:
audio_signal, audio_lengths = None, None
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# self.language_masks[language].extend([True]*num_languages)
self.language_masks[language].append(True) # Insert blank token
# breakpoint()
num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + num_languages
# breakpoint()
self.ctc_loss = CTCLoss(
num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + num_languages - 1,
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,10 @@ def forward_internal(

audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
# breakpoint()
# if language_ids is not None:
# language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device)
# language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1)
# audio_signal = torch.cat((language_inputs, audio_signal), 1)
if language_ids is not None:
language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device)
language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1)
audio_signal = torch.cat((language_inputs, audio_signal), 1)
# breakpoint()

# Create the self-attention and padding masks
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids=
res = torch.stack(res_single)
else:
res = self.joint_net(inp) # [B, T, U, V + 1]

del inp

if self.preserve_memory:
Expand Down Expand Up @@ -1662,7 +1662,7 @@ def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_h
final_layer = torch.nn.ModuleDict()
logging.info(f"Vocab size for each language: {self._vocab_size // len(self.language_keys)}")
for lang in self.language_keys:
final_layer[lang] = torch.nn.Linear(joint_n_hidden, (self._vocab_size // len(self.language_keys)+1))
final_layer[lang] = torch.nn.Linear(joint_n_hidden, ((self._vocab_size)//len(self.language_keys) + 1 + len(self.language_keys)))
layers = (
[activation]
+ ([torch.nn.Dropout(p=dropout)] if dropout else [])
Expand Down
21 changes: 14 additions & 7 deletions nemo/collections/common/tokenizers/multilingual_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def text_to_tokens(self, text, lang_id):
def text_to_ids(self, text, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
token_ids = tokenizer.text_to_ids(text)
lang_token_index = self.vocabulary.index(f"<{lang_id}>")
lang_token_index = self.vocabulary.index(f"<{lang_id}>") - self.token_id_offset["<lang_code>"] + len(tokenizer.vocab)
# Insert language token at index 0 to the list of tokens
token_ids = [lang_token_index] + token_ids
# token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids]
Expand All @@ -132,8 +132,12 @@ def ids_to_text(self, ids, lang):
# tokenizer = self.tokenizers_by_token_id[id]
tokenizer = self.tokenizers_dict[lang]
# tokens.extend(tokenizer.ids_to_tokens([offset_id]))
if id >= self.token_id_offset["<lang_code>"]:
tokens.append(self.vocabulary[id]+' ')
if id >= len(tokenizer.vocab):
try:
tokens.append(self.vocabulary[id-len(tokenizer.vocab)+self.token_id_offset["<lang_code>"]]+' ')
except IndexError:
print("Index error occured")
breakpoint()
else:
tokens.extend(tokenizer.ids_to_tokens([id]))
text = ''.join(tokens).replace('▁', ' ')
Expand All @@ -148,10 +152,13 @@ def ids_to_tokens(self, ids):
tokens = []

for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
tokens.append(token)
if id >= self.token_id_offset["<lang_code>"]:
tokens.append(self.vocabulary[id]+' ')
else:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
tokens.append(token)

return tokens

Expand Down

0 comments on commit 03c6a44

Please sign in to comment.