-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix grad accum when using loss_mask
#842
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
from levanter.config import JsonAtom | ||
from levanter.data import AsyncDataset, DataLoader | ||
from levanter.distributed import DistributedConfig, RayConfig | ||
from levanter.grad_accum import microbatched | ||
from levanter.grad_accum import NumElementsBatch, ReductionType, microbatched | ||
from levanter.tracker import TrackerConfig, capture_time | ||
from levanter.trainer_state import TrainerState, saveable_training_mask | ||
from levanter.utils import cloud_utils, fsspec_utils | ||
|
@@ -380,7 +380,7 @@ def checkpoint_path(self) -> str: | |
checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) | ||
return checkpoint_path | ||
|
||
def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: | ||
def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]: | ||
""" | ||
Performs a single training step. | ||
""" | ||
|
@@ -529,7 +529,7 @@ def _train_step( | |
key, new_key = jax.random.split(state.training_key) | ||
model = inference_mode(state.model, False) | ||
|
||
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key) | ||
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key) | ||
|
||
with hax.axis_mapping(self.parameter_axis_mapping): | ||
if not _no_hooks: | ||
|
@@ -549,18 +549,28 @@ def obj_fun(trainable_model): | |
else: | ||
return loss, new_state, hook_infos | ||
|
||
def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]: | ||
def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]: | ||
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) | ||
mbs = self.config.microbatch_size | ||
reduce = ReductionType.MEAN | ||
if isinstance(batch, NumElementsBatch) and mbs != self.TrainBatch.size: | ||
batch_kwargs[ | ||
"batch_num_elements" | ||
] = batch.num_elements() # tell the loss function how many elements are in the batch | ||
batch_kwargs[ | ||
"reduction" | ||
] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average | ||
reduce = ReductionType.SUM # we're already normalizing the loss | ||
grad_fn = microbatched( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's have the grad_fn return both the total (masked) loss and the number of elements (and the gradient?). WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we'd change the signature of I also find myself wanting to log things this way (by returning elements in an The only concern is that if the number of microbatch elements or the losses are particularly large we could lose precision in the value of the loss right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's actually a standard way of dealing with this which I've encapsulated as RunningMean. The basic idea is to maintain the mean and the count rather than the sum and the count. This solves most precision problems usually. (You can actually do variance too but we don't need it) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dlwh I tried to encapsulate this in this last commit. I'm pretty sure it doesn't work right since it doesn't pass the |
||
grad_fn, | ||
self.TrainBatch, | ||
mbs, | ||
self.parameter_axis_mapping, | ||
self.compute_axis_mapping, | ||
reduce=reduce, | ||
) | ||
with hax.axis_mapping(self.compute_axis_mapping): | ||
return grad_fn(model, *batch, **batch_kwargs) | ||
return grad_fn(model, batch, **batch_kwargs) | ||
|
||
|
||
def _initialize_global_tracker(config, run_id): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho i think the thing to do is to use reduction=None inside grad accum and do the reduction separately. Also, can't you infer batch_num_elements from loss_mask.sum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dlwh We can't get that from
loss_mask.sum
since theloss_mask
we get is for the microbatch and not the whole batch. Instead of computing a mean loss per micro batch an averaging the mean losses, we need to sum all the microbatches dividing by the true number of losses in the batch.