Skip to content

Commit

Permalink
Merge pull request #3374 from flairNLP/3373-batch-count
Browse files Browse the repository at this point in the history
use batch count instead of total training samples for logging metrics
  • Loading branch information
alanakbik authored Nov 10, 2023
2 parents 50b3e30 + 837459c commit 16c88a5
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def train_custom(
# At any point you can hit Ctrl + C to break out of training early.
try:
total_train_samples = 0
batch_count = 0

for epoch in range(epoch + 1, max_epochs + 1):
log_line(log)
Expand All @@ -547,7 +548,7 @@ def train_custom(
self.dispatch("before_training_epoch", epoch=epoch)
self.model.model_card["training_parameters"]["epoch"] = epoch # type: ignore[index]

lr_info, momentum_info = self._get_current_lr_and_momentum(total_train_samples)
lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count)

# if shuffle_first_epoch==False, the first epoch is not shuffled
shuffle_data_this_epoch = shuffle
Expand Down Expand Up @@ -579,12 +580,14 @@ def train_custom(

batch_train_loss = 0.0
batch_train_samples = 0
batch_count += 1

batch_kw = {
"batch_no": batch_no,
"batch": batch,
"total_number_of_batches": len(batch_loader),
"epoch": epoch,
"batch_count": batch_count,
}

self.dispatch("before_training_batch", **batch_kw)
Expand Down Expand Up @@ -626,10 +629,10 @@ def train_custom(
if batch_train_samples > 0:
total_train_samples += batch_train_samples
train_loss = batch_train_loss / batch_train_samples
self._record(MetricRecord.scalar(("train", "batch_loss"), train_loss, total_train_samples))
self._record(MetricRecord.scalar(("train", "batch_loss"), train_loss, batch_count))
if gradient_norm is not None:
self._record(
MetricRecord.scalar(("train", "gradient_norm"), gradient_norm, total_train_samples)
MetricRecord.scalar(("train", "gradient_norm"), gradient_norm, batch_count)
)

epoch_train_loss += batch_train_loss
Expand All @@ -644,7 +647,7 @@ def train_custom(

current_time = time.time()

lr_info, momentum_info = self._get_current_lr_and_momentum(total_train_samples)
lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count)
log.info(
f"epoch {epoch}"
f" - iter {batch_no + 1}/{len(batch_loader)}"
Expand Down Expand Up @@ -811,13 +814,13 @@ def train_custom(

return return_values

def _get_current_lr_and_momentum(self, total_train_samples):
def _get_current_lr_and_momentum(self, batch_count):
current_learning_rate = [group["lr"] for group in self.optimizer.param_groups]
momentum = [group["momentum"] if "momentum" in group else 0 for group in self.optimizer.param_groups]
lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate])
momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum])
self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, total_train_samples))
self._record(MetricRecord.scalar_list(("optimizer", "momentum"), momentum, total_train_samples))
self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count))
self._record(MetricRecord.scalar_list(("optimizer", "momentum"), momentum, batch_count))
return lr_info, momentum_info

def _sample_train_split(self, monitor_train_sample):
Expand Down

0 comments on commit 16c88a5

Please sign in to comment.