Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad accum when using loss_mask #842

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/levanter/grad_accum.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import enum
import functools
from typing import Callable, Optional, ParamSpec, TypeVar
Expand All @@ -20,6 +21,12 @@
R = TypeVar("R")


class NumElementsBatch(abc.ABC):
@abc.abstractmethod
def num_elements(self) -> int:
pass


class ReductionType(enum.Enum):
SUM = enum.auto()
MEAN = enum.auto()
Expand Down
8 changes: 7 additions & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import haliax as hax
from haliax import Axis, NamedArray, NamedOrNumeric

from levanter.grad_accum import NumElementsBatch
from levanter.models.attention import AttentionMask
from levanter.models.loss import maybe_fused_next_token_loss

Expand All @@ -19,7 +20,7 @@
LmT = TypeVar("LmT", bound="LmHeadModel")


class LmExample(eqx.Module):
class LmExample(eqx.Module, NumElementsBatch):
tokens: hax.NamedArray
loss_mask: hax.NamedArray
attn_mask: AttentionMask | NamedArray = AttentionMask.causal()
Expand Down Expand Up @@ -88,6 +89,9 @@ def from_prompt_and_completion(

return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)

def num_elements(self):
return self.loss_mask.sum()


# TODO: for some reason, mypy doesn't like the discover_packages_path argument?
@dataclass(frozen=True)
Expand Down Expand Up @@ -221,6 +225,7 @@ def compute_next_token_loss(
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
logsumexp_weight: Optional[float] = None,
loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32,
) -> jnp.ndarray | NamedArray:
Expand All @@ -241,6 +246,7 @@ def compute_next_token_loss(
loss_mask=example.loss_mask,
reduction=reduction,
reduction_axis=reduction_axis,
batch_num_elements=batch_num_elements,
logsumexp_weight=logsumexp_weight,
dtype=loss_dtype,
block_size=model.config.cross_entropy_block_size,
Expand Down
61 changes: 37 additions & 24 deletions src/levanter/models/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
from typing import Optional

import equinox
Expand All @@ -10,6 +11,9 @@
from haliax.nn import cross_entropy_loss_and_log_normalizers


logger = logging.getLogger(__name__)


def maybe_fused_next_token_loss(
Pos: hax.AxisSelector,
Embed: hax.AxisSelector,
Expand All @@ -20,6 +24,7 @@ def maybe_fused_next_token_loss(
loss_mask: Optional[NamedArray] = None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho i think the thing to do is to use reduction=None inside grad accum and do the reduction separately. Also, can't you infer batch_num_elements from loss_mask.sum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlwh We can't get that from loss_mask.sum since the loss_mask we get is for the microbatch and not the whole batch. Instead of computing a mean loss per micro batch an averaging the mean losses, we need to sum all the microbatches dividing by the true number of losses in the batch.

logsumexp_weight: Optional[float] = None,
block_size: Optional[int] = None,
dtype: Optional[jnp.dtype] = jnp.float32,
Expand All @@ -36,6 +41,7 @@ def maybe_fused_next_token_loss(
loss_mask (Optional[NamedArray]): Mask to apply to the loss.
reduction (Optional[hax.ReductionFunction]): Reduction function.
reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction.
batch_num_elements (Optional[int]): The number of elements in the batch. When passed, it is used to reduce the loss.
logsumexp_weight (Optional[float]): Weight for logsumexp penalty.
block_size (Optional[int]): Size of each block for processing.

Expand All @@ -45,6 +51,9 @@ def maybe_fused_next_token_loss(
# Resolve axes
Pos = pred_embeddings.resolve_axis(Pos)
Vocab = pred_lm_head.resolve_axis(Vocab)
if batch_num_elements is not None:
if reduction is not hax.sum:
logger.warning("batch_num_elements given when reduction is not hax.sum, make sure this is intended")

if block_size is None:
# Full softmax computation
Expand All @@ -53,32 +62,36 @@ def maybe_fused_next_token_loss(
logits = logits.astype(dtype)

# Shift target tokens to predict the next token
return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight)

# Shift target tokens to predict the next token
target_y = hax.roll(true_ids, -1, Pos)

# Create a mask that excludes the last token
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore
if loss_mask is not None:
loss_mask = loss_mask * not_last_loss_mask
loss = next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight)
else:
loss_mask = not_last_loss_mask
# Shift target tokens to predict the next token
target_y = hax.roll(true_ids, -1, Pos)

# Compute the loss with optional block-wise processing
return fused_cross_entropy_loss_and_logsumexp_penalty(
pred_embeddings,
pred_lm_head,
Contract=Embed,
Label=Vocab,
target_y=target_y,
reduction=reduction,
reduction_axis=reduction_axis,
where=loss_mask,
logsumexp_weight=logsumexp_weight,
block_size=block_size,
dtype=dtype,
)
# Create a mask that excludes the last token
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore
if loss_mask is not None:
loss_mask = loss_mask * not_last_loss_mask
else:
loss_mask = not_last_loss_mask

# Compute the loss with optional block-wise processing
loss = fused_cross_entropy_loss_and_logsumexp_penalty(
pred_embeddings,
pred_lm_head,
Contract=Embed,
Label=Vocab,
target_y=target_y,
reduction=reduction,
reduction_axis=reduction_axis,
where=loss_mask,
logsumexp_weight=logsumexp_weight,
block_size=block_size,
dtype=dtype,
)

if batch_num_elements is not None:
return loss / batch_num_elements
return loss


def next_token_loss(
Expand Down
20 changes: 15 additions & 5 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from levanter.config import JsonAtom
from levanter.data import AsyncDataset, DataLoader
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import microbatched
from levanter.grad_accum import NumElementsBatch, ReductionType, microbatched
from levanter.tracker import TrackerConfig, capture_time
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.utils import cloud_utils, fsspec_utils
Expand Down Expand Up @@ -380,7 +380,7 @@ def checkpoint_path(self) -> str:
checkpoint_path = self.config.checkpointer.expanded_path(self.run_id)
return checkpoint_path

def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]:
"""
Performs a single training step.
"""
Expand Down Expand Up @@ -529,7 +529,7 @@ def _train_step(
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key)

with hax.axis_mapping(self.parameter_axis_mapping):
if not _no_hooks:
Expand All @@ -549,18 +549,28 @@ def obj_fun(trainable_model):
else:
return loss, new_state, hook_infos

def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]:
def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
mbs = self.config.microbatch_size
reduce = ReductionType.MEAN
if isinstance(batch, NumElementsBatch) and mbs != self.TrainBatch.size:
batch_kwargs[
"batch_num_elements"
] = batch.num_elements() # tell the loss function how many elements are in the batch
batch_kwargs[
"reduction"
] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average
reduce = ReductionType.SUM # we're already normalizing the loss
grad_fn = microbatched(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have the grad_fn return both the total (masked) loss and the number of elements (and the gradient?). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we'd change the signature of ComputeLossFunction to return something like
(model, input) -> (loss, extras) where extras has the number of elements in the microbatch in it?

I also find myself wanting to log things this way (by returning elements in an extras struct) so this would be a nice way to add that in.

The only concern is that if the number of microbatch elements or the losses are particularly large we could lose precision in the value of the loss right?
When computing (loss1 + loss2 + ... ) / num_elems v.s. loss1/num_elems + loss2/num_elems + ..., I would kinda expect the second to be more accurate. Is this a reasonable concern or am I prematurely optimizing?

Copy link
Contributor Author

@Aphoh Aphoh Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be a valid concern given we'd be summing up fp16 gradients? This is what happens if you sum up a bunch of values near ln(vocab_size) which is around the maximum value you'd get from the loss.
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's actually a standard way of dealing with this which I've encapsulated as RunningMean. The basic idea is to maintain the mean and the count rather than the sum and the count. This solves most precision problems usually. (You can actually do variance too but we don't need it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlwh I tried to encapsulate this in this last commit. I'm pretty sure it doesn't work right since it doesn't pass the grad_accum tests (trees don't match). Could you give it a quick peek and let me know if you see anything?

grad_fn,
self.TrainBatch,
mbs,
self.parameter_axis_mapping,
self.compute_axis_mapping,
reduce=reduce,
)
with hax.axis_mapping(self.compute_axis_mapping):
return grad_fn(model, *batch, **batch_kwargs)
return grad_fn(model, batch, **batch_kwargs)


def _initialize_global_tracker(config, run_id):
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class ComputeLossFunction(Protocol[M_con, X]):
def __call__(
self,
model: M_con,
*inputs: X,
input: X,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
**kwargs,
) -> Scalar | hax.NamedArray:
...
Loading