Skip to content

Commit

Permalink
Add learning rate to training progress bar.
Browse files Browse the repository at this point in the history
  • Loading branch information
aecelaya committed Dec 3, 2024
1 parent d48d89e commit 2730947
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions mist/runtime/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Progress bars for MIST training and validation loops."""
from rich.progress import (
BarColumn,
TextColumn,
Progress,
MofNCompleteColumn,
TimeElapsedColumn
)

import numpy as np

class TrainProgressBar(Progress):
def __init__(self, current_epoch, fold, epochs, train_steps):
Expand All @@ -20,16 +21,26 @@ def __init__(self, current_epoch, fold, epochs, train_steps):
TimeElapsedColumn(),
TextColumn("•"),
TextColumn("{task.fields[loss]}"),
TextColumn("•"),
TextColumn("{task.fields[lr]}"), # Learning rate column
)

# Initialize tasks with loss and learning rate fields
self.task = self.progress.add_task(
description="Training",
description="Training (loss)",
total=train_steps,
loss=f"loss: "
loss=f"loss: ",
lr=f"lr: ",
)

def update(self, loss):
self.progress.update(self.task, advance=1, loss=f"loss: {loss:.4f}")
def update(self, loss, lr):
# Update loss task.
self.progress.update(
self.task,
advance=1,
loss=f"loss: {loss:.4f}",
lr=f"lr: {np.format_float_scientific(lr, precision=3)}",
)

def __enter__(self):
self.progress.start()
Expand Down

0 comments on commit 2730947

Please sign in to comment.