Skip to content

Commit

Permalink
Merge pull request #68 from slimgroup/activation-hint
Browse files Browse the repository at this point in the history
add activation parameter to conditional multiscale hint
  • Loading branch information
rafaelorozco authored Nov 28, 2022
2 parents f6b2d17 + bfb1149 commit 5f571ec
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
version = "2.2.0"
version = "2.2.1"

This comment has been minimized.

Copy link
@mloubout

mloubout Nov 28, 2022

Member

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
16 changes: 7 additions & 9 deletions src/conditional_layers/conditional_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ end

# 2D Constructor from input dimensions
function ConditionalLayerHINT(n_in::Int64, n_hidden::Int64; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1,
logdet=true, permute=true, ndims=2)
logdet=true, permute=true, ndims=2, activation::ActivationFunction=SigmoidLayer())

# Create basic coupling layers
CL_X = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
CL_Y = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
CL_YX = CouplingLayerBasic(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
CL_X = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
CL_Y = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
CL_YX = CouplingLayerBasic(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)

# Permutation using 1x1 convolution
permute == true ? (C_X = Conv1x1(n_in)) : (C_X = nothing)
Expand Down Expand Up @@ -209,11 +209,10 @@ function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::Abs
return ΔZx, ΔZy, Zx, Zy
end

function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=false) where {T, N}
~isnothing(CH.C_Y) ? (Yp = CH.C_Y.forward(Y)) : (Yp = copy(Y))
Zy = CH.CL_Y.forward(Yp; logdet=false)
return Zy

logdet ? (Zy, logdet_) = CH.CL_Y.forward(Yp; logdet=true) : Zy = CH.CL_Y.forward(Yp; logdet=false)
logdet ? (return Zy, logdet_) : (return Zy)
end

function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
Expand All @@ -222,7 +221,6 @@ function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T,
return Y
end


## Jacobian-related utils

function jacobian(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
Expand Down
6 changes: 3 additions & 3 deletions src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ end

# Constructor for given coupling layer and 1 x 1 convolution
CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing};
logdet=false, permute="none") = CouplingLayerHINT(CL, C, logdet, permute, false)
logdet=false, permute="none", activation::ActivationFunction=SigmoidLayer()) = CouplingLayerHINT(CL, C, logdet, permute, false)

# 2D Constructor from input dimensions
function CouplingLayerHINT(n_in::Int64, n_hidden::Int64; logdet=false, permute="none",
k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2, activation::ActivationFunction=SigmoidLayer())

# Create basic coupling layers
n = get_depth(n_in)
CL = Array{CouplingLayerBasic}(undef, n)
for j=1:n
CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden; k1=k1, k2=k2, p1=p1, p2=p2,
CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden;activation=activation, k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end

Expand Down
1 change: 1 addition & 0 deletions src/networks/invertible_network_conditional_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA
end

if G.split_scales
ΔC_total = G.squeezer.inverse(ΔC_total)
C = G.squeezer.inverse(C)
X = G.squeezer.inverse(X)
ΔX = G.squeezer.inverse(ΔX)
Expand Down
4 changes: 2 additions & 2 deletions src/networks/invertible_network_conditional_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
@Flux.functor NetworkConditionalHINT

# Constructor
function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2)
function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2,activation::ActivationFunction=SigmoidLayer(), )

AN_X = Array{ActNorm}(undef, depth)
AN_Y = Array{ActNorm}(undef, depth)
Expand All @@ -68,7 +68,7 @@ function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s
for j=1:depth
AN_X[j] = ActNorm(n_in; logdet=logdet)
AN_Y[j] = ActNorm(n_in; logdet=logdet)
CL[j] = ConditionalLayerHINT(n_in, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
CL[j] = ConditionalLayerHINT(n_in, n_hidden; activation=activation,permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end

Expand Down
45 changes: 24 additions & 21 deletions src/networks/invertible_network_conditional_hint_multiscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end

# Constructor
function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer())
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())

AN_X = Array{ActNorm}(undef, L, K)
AN_Y = Array{ActNorm}(undef, L, K)
Expand All @@ -89,7 +89,7 @@ function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64
for j=1:K
AN_X[i, j] = ActNorm(n_in*4; logdet=logdet)
AN_Y[i, j] = ActNorm(n_in*4; logdet=logdet)
CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, activation=activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end
n_in *= channel_factor
end
Expand Down Expand Up @@ -131,6 +131,28 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkMult
logdet ? (return X, Y, logdet_) : (return X, Y)
end

# Forward pass and compute logdet
function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=false) where {T, N}
CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))

logdet_ = 0f0
for i=1:CH.L
Y = CH.squeezer.forward(Y)
for j=1:CH.K
logdet ? (Y_,logdet1) = CH.AN_Y[i, j].forward(Y; logdet=true) : Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
logdet ? (Y_,logdet2) = CH.CL[i, j].forward_Y(Y_; logdet=true) : Y = CH.CL[i, j].forward_Y(Y_; logdet=false)
logdet && (logdet_ += (logdet1 + logdet2))
end
if CH.split_scales && i < CH.L # don't split after last iteration
Y, Zy = tensor_split(Y)
Y_save[i] = Zy
CH.XY_dims[i] = collect(size(Zy))
end
end
CH.split_scales && (Y = cat_states(Y_save, Y))
logdet ? (return Y, logdet_) : (return Y)
end

# Inverse pass and compute gradients
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet
Expand Down Expand Up @@ -234,26 +256,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::NetworkMultiScaleConditionalHINT)
end
end

# Forward pass and compute logdet
function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}
CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))

for i=1:CH.L
Y = CH.squeezer.forward(Y)
for j=1:CH.K
Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
Y = CH.CL[i, j].forward_Y(Y_)
end
if CH.split_scales && i < CH.L # don't split after last iteration
Y, Zy = tensor_split(Y)
Y_save[i] = Zy
CH.XY_dims[i] = collect(size(Zy))
end
end
CH.split_scales && (Y = cat_states(Y_save, Y))
return Y

end

# Inverse pass and compute gradients
function inverse_Y(Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}
Expand Down
4 changes: 2 additions & 2 deletions src/networks/invertible_network_hint_multiscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end

# Constructor
function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, activation::ActivationFunction=SigmoidLayer(), ndims=2)

AN = Array{ActNorm}(undef, L, K)
CL = Array{CouplingLayerHINT}(undef, L, K)
Expand All @@ -83,7 +83,7 @@ function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
for i=1:L
for j=1:K
AN[i, j] = ActNorm(n_in*4; logdet=true)
CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; activation=activation,permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=true, ndims=ndims)
end
n_in *= channel_factor
Expand Down
4 changes: 2 additions & 2 deletions src/networks/invertible_network_irim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
@Flux.functor NetworkLoop

# 2D Constructor
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2)
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2, activation::ActivationFunction=SigmoidLayer())

if type == "additive"
L = Array{CouplingLayerIRIM}(undef, maxiter)
Expand All @@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4,
AN = Array{ActNorm}(undef, maxiter)
for j=1:maxiter
if type == "additive"
L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2,activation=activation, ndims=ndims)
elseif type == "HINT"
L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, ndims=ndims)
Expand Down
3 changes: 2 additions & 1 deletion test/test_layers/test_actnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Test, Statistics
using Random


Random.seed!(11)
###############################################################################
# Test logdet implementation

Expand Down

1 comment on commit 5f571ec

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/73032

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.2.1 -m "<description of version>" 5f571ece49ce9b04bbda40115b91f8fe4332e2c5
git push origin v2.2.1

Please sign in to comment.