Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

bring back torch.autograd.Function for float8 matmul #336

Closed
wants to merge 5 commits into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 25, 2024

Stack from ghstack (oldest at bottom):

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A torch.autograd.Function override
is the cleanest way to do that, and in 2024 this now works with
torch.compile.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D60252068

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6cb1588bf59be73b5782f6af94e7a360eba7f40e
Pull Request resolved: #336
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 42dd59511e4ec2a55846c2593955c4ff5f12b254
Pull Request resolved: #336
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 25, 2024

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068)

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 26, 2024

started a new PR due to ghstack error

@vkuzo vkuzo closed this Jul 26, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants