diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 6103d17..dcd5895 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -102,10 +102,11 @@ def __init__( self._add_unk_trg_tokens = add_unk_trg_tokens self._mpn = MosesPunctNormalizer() self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions] + self._stats = TrainStats() @property def stats(self) -> TrainStats: - return super().stats + return self._stats def train( self, @@ -366,6 +367,7 @@ def preprocess_function(examples): self._metrics = train_result.metrics self._metrics["train_samples"] = len(train_dataset) + self._stats.train_corpus_size = self._metrics["train_samples"] self._trainer.log_metrics("train", self._metrics) logger.info("Model training finished")