Skip to content

Commit

Permalink
decompose logging and rename helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 26, 2024
1 parent 04d20fa commit 4658a8f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 25 deletions.
34 changes: 15 additions & 19 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = MAX_LOSS_CAP

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
def _detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
if train_loss >= running_loss_avg * self.outlier_multiplier:
self.outlier_counter += 1
Expand All @@ -82,7 +82,7 @@ def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
self.outlier_counter = 0
return False

def detect_high_losses(self, current_step: int):
def _detect_high_losses(self, current_step: int):
# Half of the running losses are greater than our "high loss" threshold, after an initial buffer period
if (current_step >= self.window_size * 2) and (
sum(1 for loss in self.loss_window if loss > self.loss_cap) >=
Expand All @@ -93,34 +93,30 @@ def detect_high_losses(self, current_step: int):
)
return True
return False

def handle_loss_spike(
self, logger: Logger, running_loss_avg: float
) -> None:

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'loss_spike':
f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
key: message,
'loss_window': list(self.loss_window),
})

def _handle_loss_spike(
self, logger: Logger, running_loss_avg: float
) -> None:
message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'loss_spike', message)
if not self.log_only:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
outlier_counter=self.outlier_counter,
)

def handle_high_losses(self, logger: Logger) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'high_loss':
f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
def _handle_high_losses(self, logger: Logger) -> None:
message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'high_loss', message)
if not self.log_only:
raise HighLossError(
loss_cap=self.loss_cap,
Expand Down
12 changes: 6 additions & 6 deletions tests/callbacks/test_kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def test_detect_loss_spike_no_spike(self, _):
self.callback.outlier_counter = 0
train_loss = 4
running_loss_avg = 2
result = self.callback.detect_loss_spike(train_loss, running_loss_avg)
result = self.callback._detect_loss_spike(train_loss, running_loss_avg)
self.assertFalse(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_loss_spike_with_spike(self, _):
self.callback.outlier_counter = 4 # Simulating previous spikes
train_loss = 4
running_loss_avg = 2
result = self.callback.detect_loss_spike(train_loss, running_loss_avg)
result = self.callback._detect_loss_spike(train_loss, running_loss_avg)
self.assertTrue(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
Expand All @@ -59,7 +59,7 @@ def test_no_error_raised_with_log_only_true(self, _):
self.callback.outlier_counter = 4
self.callback.loss_window = deque([2] * 10, maxlen=10)

result = self.callback.detect_loss_spike(state.loss.item(), 2)
result = self.callback._detect_loss_spike(state.loss.item(), 2)
self.assertTrue(result)

# batch_end should not raise an error due to log_only=True
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_error_raised_with_log_only_false(self, _):
self.callback.loss_window = deque([2] * 10, maxlen=10)
self.callback.log_only = False

result = self.callback.detect_loss_spike(state.loss.item(), 2)
result = self.callback._detect_loss_spike(state.loss.item(), 2)
self.assertTrue(result)

# batch_end should raise an error due to log_only=False
Expand All @@ -98,7 +98,7 @@ def test_error_raised_with_log_only_false(self, _):
def test_detect_high_losses_no_high_losses(self, _):
self.callback.loss_window = deque([2] * 10, maxlen=10)
current_step = 21
result = self.callback.detect_high_losses(current_step)
result = self.callback._detect_high_losses(current_step)
self.assertFalse(result)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
Expand All @@ -108,5 +108,5 @@ def test_detect_high_losses_with_high_losses(self, _):
maxlen=10,
) # Simulate mix of losses in loss window
current_step = 21
result = self.callback.detect_high_losses(current_step)
result = self.callback._detect_high_losses(current_step)
self.assertTrue(result)

0 comments on commit 4658a8f

Please sign in to comment.