diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 9d43b28..04086af 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -55,8 +55,8 @@ def __init__( self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True) if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)): self._mpn = MosesPunctNormalizer() - self._mpn.substitutions = [ - (str(re.compile(r)), sub) + self._mpn.substitutions = [ # type: ignore + (re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str) ] @@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs): input_tokens = model_inputs["input_tokens"] del model_inputs["input_tokens"] - generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) - generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) + if hasattr(self.model, "generation_config") and self.model.generation_config is not None: + config = self.model.generation_config + else: + config = self.model.config + generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) + generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) output = self.model.generate( **model_inputs, diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index f15ba4e..1192243 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -100,10 +100,8 @@ def __init__( self._add_unk_src_tokens = add_unk_src_tokens self._add_unk_tgt_tokens = add_unk_tgt_tokens self._mpn = MosesPunctNormalizer() - self._mpn.substitutions = [ - (str(re.compile(r)), sub) - for r, sub in self._mpn.substitutions - if isinstance(r, str) and isinstance(sub, str) + self._mpn.substitutions = [ # type: ignore + (re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str) ] self._stats = TrainStats()