Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten.linear.default implementation to mx_ops #806

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Ali-Flt
Copy link

@Ali-Flt Ali-Flt commented Sep 4, 2024

Fixes #796

Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/806

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8b79caa with merge base f5703b0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Ali-Flt!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 4, 2024
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@vkuzo
Copy link
Contributor

vkuzo commented Sep 4, 2024

@jerryzh168 jerryzh168 requested a review from vkuzo September 4, 2024 17:26
@Ali-Flt
Copy link
Author

Ali-Flt commented Sep 4, 2024

@vkuzo I tried different ways to trigger the use of aten.linear but couldn't

@vkuzo
Copy link
Contributor

vkuzo commented Sep 4, 2024

I'd definitely recommend a test that fails before this PR and passes after this PR. Would it work to wrap the code snippet you were using (which ended up calling into the linear override) into a test?

@gau-nernst
Copy link
Collaborator

Just curious. Would it be better to implement F.linear() under __torch_function__() instead? Previously I also faced strange behavior on what aten ops will be dispatched by F.linear(), so implementing F.linear() directly would solve the problem.

@vkuzo
Copy link
Contributor

vkuzo commented Sep 5, 2024

Would it be better to implement F.linear() under torch_function() instead?

As of ~months ago, __torch_dispatch__ was better supported with torch.compile, at least for the things we needed for float8. I haven't checked if torch.compile + __torch_function__ coverage is better now, would be good to check.

@balaabhijit
Copy link

+1 Kudos for the fix! Can confirm I also ran into this same error without this PR

@balaabhijit
Copy link

A simple test which will fail before this PR on:

torch==2.5.1+cu121
torchao==0.6.1
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear

class MLP(nn.Module):
    def __init__(self, in_features: int = 128, out_features: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=out_features)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.nn.functional.gelu(x)
        return x


model = MLP()
# Does not hit the error with swap_linear_with_mx_inference_linear
swap_linear_with_mx_linear(
        model, elem_dtype=torch.float8_e4m3fn, block_size=32)

input_tensor = torch.randn(10, 128)
with torch.inference_mode():
    _ = model(input_tensor)

@balaabhijit
Copy link

@vkuzo Thanks for the great insight on __torch_dispatch__ vs __torch_function__ this is super helpful

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Make command box contiguous again, but dupplicating the single command we need into a separate command issues with updown/shell
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NotImplementedError: aten.linear.default not implemented when using MXTensor
5 participants