From 6fa6026ee0ee0497fb6e347f1d56b7ed9b4e85e9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Thu, 23 May 2024 16:47:47 -0700 Subject: [PATCH] Fixing the state.timestamp.batch.value issue in loss v len callback (#1232) * adding print statements * testing fix * fix * removing print statements * minor fix --- llmfoundry/callbacks/loss_perp_v_len_callback.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py index 1a3ac05651..aa9519c255 100644 --- a/llmfoundry/callbacks/loss_perp_v_len_callback.py +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -107,19 +107,24 @@ def after_backward(self, state: State, logger: Logger) -> None: ) def batch_end(self, state: State, logger: Logger) -> None: - if state.timestamp.batch.value % self.compute_batch_interval == 0: + if ( + state.timestamp.batch.value - 1 + ) % self.compute_batch_interval == 0: # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55) current_metric_dict = self.loss_perp_v_len.compute() if dist.get_global_rank() == 0: for k, v in current_metric_dict.items(): v = v.tolist() v.append( - state.timestamp.batch.value, + state.timestamp.batch.value - + 1, # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55) ) # Add the current batch index as the last column if k not in self.metric_dict: self.metric_dict[k] = [] self.metric_dict[k].append(v) - if state.timestamp.batch.value % self.log_batch_interval == 0 and dist.get_global_rank( - ) == 0: + if ( + state.timestamp.batch.value - 1 + ) % self.log_batch_interval == 0 and dist.get_global_rank( + ) == 0: # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55) for k, v in self.metric_dict.items(): columns = [] columns = [