Skip to content

Commit

Permalink
clean up API
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 12, 2024
1 parent 8975115 commit 8d65d68
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 104 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CometCallback` for logging training runs to Comet.ml.
- Added `DataMixBase` class, to allow extending to new data mix groups.
- Added support for MoE-based models.
- Added callback method `Callback.post_model_forward()`.
- Added method `DataLoaderBase.get_mock_batch()`.
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.
- Added `Callback.pre_backward()`, `.pre_eval_batch()`, and `.post_eval_batch()` methods.
- Added `Trainer.model_forward()`, `.get_losses()`, and `.eval_batch()` methods.

### Changed

Expand Down
31 changes: 18 additions & 13 deletions src/olmo_core/train/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict

import torch

Expand Down Expand Up @@ -88,31 +88,22 @@ def pre_step(self, batch: Dict[str, Any]):
"""
del batch

def post_model_forward(
def pre_backward(
self,
*,
batch: Dict[str, Any],
micro_batch: Dict[str, Any],
num_micro_batches: int,
batch_num_tokens_for_loss: torch.Tensor,
loss: torch.Tensor,
ce_loss: torch.Tensor,
z_loss: Optional[torch.Tensor] = None,
):
"""
Runs right after the train forward pass on a micro-batch. This can be used to modify the
Runs right before the backward pass on a micro-batch. This can be used to modify the
``loss`` before ``loss.backward()`` is called.
:param batch: The full batch.
:param micro_batch: The micro-batch just used.
:param num_micro_batches: The number of micro-batches in the full batch.
:param batch_num_tokens_for_loss: The total number of tokens in the local full batch that
will be counted towards the loss.
:param loss: The combined loss from the micro-batch (``ce_loss`` plus the optional ``z_loss``).
:param ce_loss: The cross-entropy loss from the micro-batch.
:param z_loss: The Z-loss from the micro-batch.
"""
del batch, micro_batch, num_micro_batches, batch_num_tokens_for_loss, loss, ce_loss, z_loss
del batch, micro_batch, loss

def pre_optim_step(self):
"""
Expand All @@ -126,6 +117,20 @@ def post_train_batch(self):
"""
pass

def pre_eval_batch(self, batch: Dict[str, Any]):
"""
Runs right before an eval batch is processed with :meth:`~olmo_core.train.Trainer.eval_batch()`.
:param batch: The eval batch.
"""
del batch

def post_eval_batch(self):
"""
Runs after after an eval batch is processed with :meth:`~olmo_core.train.Trainer.eval_batch()`.
"""
pass

def post_step(self):
"""
Runs after a complete step (potentially including evals and checkpointing).
Expand Down
9 changes: 3 additions & 6 deletions src/olmo_core/train/callbacks/evaluator_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional

import torch

from olmo_core.data import NumpyDatasetConfig, NumpyPaddedFSLDataset
from olmo_core.distributed.utils import get_world_size
from olmo_core.eval import Evaluator
Expand Down Expand Up @@ -64,10 +62,9 @@ def post_step(self):
eval_step += 1
eval_tokens += batch["input_ids"].numel() * dp_world_size
batch = move_to_device(batch, self.trainer.device)
with torch.no_grad():
ce_loss, _, logits = self.trainer._model_forward(
batch, loss_reduction="none", compute_z_loss=False
)
logits, ce_loss, _ = self.trainer.eval_batch(
batch, loss_reduction="none", compute_z_loss=False
)
evaluator.update_metrics(batch, ce_loss, logits)

if eval_step % self.trainer.cancel_check_interval == 0:
Expand Down
24 changes: 16 additions & 8 deletions src/olmo_core/train/callbacks/moe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class MoEHandlerCallback(Callback):
_batch_z_loss = None
_moe_layer = None

def clear_loss_buffers(self):
assert self._moe_layer is not None
self._moe_layer.get_loss()

def pre_train(self):
for module in self.trainer.model.modules():
if isinstance(module, MoE):
Expand All @@ -32,30 +36,34 @@ def pre_train(self):
f"No MoE layer found in model, required by {self.__class__.__name__}"
)

def post_model_forward(
def pre_step(self, batch: Dict[str, Any]):
del batch
self.clear_loss_buffers()

def post_eval_batch(self):
self.clear_loss_buffers()

def pre_backward(
self,
*,
batch: Dict[str, Any],
micro_batch: Dict[str, Any],
num_micro_batches: int,
batch_num_tokens_for_loss: torch.Tensor,
loss: torch.Tensor,
ce_loss: torch.Tensor,
z_loss: Optional[torch.Tensor] = None,
):
del batch, micro_batch, batch_num_tokens_for_loss, ce_loss, z_loss
assert self._moe_layer is not None

scale_factor = micro_batch["input_ids"].shape[0] / batch["input_ids"].shape[0]

moe_loss: Optional[torch.Tensor] = None
if (lb_loss := self._moe_layer.get_load_balancing_loss()) is not None:
lb_loss.div_(num_micro_batches)
lb_loss.mul_(scale_factor)
moe_loss = lb_loss
if self._batch_lb_loss is None:
self._batch_lb_loss = move_to_device(torch.tensor(0.0), lb_loss.device)
self._batch_lb_loss += get_local_tensor(lb_loss)

if (rz_loss := self._moe_layer.get_router_z_loss()) is not None:
rz_loss.div_(num_micro_batches)
rz_loss.mul_(scale_factor)
if moe_loss is not None:
moe_loss += rz_loss
else:
Expand Down
195 changes: 119 additions & 76 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,98 @@ def add_callback(self, name: str, callback: Callback):
self._sort_callbacks()
callback.post_attach()

def model_forward(self, micro_batch: Dict[str, Any]) -> torch.Tensor:
"""
Run a forward pass on a micro-batch, returning the logits.
"""
with self._model_forward_context():
# shape: (batch_size, seq_len, vocab_size)
logits = self.model(
input_ids=micro_batch["input_ids"],
# attention_mask=micro_batch.get("attention_mask"),
# attention_bias=micro_batch.get("attention_bias"),
doc_lens=micro_batch.get("doc_lens"),
max_doc_lens=micro_batch.get("max_doc_lens"),
)
return logits

def get_losses(
self,
micro_batch: Dict[str, Any],
logits: torch.Tensor,
loss_reduction: Literal["mean", "sum", "none"] = "mean",
compute_z_loss: Optional[bool] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Compute the cross-entropy loss and optionally the Z-loss from a micro-batch and the
corresponding logits returned from :meth:`model_forward()`.
:param micro_batch: The micro-batch to evaluate.
:param logits: The logits from the forward pass.
:param loss_reduction: The (local) reduction to apply to the loss(es).
:param compute_z_loss: Whether or not to compute and return the Z-loss.
:returns: The cross entropy and optional Z-loss, respectively.
"""
loss_fn = cross_entropy_loss if not self.fused_loss else fused_cross_entropy_loss
if compute_z_loss is None:
compute_z_loss = self.z_loss_multiplier is not None

# shape: (batch_size, seq_len - 1, vocab_size)
logits_for_loss = logits[..., :-1, :].contiguous()
# shape: (batch_size * (seq_len - 1), vocab_size)
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))

# shape: (batch_size, seq_len - 1)
labels = micro_batch.get("labels", self._get_labels(micro_batch))
# shape: (batch_size * (seq_len - 1),)
labels = labels.view(-1)

ce_loss, z_loss = loss_fn(
logits_for_loss,
labels,
ignore_index=self.data_loader.collator.label_ignore_index,
reduction=loss_reduction,
compute_z_loss=compute_z_loss,
z_loss_multiplier=self.z_loss_multiplier or 1e-4,
)

if loss_reduction == "none":
# Reshape (batch_size * (seq_len - 1),) -> (batch_size, seq_len - 1)
ce_loss = ce_loss.view(micro_batch["input_ids"].shape[0], -1)
if z_loss is not None:
z_loss = z_loss.view(micro_batch["input_ids"].shape[0], -1)

return ce_loss, z_loss

def eval_batch(
self,
batch: Dict[str, Any],
loss_reduction: Literal["mean", "sum", "none"] = "mean",
compute_z_loss: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Get the loss for an eval batch.
.. important::
You are responsible for ensuring the model is in ``.eval()`` mode before calling this.
:param batch: The batch to evaluate.
:param loss_reduction: The (local) reduction to apply to the loss(es).
:param compute_z_loss: Whether or not to compute and return the Z-loss.
:returns: The logits, cross-entropy loss, and Z-loss, respectively.
"""
batch = move_to_device(batch, self.device)
for callback in self.callbacks.values():
callback.pre_eval_batch(batch)
with torch.no_grad():
logits = self.model_forward(batch)
ce_loss, z_loss = self.get_losses(
batch, logits, loss_reduction=loss_reduction, compute_z_loss=compute_z_loss
)
return logits, ce_loss, z_loss

def _sort_callbacks(self):
self.callbacks = OrderedDict(
(
Expand Down Expand Up @@ -955,69 +1047,13 @@ def _model_forward(
loss_reduction: Literal["mean", "sum", "none"] = "mean",
compute_z_loss: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
with self._model_forward_context():
# shape: (batch_size, seq_len, vocab_size)
logits = self.model(
input_ids=batch["input_ids"],
# attention_mask=batch.get("attention_mask"),
# attention_bias=batch.get("attention_bias"),
doc_lens=batch.get("doc_lens"),
max_doc_lens=batch.get("max_doc_lens"),
)

# shape: (batch_size, seq_len - 1, vocab_size)
logits_for_loss = logits[..., :-1, :].contiguous()
# shape: (batch_size * (seq_len - 1), vocab_size)
logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
# shape: (batch_size, seq_len - 1)
labels = batch.get("labels", self._get_labels(batch))
# shape: (batch_size * (seq_len - 1),)
labels = labels.view(-1)

loss_fn = cross_entropy_loss if not self.fused_loss else fused_cross_entropy_loss
ce_loss, z_loss = loss_fn(
logits_for_loss,
labels,
ignore_index=self.data_loader.collator.label_ignore_index,
reduction=loss_reduction,
compute_z_loss=compute_z_loss,
z_loss_multiplier=self.z_loss_multiplier or 1e-4,
# NOTE: keep this method for backwards compatibility.
logits = self.model_forward(batch)
ce_loss, z_loss = self.get_losses(
batch, logits, loss_reduction=loss_reduction, compute_z_loss=compute_z_loss
)

if loss_reduction == "none":
# Reshape (batch_size * (seq_len - 1),) -> (batch_size, seq_len - 1)
ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
if z_loss is not None:
z_loss = z_loss.view(batch["input_ids"].shape[0], -1)

return ce_loss, z_loss, logits

def _get_microbatch_loss(
self, micro_batch: Dict[str, Any], batch_num_tokens_for_loss: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
# NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss'
# (the total number of tokens used in the loss across the whole batch, not just the micro batch)
# to avoid biasing the loss in the case where micro-batches might not be the same size.
ce_loss, z_loss, logits = self._model_forward(
micro_batch, compute_z_loss=self.z_loss_multiplier is not None, loss_reduction="sum"
)
ce_loss = ce_loss / batch_num_tokens_for_loss

# In case this helps with memory utilization.
del micro_batch

# Get loss to optimize for.
if self.z_loss_multiplier is not None:
assert z_loss is not None
z_loss = z_loss / batch_num_tokens_for_loss
loss = ce_loss + z_loss
else:
loss = ce_loss

del logits

return loss, ce_loss, z_loss

@contextlib.contextmanager
def _train_microbatch_context(
self, micro_batch_idx: int, num_micro_batches: int
Expand Down Expand Up @@ -1065,20 +1101,22 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
for micro_batch_idx, micro_batch in enumerate(micro_batches):
with self._train_microbatch_context(micro_batch_idx, num_micro_batches):
# Run forward pass.
loss, ce_loss, z_loss = self._get_microbatch_loss(
micro_batch, batch_num_tokens_for_loss
)
logits = self.model_forward(micro_batch)

for callback in self.callbacks.values():
callback.post_model_forward(
batch=batch,
micro_batch=micro_batch,
num_micro_batches=num_micro_batches,
batch_num_tokens_for_loss=batch_num_tokens_for_loss,
loss=loss,
ce_loss=ce_loss,
z_loss=z_loss,
)
# NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss'
# (the total number of tokens used in the loss across the whole batch, not just the micro batch)
# to avoid biasing the loss in the case where micro-batches might not be the same size.
ce_loss, z_loss = self.get_losses(micro_batch, logits, loss_reduction="sum")
ce_loss.div_(batch_num_tokens_for_loss)
if z_loss is not None:
z_loss.div_(batch_num_tokens_for_loss)

# Get loss to optimize for.
loss: torch.Tensor
if z_loss is not None:
loss = ce_loss + z_loss
else:
loss = ce_loss

# Update overall CE batch loss.
ce_batch_loss += get_local_tensor(ce_loss.detach())
Expand All @@ -1088,6 +1126,10 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
assert z_batch_loss is not None
z_batch_loss += get_local_tensor(z_loss.detach())

# Run through callbacks.
for callback in self.callbacks.values():
callback.pre_backward(batch=batch, micro_batch=micro_batch, loss=loss)

# Run backward pass.
loss.backward()

Expand Down Expand Up @@ -1115,6 +1157,10 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
if isinstance(self.optim, SkipStepOptimizer):
self.record_metric(OPTIM_STEP_SKIPPED_METRIC, self.optim.step_skipped)

# Run through callbacks.
for callback in self.callbacks.values():
callback.post_train_batch()

def _iter_batches(self) -> Generator[Dict[str, Any], None, None]:
data_iterator = iter(self.data_loader)

Expand Down Expand Up @@ -1175,9 +1221,6 @@ def _fit_epoch(self):

self._train_batch(batch)

for callback in self.callbacks.values():
callback.post_train_batch()

for callback in self.callbacks.values():
callback.post_step()

Expand Down

0 comments on commit 8d65d68

Please sign in to comment.