-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stablehlo.compare
derivative isn't implemented
#57
Comments
Can you print out the MLIR module before AD?
…On Fri, Jul 26, 2024 at 9:21 PM Avik Pal ***@***.***> wrote:
error: Unimplemented derivative for argument 0 in reverse mode for op %4 = "stablehlo.select"(%3, %2, %1) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
error: could not compute the adjoint for this operation %3 = "stablehlo.compare"(%2, %1) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
*Originally posted by @avik-pal <https://github.com/avik-pal> in #55
(comment)
<#55 (comment)>*
The relu activation test is marked broken for now. Once this is fixed
that should pass.
—
Reply to this email directly, view it on GitHub
<#57>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXFU24NZIYHKOWTHINLZOLYZVAVCNFSM6AAAAABLRNBQIWVHI2DSMVQWIX3LMV43ASLTON2WKOZSGQZTGMJWGA4DENQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
The problem here is that an invalid compare was emitted
On Fri, Jul 26, 2024 at 9:45 PM William Moses ***@***.***>
wrote:
… Can you print out the MLIR module before AD?
On Fri, Jul 26, 2024 at 9:21 PM Avik Pal ***@***.***> wrote:
> error: Unimplemented derivative for argument 0 in reverse mode for op %4 = "stablehlo.select"(%3, %2, %1) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
>
> error: could not compute the adjoint for this operation %3 = "stablehlo.compare"(%2, %1) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
>
> *Originally posted by @avik-pal <https://github.com/avik-pal> in #55
> (comment)
> <#55 (comment)>*
>
> The relu activation test is marked broken for now. Once this is fixed
> that should pass.
>
> —
> Reply to this email directly, view it on GitHub
> <#57>, or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AAJTUXFU24NZIYHKOWTHINLZOLYZVAVCNFSM6AAAAABLRNBQIWVHI2DSMVQWIX3LMV43ASLTON2WKOZSGQZTGMJWGA4DENQ>
> .
> You are receiving this because you are subscribed to this thread.Message
> ID: ***@***.***>
>
|
julia> Reactant.@code_hlo sumabs2(relu, x_act_ca)
Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<10x10xf32>) -> tensor<f32> {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = stablehlo.compare GT, %0, %cst_0 : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
%2 = stablehlo.select %1, %0, %cst_0 : tensor<10x10xi1>, tensor<10x10xf32>
%3 = stablehlo.multiply %2, %2 : tensor<10x10xf32>
%4 = stablehlo.reduce(%3 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x10xf32>, tensor<f32>) -> tensor<f32>
return %4 : tensor<f32>
}
}
julia> Reactant.@code_hlo optimize=false sumabs2(relu, x_act_ca)
Module:
module {
func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%2 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%3 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %2, %3 : tensor<f32>, tensor<f32>
}
func.func @main(%arg0: tensor<10x10xf32>) -> (tensor<10x10xf32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
%1 = stablehlo.compare GT, %0, %cst : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
%2 = stablehlo.select %1, %0, %cst_0 : tensor<10x10xi1>, tensor<10x10xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%3:2 = enzyme.batch @abs2_broadcast_scalar(%2) {batch_shape = array<i64: 10, 10>} : (tensor<10x10xf32>) -> (tensor<10x10xf32>, tensor<10x10xf32>)
%4 = stablehlo.reduce(%3#1 init: %cst_1) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x10xf32>, tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
%6 = stablehlo.transpose %4, dims = [] : (tensor<f32>) -> tensor<f32>
return %5, %6 : tensor<10x10xf32>, tensor<f32>
}
} |
fun fact the optimize=false is printed before AD is run, so you could even do the optimize=false on the function with the autodiff [which will make easier to repro in the future] |
This should be fixed with EnzymeAD/Enzyme-JAX#106. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Originally posted by @avik-pal in #55 (comment)
The
relu
activation test is marked broken for now. Once this is fixed that should pass.The text was updated successfully, but these errors were encountered: