Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axiswise scaling to Float8Linear #920

Merged
merged 20 commits into from
Oct 7, 2024
Merged

add axiswise scaling to Float8Linear #920

merged 20 commits into from
Oct 7, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 23, 2024

Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX.

Test Plan:

// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 23, 2024

Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/920

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d70326c with merge base 52d27a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Sep 23, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6368d8ec2fb50eea52cd54e1ca5724047483a7eb
ghstack-comment-id: 2368837904
Pull Request resolved: #920
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2024
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 23, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 77d62e8efb3a838035213125476c714290882a08
ghstack-comment-id: 2368837904
Pull Request resolved: #920
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 23, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0d471db431fab2195a86e84bc7d3a93cc25db6e4
ghstack-comment-id: 2368837904
Pull Request resolved: #920
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 23, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: af334fd3f9f0b10e2f0a7cf1e38513741d1b45f7
ghstack-comment-id: 2368837904
Pull Request resolved: #920
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 23, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 304a5427739966a9601fa860ed248fc2bb902d67
ghstack-comment-id: 2368837904
Pull Request resolved: #920
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo vkuzo requested a review from drisspg October 1, 2024 01:30
vkuzo added 2 commits October 2, 2024 08:29
[ghstack-poisoned]
[ghstack-poisoned]
@@ -39,6 +44,7 @@
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

# TODO(future PR): standardize IS_H100 with the rest of the codebase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want is SM89 for testing cublas matmuls, since in theory we have hardware with that capability on CI/CD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I meant the variable name, not what it's checking

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a separate path for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is choosing to move fast on a single use case at the expense of taking on some temporary tech debt

vkuzo added 3 commits October 2, 2024 16:06
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some minor readability nits

and other granularities in a separate PR.
"""

# TODO(this PR): types of inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, right?

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to input_fp8.flatten(0, -2)? If so, I find this more self-descriptive

orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversely, res_bits.unflatten(0, orig_shape[:-1]).

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_orig_shape = grad_output.shape
grad_output_reshaped = grad_output.reshape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and below

vkuzo added 3 commits October 4, 2024 09:50
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/10/head to main October 7, 2024 17:21
@vkuzo vkuzo merged commit e76db70 into main Oct 7, 2024
43 checks passed
jainapurva pushed a commit that referenced this pull request Oct 9, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Feel free to ignore the UX introduced in this PR, it's just an intermediate step.  See next PR for the real UX.

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
Summary:

This PR: support scaling of all arguments of all gemms to be axiswise,
and ensure that training with axiswise scaling works e2e.

Future PR: support more granular configurability and optimize
performance, add docs

Feel free to ignore the UX introduced in this PR, it's just an intermediate step.  See next PR for the real UX.

Test Plan:

```
// tests pass
./test/float8/test_everything.sh

// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
```

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Update README.md

* Adding additional minor changes and Using markdown note blocks

* Minor typos and undoing changes that are more impactful

* adds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants