Skip to content

Commit

Permalink
small fix for NMT build job (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 authored Sep 6, 2024
1 parent 88f5bad commit 7987890
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def __init__(
self._add_unk_trg_tokens = add_unk_trg_tokens
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]
self._stats = TrainStats()

@property
def stats(self) -> TrainStats:
return super().stats
return self._stats

def train(
self,
Expand Down Expand Up @@ -366,6 +367,7 @@ def preprocess_function(examples):

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)
self._stats.train_corpus_size = self._metrics["train_samples"]

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished")
Expand Down

0 comments on commit 7987890

Please sign in to comment.