From 7987890a102093472f59919273842ea60ee23d99 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 6 Sep 2024 12:52:44 -0400 Subject: [PATCH] small fix for NMT build job (#119) --- .../translation/huggingface/hugging_face_nmt_model_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 6103d17..dcd5895 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -102,10 +102,11 @@ def __init__( self._add_unk_trg_tokens = add_unk_trg_tokens self._mpn = MosesPunctNormalizer() self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions] + self._stats = TrainStats() @property def stats(self) -> TrainStats: - return super().stats + return self._stats def train( self, @@ -366,6 +367,7 @@ def preprocess_function(examples): self._metrics = train_result.metrics self._metrics["train_samples"] = len(train_dataset) + self._stats.train_corpus_size = self._metrics["train_samples"] self._trainer.log_metrics("train", self._metrics) logger.info("Model training finished")