Skip to content

Commit

Permalink
Merge pull request #79 from slimgroup/flux-fix
Browse files Browse the repository at this point in the history
fix Flux compat
  • Loading branch information
mloubout authored Apr 20, 2023
2 parents 5c092f9 + b2eeb12 commit df92d43
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
version = "2.2.4"
version = "2.2.5"

This comment has been minimized.

Copy link
@mloubout

mloubout Apr 20, 2023

Author Member

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
else
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
ΔX2 += ΔY2
end
ΔX_ = tensor_cat(ΔX1, ΔX2)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
end

# Input are two tensors ΔX, X
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N}
isnothing(permute) ? permute = H.permute : permute = permute

# Permutation
Expand Down
28 changes: 15 additions & 13 deletions src/utils/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using ChainRulesCore
export logdetjac
import ChainRulesCore: frule, rrule

export logdetjac, getrrule
import ChainRulesCore: frule, rrule, @non_differentiable

@non_differentiable get_params(::Invertible)
@non_differentiable get_params(::Reversed)
## Tape types and utilities

"""
Expand Down Expand Up @@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
if logdet isa Float32
state.logdet === nothing ? (state.logdet = logdet) : (state.logdet += logdet)
end

end

"""
Expand All @@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}
state.Y[state.counter_block] = X
state.counter_layer -= 1
end

state.counter_block == 0 && reset!(state) # reset state when first block/first layer is reached

end

## Chain rules for invertible networks
# General pullback function
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Check state coherency
check_coherence(state, net)
Expand All @@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
T2 = typeof(current(state))
ΔY = convert(T2, ΔY)
# Backward pass
ΔX, X_ = net.backward(ΔY, current(state))

ΔX, X_ = net.backward(ΔY, current(state); set_grad=true)
Δθ = getfield.(get_params(net), :grad)
# Update state
backward_update!(state, X_)

return nothing, ΔX
return NoTangent(), NoTangent(), ΔX, Δθ
end


# Reverse-mode AD rule
function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...;
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}

# Forward pass
net.logdet ? ((Y, logdet) = net.forward(X)) : (Y = net.forward(X); logdet = nothing)

Expand All @@ -142,4 +140,8 @@ end

## Logdet utilities for Zygote pullback

logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet

## Utility to get the pullback directly for testing

getrrule(net::Invertible, X::AbstractArray) = rrule(forward_net, net, X, getfield.(get_params(net), :data))
2 changes: 1 addition & 1 deletion src/utils/dimensionality_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ end

# Split and reshape 1D vector Y in latent space back to states Zi
# where Zi is the split tensor at each multiscale level.
function split_states(Y::AbstractVector{T}, Z_dims) where {T, N}
function split_states(Y::AbstractVector{T}, Z_dims) where {T}
L = length(Z_dims) + 1
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]
Expand Down
5 changes: 3 additions & 2 deletions src/utils/neuralnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}())

_get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s)
_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I)
_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I)
_get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}())

for m _INet_modes
Expand Down Expand Up @@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1})
end

# Make invertible nets callable objects
(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)
(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl",
"test_utils/test_activations.jl",
"test_utils/test_squeeze.jl",
"test_utils/test_jacobian.jl",
"test_utils/test_chainrules.jl"]
"test_utils/test_chainrules.jl",
"test_utils/test_flux.jl"]

# Layers
layers = ["test_layers/test_residual_block.jl",
Expand Down
28 changes: 14 additions & 14 deletions test/test_utils/test_chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ N10 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")

# Forward pass + gathering pullbacks
function fw(X)
X1, ∂1 = rrule(N1, X)
X2, ∂2 = rrule(N2, X1)
X3, ∂3 = rrule(N3, X2)
X1, ∂1 = getrrule(N1, X)
X2, ∂2 = getrrule(N2, X1)
X3, ∂3 = getrrule(N3, X2)
X5, ∂5 = Flux.Zygote.pullback(Chain(N4, N5), X3)
X6, ∂6 = rrule(N6, X5)
X7, ∂7 = rrule(N7, X6)
X6, ∂6 = getrrule(N6, X5)
X7, ∂7 = getrrule(N7, X6)
X9, ∂9 = Flux.Zygote.pullback(Chain(N8, N9), X7)
X10, ∂10 = rrule(N10, X9)
d1 = x -> ∂1(x)[2]
d2 = x -> ∂2(x)[2]
d3 = x -> ∂3(x)[2]
X10, ∂10 = getrrule(N10, X9)
d1 = x -> ∂1(x)[3]
d2 = x -> ∂2(x)[3]
d3 = x -> ∂3(x)[3]
d5 = x -> ∂5(x)[1]
d6 = x -> ∂6(x)[2]
d7 = x -> ∂7(x)[2]
d6 = x -> ∂6(x)[3]
d7 = x -> ∂7(x)[3]
d9 = x -> ∂9(x)[1]
d10 = x -> ∂10(x)[2]
d10 = x -> ∂10(x)[3]
return X10, ΔY -> d1(d2(d3(d5(d6(d7(d9(d10(ΔY))))))))
end

Expand All @@ -65,9 +65,9 @@ g2 = gradient(X -> loss(X), X)
## test Reverse network AD

Nrev = reverse(N10)
Xrev, ∂rev = rrule(Nrev, X)
Xrev, ∂rev = getrrule(Nrev, X)
grev = ∂rev(Xrev-Y0)

g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)

@test grev[2] g2rev[1] rtol=1f-6
@test grev[3] g2rev[1] rtol=1f-6
50 changes: 50 additions & 0 deletions test/test_utils/test_flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using InvertibleNetworks, Flux, Test, LinearAlgebra

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32

# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
pan, pc = deepcopy(get_params(AN)), deepcopy(get_params(C))
model = Chain(AN, C)

# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = model(X)
X0 = rand(Float32, nx, ny, n_in, batchsize) .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))

# old, implicit-style Flux
θ = Flux.params(model)
opt = Descent(0.001)

l, grads = Flux.withgradient(θ) do
loss(model, X0, Y)
end

for θi in θ
@test θi keys(grads.grads)
@test !isnothing(grads.grads[θi])
@test size(grads.grads[θi]) == size(θi)
end

Flux.update!(opt, θ, grads)

for i = 1:5
li, grads = Flux.withgradient(θ) do
loss(model, X, Y)
end

@info "Loss: $li"
@test li != l
global l = li

Flux.update!(opt, θ, grads)
end

1 comment on commit df92d43

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81994

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.2.5 -m "<description of version>" df92d4301b2d84e6d3c7637e7d643db2952b0532
git push origin v2.2.5

Please sign in to comment.