diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index feafcbc30..fb72bef41 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -338,9 +338,14 @@ def change_vocabulary( decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, - ) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + ) + else: + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) self.wer = RNNTBPEWER( decoding=self.decoding, @@ -405,7 +410,10 @@ def change_vocabulary( ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls)) ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg) - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=ctc_decoding_cfg, tokenizer=self.tokenizer) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) self.ctc_wer = WERBPE( decoding=self.ctc_decoding, @@ -444,9 +452,14 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, - ) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + ) + else: + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) self.wer = RNNTBPEWER( decoding=self.decoding, @@ -483,7 +496,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) self.ctc_wer = WERBPE( decoding=self.ctc_decoding,