Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad accum when using loss_mask #842

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Aphoh
Copy link
Contributor

@Aphoh Aphoh commented Dec 13, 2024

Implement the fix described here. I don't think this is the cleanest way to do it, but my idea was to make a mixin for batch elements that allow them to specify when they have a specific num_elements that the microbatch losses should be divided by.

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for tackling this! I think we can push more of the logic into grad_accum instead and that will be a bit cleaner. WDYT?

@@ -20,6 +24,7 @@ def maybe_fused_next_token_loss(
loss_mask: Optional[NamedArray] = None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho i think the thing to do is to use reduction=None inside grad accum and do the reduction separately. Also, can't you infer batch_num_elements from loss_mask.sum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlwh We can't get that from loss_mask.sum since the loss_mask we get is for the microbatch and not the whole batch. Instead of computing a mean loss per micro batch an averaging the mean losses, we need to sum all the microbatches dividing by the true number of losses in the batch.

batch_kwargs[
"reduction"
] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average
reduce = ReductionType.SUM # we're already normalizing the loss
grad_fn = microbatched(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have the grad_fn return both the total (masked) loss and the number of elements (and the gradient?). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we'd change the signature of ComputeLossFunction to return something like
(model, input) -> (loss, extras) where extras has the number of elements in the microbatch in it?

I also find myself wanting to log things this way (by returning elements in an extras struct) so this would be a nice way to add that in.

The only concern is that if the number of microbatch elements or the losses are particularly large we could lose precision in the value of the loss right?
When computing (loss1 + loss2 + ... ) / num_elems v.s. loss1/num_elems + loss2/num_elems + ..., I would kinda expect the second to be more accurate. Is this a reasonable concern or am I prematurely optimizing?

Copy link
Contributor Author

@Aphoh Aphoh Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be a valid concern given we'd be summing up fp16 gradients? This is what happens if you sum up a bunch of values near ln(vocab_size) which is around the maximum value you'd get from the loss.
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's actually a standard way of dealing with this which I've encapsulated as RunningMean. The basic idea is to maintain the mean and the count rather than the sum and the count. This solves most precision problems usually. (You can actually do variance too but we don't need it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlwh I tried to encapsulate this in this last commit. I'm pretty sure it doesn't work right since it doesn't pass the grad_accum tests (trees don't match). Could you give it a quick peek and let me know if you see anything?

@Aphoh Aphoh force-pushed the aphoh/fix-grad-accum branch from d37665a to e8a78d3 Compare December 18, 2024 06:19
@Aphoh Aphoh force-pushed the aphoh/fix-grad-accum branch from e8a78d3 to 3e60320 Compare December 18, 2024 06:32
@Aphoh
Copy link
Contributor Author

Aphoh commented Dec 18, 2024

passing tests, I think this is ready for a peek @dlwh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants