Skip to content

Commit

Permalink
Update hybrid_rnnt_ctc_bpe_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tahirjmakhdoomi authored Feb 6, 2024
1 parent 1ed6d46 commit b6ca450
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,14 @@ def change_vocabulary(
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

self.wer = RNNTBPEWER(
decoding=self.decoding,
Expand Down Expand Up @@ -405,7 +410,10 @@ def change_vocabulary(
ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls))
ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg)

self.ctc_decoding = CTCBPEDecoding(decoding_cfg=ctc_decoding_cfg, tokenizer=self.tokenizer)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

self.ctc_wer = WERBPE(
decoding=self.ctc_decoding,
Expand Down Expand Up @@ -444,9 +452,14 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

self.wer = RNNTBPEWER(
decoding=self.decoding,
Expand Down Expand Up @@ -483,7 +496,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

self.ctc_wer = WERBPE(
decoding=self.ctc_decoding,
Expand Down

0 comments on commit b6ca450

Please sign in to comment.