Skip to content

Commit

Permalink
Align logged attributes for errors and run metadata in kill_loss_spik…
Browse files Browse the repository at this point in the history
…e_callback.py (#1494)
  • Loading branch information
joyce-chen-uni authored Aug 29, 2024
1 parent bf6cfdf commit 8516181
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
35 changes: 26 additions & 9 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def _detect_high_losses(self, current_step: int) -> bool:

return is_high_loss

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
def _log_metadata(self, logger: Logger, key: str, value: dict) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
key: message,
key: value,
'loss_window': list(self.loss_window),
})

Expand All @@ -122,22 +122,39 @@ def _handle_loss_spike(
logger: Logger,
running_loss_avg: float,
) -> None:
message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'loss_spike', message)
if not self.log_only:
if self.log_only:
self._log_metadata(
logger,
'loss_spike',
{
'outlier_multiplier': self.outlier_multiplier,
'running_loss_avg': running_loss_avg,
'outlier_counter': self.outlier_counter,
},
)
else:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
running_loss_avg=running_loss_avg,
outlier_counter=self.outlier_counter,
loss_window=list(self.loss_window),
)

def _handle_high_losses(self, logger: Logger) -> None:
message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'high_loss', message)
if not self.log_only:
if self.log_only:
self._log_metadata(
logger,
'high_loss',
{
'loss_cap': self.loss_cap,
'window_size': self.window_size,
},
)
else:
raise HighLossError(
loss_cap=self.loss_cap,
window_size=self.window_size,
loss_window=list(self.loss_window),
)

def _set_window_size(self, state: State) -> None:
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,20 @@ class LossSpikeError(UserError):
def __init__(
self,
outlier_multiplier: float,
running_loss_avg: int,
running_loss_avg: float,
outlier_counter: int,
loss_window: list[float],
) -> None:
message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \
the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \
the running average loss (approx. {round(running_loss_avg, 1)}) over {outlier_counter} consecutive training steps. \
Please try submitting the run again with a lower learning rate.'

super().__init__(
message,
outlier_multiplier=outlier_multiplier,
running_loss_avg=running_loss_avg,
outlier_counter=outlier_counter,
loss_window=loss_window,
)


Expand All @@ -417,6 +419,7 @@ def __init__(
self,
loss_cap: float,
window_size: int,
loss_window: list[float],
) -> None:
message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \
for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.'
Expand All @@ -425,4 +428,5 @@ def __init__(
message,
loss_cap=loss_cap,
window_size=window_size,
loss_window=loss_window,
)
2 changes: 2 additions & 0 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_default_value(arg_type: Optional[type] = None):
return 1
elif arg_type == float:
return 1.0
elif arg_type == list[float]:
return [1.0]
elif arg_type == set[str]:
return {'set'}
elif arg_type == list[str]:
Expand Down

0 comments on commit 8516181

Please sign in to comment.