-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix grad accum when using loss_mask
#842
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for tackling this! I think we can push more of the logic into grad_accum instead and that will be a bit cleaner. WDYT?
src/levanter/models/loss.py
Outdated
@@ -20,6 +24,7 @@ def maybe_fused_next_token_loss( | |||
loss_mask: Optional[NamedArray] = None, | |||
reduction: Optional[hax.ReductionFunction] = hax.mean, | |||
reduction_axis: Optional[hax.AxisSelection] = None, | |||
batch_num_elements: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho i think the thing to do is to use reduction=None inside grad accum and do the reduction separately. Also, can't you infer batch_num_elements from loss_mask.sum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dlwh We can't get that from loss_mask.sum
since the loss_mask
we get is for the microbatch and not the whole batch. Instead of computing a mean loss per micro batch an averaging the mean losses, we need to sum all the microbatches dividing by the true number of losses in the batch.
batch_kwargs[ | ||
"reduction" | ||
] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average | ||
reduce = ReductionType.SUM # we're already normalizing the loss | ||
grad_fn = microbatched( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's have the grad_fn return both the total (masked) loss and the number of elements (and the gradient?). WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we'd change the signature of ComputeLossFunction
to return something like
(model, input) -> (loss, extras)
where extras
has the number of elements in the microbatch in it?
I also find myself wanting to log things this way (by returning elements in an extras
struct) so this would be a nice way to add that in.
The only concern is that if the number of microbatch elements or the losses are particularly large we could lose precision in the value of the loss right?
When computing (loss1 + loss2 + ... ) / num_elems
v.s. loss1/num_elems + loss2/num_elems + ...
, I would kinda expect the second to be more accurate. Is this a reasonable concern or am I prematurely optimizing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's actually a standard way of dealing with this which I've encapsulated as RunningMean. The basic idea is to maintain the mean and the count rather than the sum and the count. This solves most precision problems usually. (You can actually do variance too but we don't need it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dlwh I tried to encapsulate this in this last commit. I'm pretty sure it doesn't work right since it doesn't pass the grad_accum
tests (trees don't match). Could you give it a quick peek and let me know if you see anything?
d37665a
to
e8a78d3
Compare
e8a78d3
to
3e60320
Compare
passing tests, I think this is ready for a peek @dlwh |
Implement the fix described here. I don't think this is the cleanest way to do it, but my idea was to make a mixin for batch elements that allow them to specify when they have a specific
num_elements
that the microbatch losses should be divided by.