Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile]: Wrap fp8 conversion kernels in autograd func and use in Float8NoCompileLinear #1452

Merged
merged 2 commits into from
Dec 20, 2024

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 19, 2024

Summary

  • Wrap the float8 compile triton kernels in an autograd func to make the conversion differentiable (for use in training)
  • Integrate these autograd functions into the Float8NoCompileLinear class

Test plan

  • The unit tests enforcing fidelity between float8 training path and this prototype's path are passing.

pytest test/test.py

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1452

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b8e1c6a with merge base 692236a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2024
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Dec 19, 2024
@danielvegamyhre
Copy link
Contributor Author

adding @drisspg @vkuzo for review

@drisspg drisspg requested review from vkuzo and drisspg December 19, 2024 23:51
@danielvegamyhre danielvegamyhre changed the title [float8nocompile]: Wrap float8nocompile conversion kernels in autograd func [float8nocompile]: Wrap float8nocompile conversion kernels in autograd func and use in Float8NoCompileLinear Dec 20, 2024
@danielvegamyhre danielvegamyhre changed the title [float8nocompile]: Wrap float8nocompile conversion kernels in autograd func and use in Float8NoCompileLinear [float8nocompile]: Wrap fp8 conversion kernels in autograd func and use in Float8NoCompileLinear Dec 20, 2024
"""
A differentiable conversion to fp8.
* forward: no-op
* backward: convert to fp8_e5m2 with tensor-wise dynamic scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove e5m2 from docstring, since the target dtype is passed as an argument


@staticmethod
def backward(ctx, gradY):
# cast grad output to e5m2 in backward pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove e5m2 from comment

@vkuzo
Copy link
Contributor

vkuzo commented Dec 20, 2024

looks good, please ensure CI is green before landing

@danielvegamyhre danielvegamyhre merged commit 29de3e0 into pytorch:main Dec 20, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants