Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more coverage for common NN operations #55

Merged
merged 15 commits into from
Aug 5, 2024

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Jul 26, 2024

Overview

  • Activations
    • sigmoid
    • sigmoid_fast
    • relu
    • gelu
  • abs2
  • normalization
    • mean
    • var
  • Add tests
    • Activations
    • abs2
    • normalization

@avik-pal avik-pal changed the title feat: more coverage for common NN activations feat: more coverage for common NN operations Jul 26, 2024
@avik-pal
Copy link
Collaborator Author

Need to fix the compats. Lux failure will be gone with LuxDL/LuxLib.jl#105.

ext/ReactantNNlibExt.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Collaborator Author

1.11 failures are due to Enzyme not working on 1.11

end
end
end

function Reactant.elem_apply(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally shouldn’t need to add more elem alllirs anymore now that we have batching properly. I think just defining this for an rarray of size zero should suffice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this

function NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T}
    return max(x, zero(T))
end

but I am getting

error: 'stablehlo.constant' op inferred type(s) 'tensor<f32>' are incompatible with return type(s) of operation 'tensor<2x3xf32>'
error: 'stablehlo.constant' op failed to infer returned types
ERROR: "failed to run pass manager on module"
Stacktrace:
 [1] run!
   @ /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Pass.jl:70 [inlined]
 [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
   @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1178
 [3] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Any}; optimize::Bool)
   @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1198
 [4] (::var"#51#52")()
   @ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1315
 [5] context!(f::var"#51#52", ctx::Reactant.MLIR.IR.Context)
   @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
 [6] top-level scope
   @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1313

Pre-optimize the code is

Module:
module {
  func.func private @relu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.maximum %0, %cst : tensor<f32>
    %2 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %3 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %2, %3 : tensor<f32>, tensor<f32>
  }
  func.func @main(%arg0: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32>
    %1:2 = enzyme.batch @relu_broadcast_scalar(%0) {batch_shape = array<i64: 2, 3>} : (tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>)
    %2 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
    %3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
    return %2, %3 : tensor<3x2xf32>, tensor<3x2xf32>
  }
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "we dont support constants in batching yet" which I'm presently working on. I'll try to get this squared today/tomorrow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I just leave a "TODO" in comment for relu or wait for the support?

src/overloads.jl Outdated Show resolved Hide resolved
src/overloads.jl Outdated Show resolved Hide resolved
::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
# See https://arxiv.org/pdf/1606.08415v5 Section 2
return lhs .* sigmoid.(ElType(1.702) .* lhs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an erf op in HLO?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can use any of the dialects? I was using https://openxla.org/s/results?q=erf#gsc.tab=0&gsc.q=erf&gsc.sort= as a reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we feel about adding a direct dep on SpecialFunctions for erf/erfinv...? Else we will have to create a 2nd NNlibSpecialFunctionsExt to define the exact gelu impl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presently yes (and we need to make sure we have corresponding lowering from one to the others, and potentially derivatives)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

erf doesn't seem to have the derivatives implemented EnzymeAD/Enzyme-JAX#88, so I am more inclined to keep the current implementation, and switch it later once derivatives are implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to take a stab at it? It would go right here https://github.com/EnzymeAD/Enzyme-JAX/blob/4eeaef06e0da144bebd08ec739cf01911dcddb47/src/enzyme_ad/jax/Implementations/CHLODerivatives.td#L142 and shouldn't be bad cc @mofeing and or @Pangoraw who may be able to help with

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will give it a shot

@avik-pal
Copy link
Collaborator Author

How important is 1.9 support? Lux doesn't support 1.9 (neither does ArrayInterface) which causes the test failures. I could skip the tests on 1.9 but now that 1.10 will be LTS we might as well drop it

@wsmoses
Copy link
Member

wsmoses commented Jul 27, 2024 via email

@codecov-commenter
Copy link

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

@wsmoses
Copy link
Member

wsmoses commented Aug 1, 2024

@avik-pal the jll bump with constant propagation of broadcast is merged. rebase this?

@avik-pal avik-pal force-pushed the ap/broadcast_coverage branch from 9cf45c7 to 2c65ae7 Compare August 2, 2024 00:15
@avik-pal
Copy link
Collaborator Author

avik-pal commented Aug 2, 2024

@wsmoses
Copy link
Member

wsmoses commented Aug 2, 2024 via email

@avik-pal
Copy link
Collaborator Author

avik-pal commented Aug 2, 2024

Module:
module {
  func.func private @relu_broadcast_scalar(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %1 = stablehlo.maximum %0, %cst : tensor<f64>
    %2 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
    %3 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
    return %2, %3 : tensor<f64>, tensor<f64>
  }
  func.func @main(%arg0: tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %1:2 = enzyme.batch @relu_broadcast_scalar(%0) {batch_shape = array<i64: 2, 2>} : (tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>)
    %2 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    return %2, %3 : tensor<2x2xf64>, tensor<2x2xf64>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Aug 4, 2024

@avik-pal the answer was silly that we just didn't bump the repo dependency commit: 8f25b97

@wsmoses
Copy link
Member

wsmoses commented Aug 4, 2024

Now awaiting jll: JuliaPackaging/Yggdrasil#9196 [tho also @avik-pal this won't include the erf derivative yet -- unless you have a PR we can quickly merge and then get into the new jll]

@avik-pal
Copy link
Collaborator Author

avik-pal commented Aug 4, 2024

let's move forward without the erf derivative for now. I am trying to help one of our GSoCs with some benchmarking so that needs to be finished first 😓

@wsmoses
Copy link
Member

wsmoses commented Aug 4, 2024

@avik-pal the jll merged, can you rebase here?

@avik-pal avik-pal force-pushed the ap/broadcast_coverage branch from 343a45d to d5d75f3 Compare August 4, 2024 21:45
Project.toml Outdated Show resolved Hide resolved
src/Reactant.jl Outdated Show resolved Hide resolved
@@ -140,6 +172,17 @@ for (jlop, hloop, RT) in (
end
end

function Base.ifelse(
pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this generalize to any shape/size, not just 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't the broadcasting handle the shape automatically?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but someone could also call ifelse(true, ones(4,4), zeros(4,4)) or ifelse(trues(4,4), ones(4,4), zeros(4,4)), etc, outside a broadcast [tho yes the 0 dim one will generalize to anything in a broadcast]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I don't think the latter case is legal in julia atm, so just generalizing to ifelse(true, ones(4,4), zeros(4,4)) probably makes sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is wrong with this version I defined:

julia> f(x) = ifelse.(true, x, x)
f (generic function with 1 method)

julia> Reactant.@code_hlo optimize=false f(x)
Module:
module {
  func.func private @ifelse_broadcast_scalar(%arg0: tensor<i1>, %arg1: tensor<f64>) -> (tensor<i1>, tensor<f64>, tensor<f64>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<i1>) -> tensor<i1>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f64>) -> tensor<f64>
    %2 = stablehlo.select %0, %1, %1 : tensor<i1>, tensor<f64>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<i1>) -> tensor<i1>
    %4 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
    %5 = stablehlo.transpose %2, dims = [] : (tensor<f64>) -> tensor<f64>
    return %3, %4, %5 : tensor<i1>, tensor<f64>, tensor<f64>
  }
  func.func @main(%arg0: tensor<3x2xf64>) -> (tensor<3x2xf64>, tensor<3x2xf64>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x2xf64>) -> tensor<2x3xf64>
    %c = stablehlo.constant dense<true> : tensor<2x3xi1>
    %1:3 = enzyme.batch @ifelse_broadcast_scalar(%c, %0) {batch_shape = array<i64: 2, 3>} : (tensor<2x3xi1>, tensor<2x3xf64>) -> (tensor<2x3xi1>, tensor<2x3xf64>, tensor<2x3xf64>)
    %2 = stablehlo.transpose %1#2, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    %3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
    return %2, %3 : tensor<3x2xf64>, tensor<3x2xf64>
  }
}

julia> Reactant.@code_hlo f(x)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<3x2xf64>) {
    return
  }
}

src/overloads.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/broadcast_coverage branch from d74fbb2 to d0e1dbd Compare August 4, 2024 22:28
@wsmoses wsmoses merged commit 4177c8d into EnzymeAD:main Aug 5, 2024
8 of 12 checks passed
@avik-pal avik-pal deleted the ap/broadcast_coverage branch August 5, 2024 20:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants