Skip to content

Commit

Permalink
Change progress logging to once across all nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Jan 8, 2024
1 parent 4ab5fb8 commit b824a63
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,42 +489,42 @@ def __init__(self):
self.prediction_bar = None

def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero:
if state.is_world_process_zero:
self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
self.current_step = 0

def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
if state.is_world_process_zero:
self.training_bar.update(state.global_step - self.current_step)
self.current_step = state.global_step

def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and has_length(eval_dataloader):
if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_bar is None:
self.prediction_bar = tqdm(
total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
)
self.prediction_bar.update(1)

def on_evaluate(self, args, state, control, **kwargs):
if state.is_local_process_zero:
if state.is_world_process_zero:
if self.prediction_bar is not None:
self.prediction_bar.close()
self.prediction_bar = None

def on_predict(self, args, state, control, **kwargs):
if state.is_local_process_zero:
if state.is_world_process_zero:
if self.prediction_bar is not None:
self.prediction_bar.close()
self.prediction_bar = None

def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and self.training_bar is not None:
if state.is_world_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_bar.write(str(logs))

def on_train_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
if state.is_world_process_zero:
self.training_bar.close()
self.training_bar = None

Expand Down

0 comments on commit b824a63

Please sign in to comment.