Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dilation and feature/batch group count in convolution reverse #181

Merged
merged 3 commits into from
Dec 14, 2024

Conversation

Pangoraw
Copy link
Collaborator

@Pangoraw Pangoraw commented Dec 9, 2024

No description provided.

Comment on lines 21 to 23
// REVERSE-NEXT{LITERAL}: %1 = stablehlo.reshape %0 : (tensor<8x66x66x512xf32>) -> tensor<8x66x66x1x512xf32>
// REVERSE-NEXT{LITERAL}: %2 = stablehlo.transpose %1, dims = [1, 2, 0, 3, 4] : (tensor<8x66x66x1x512xf32>) -> tensor<66x66x8x1x512xf32>
// REVERSE-NEXT{LITERAL}: %3 = stablehlo.reshape %2 : (tensor<66x66x8x1x512xf32>) -> tensor<8x66x66x512xf32>
Copy link
Collaborator Author

@Pangoraw Pangoraw Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add an optimization to get rid of this pattern in a follow-up:

(transpose (unsqueeze)) -> (unsqueeze) // unsqueeze is a reshape that introduces 1 sized dims

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this reshape feels sketchy, but I'll trust your math here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that optimization is only true if the permuted dims are not relative to the original dims right?

like you only permute 1-sized dims with any other dim

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's right it has no effect only when permuting 1 sized dims.

This is to support batch group count which is moved in the feature dimension during the reverse, having a static select tablegen operator would be nice for these things:

(StaticSelect (Cond), $a, $b)

would create either $a or $b based on the condition (would also be useful for #176 IIUC).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay to use a regular select here? Constant propagation will immediately get rid of it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, actually I think this reshape is wrong here because it's not actually a noop. We actually have the optimization I was referring to already, it's called TransposeIsReshape

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the problem in the reshape/transpose, the operations are completely removed in the case of batchGroupCount == 1 (which will always be the case from NNlib.conv).

@wsmoses wsmoses merged commit 024902a into EnzymeAD:main Dec 14, 2024
5 of 9 checks passed
vimarsh6739 pushed a commit that referenced this pull request Dec 14, 2024
#181)

* support dilation and feature group count in convolution reverse

* support batch group count

* fix dimensions for post conv transpose (batch gruop count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants