diff --git a/Project.toml b/Project.toml
index d98fbf67..251a0981 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte
", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. herrmann "]
-version = "2.2.0"
+version = "2.2.1"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
diff --git a/src/conditional_layers/conditional_layer_hint.jl b/src/conditional_layers/conditional_layer_hint.jl
index 364b7c59..cf8b75e1 100644
--- a/src/conditional_layers/conditional_layer_hint.jl
+++ b/src/conditional_layers/conditional_layer_hint.jl
@@ -65,12 +65,12 @@ end
# 2D Constructor from input dimensions
function ConditionalLayerHINT(n_in::Int64, n_hidden::Int64; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1,
- logdet=true, permute=true, ndims=2)
+ logdet=true, permute=true, ndims=2, activation::ActivationFunction=SigmoidLayer())
# Create basic coupling layers
- CL_X = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
- CL_Y = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
- CL_YX = CouplingLayerBasic(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
+ CL_X = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
+ CL_Y = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
+ CL_YX = CouplingLayerBasic(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
# Permutation using 1x1 convolution
permute == true ? (C_X = Conv1x1(n_in)) : (C_X = nothing)
@@ -209,11 +209,10 @@ function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::Abs
return ΔZx, ΔZy, Zx, Zy
end
-function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
+function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=false) where {T, N}
~isnothing(CH.C_Y) ? (Yp = CH.C_Y.forward(Y)) : (Yp = copy(Y))
- Zy = CH.CL_Y.forward(Yp; logdet=false)
- return Zy
-
+ logdet ? (Zy, logdet_) = CH.CL_Y.forward(Yp; logdet=true) : Zy = CH.CL_Y.forward(Yp; logdet=false)
+ logdet ? (return Zy, logdet_) : (return Zy)
end
function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
@@ -222,7 +221,6 @@ function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T,
return Y
end
-
## Jacobian-related utils
function jacobian(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl
index fb56e0f8..ca21a45a 100644
--- a/src/layers/invertible_layer_hint.jl
+++ b/src/layers/invertible_layer_hint.jl
@@ -73,17 +73,17 @@ end
# Constructor for given coupling layer and 1 x 1 convolution
CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing};
- logdet=false, permute="none") = CouplingLayerHINT(CL, C, logdet, permute, false)
+ logdet=false, permute="none", activation::ActivationFunction=SigmoidLayer()) = CouplingLayerHINT(CL, C, logdet, permute, false)
# 2D Constructor from input dimensions
function CouplingLayerHINT(n_in::Int64, n_hidden::Int64; logdet=false, permute="none",
- k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
+ k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2, activation::ActivationFunction=SigmoidLayer())
# Create basic coupling layers
n = get_depth(n_in)
CL = Array{CouplingLayerBasic}(undef, n)
for j=1:n
- CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden; k1=k1, k2=k2, p1=p1, p2=p2,
+ CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden;activation=activation, k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end
diff --git a/src/networks/invertible_network_conditional_glow.jl b/src/networks/invertible_network_conditional_glow.jl
index 073fdde3..253b329a 100644
--- a/src/networks/invertible_network_conditional_glow.jl
+++ b/src/networks/invertible_network_conditional_glow.jl
@@ -169,6 +169,7 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA
end
if G.split_scales
+ ΔC_total = G.squeezer.inverse(ΔC_total)
C = G.squeezer.inverse(C)
X = G.squeezer.inverse(X)
ΔX = G.squeezer.inverse(ΔX)
diff --git a/src/networks/invertible_network_conditional_hint.jl b/src/networks/invertible_network_conditional_hint.jl
index ace3a875..9144e3a6 100644
--- a/src/networks/invertible_network_conditional_hint.jl
+++ b/src/networks/invertible_network_conditional_hint.jl
@@ -58,7 +58,7 @@ end
@Flux.functor NetworkConditionalHINT
# Constructor
-function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2)
+function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2,activation::ActivationFunction=SigmoidLayer(), )
AN_X = Array{ActNorm}(undef, depth)
AN_Y = Array{ActNorm}(undef, depth)
@@ -68,7 +68,7 @@ function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s
for j=1:depth
AN_X[j] = ActNorm(n_in; logdet=logdet)
AN_Y[j] = ActNorm(n_in; logdet=logdet)
- CL[j] = ConditionalLayerHINT(n_in, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
+ CL[j] = ConditionalLayerHINT(n_in, n_hidden; activation=activation,permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end
diff --git a/src/networks/invertible_network_conditional_hint_multiscale.jl b/src/networks/invertible_network_conditional_hint_multiscale.jl
index f3134cc3..0e89bc6b 100644
--- a/src/networks/invertible_network_conditional_hint_multiscale.jl
+++ b/src/networks/invertible_network_conditional_hint_multiscale.jl
@@ -71,7 +71,7 @@ end
# Constructor
function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
- split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer())
+ split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN_X = Array{ActNorm}(undef, L, K)
AN_Y = Array{ActNorm}(undef, L, K)
@@ -89,7 +89,7 @@ function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64
for j=1:K
AN_X[i, j] = ActNorm(n_in*4; logdet=logdet)
AN_Y[i, j] = ActNorm(n_in*4; logdet=logdet)
- CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
+ CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, activation=activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
end
n_in *= channel_factor
end
@@ -131,6 +131,28 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkMult
logdet ? (return X, Y, logdet_) : (return X, Y)
end
+# Forward pass and compute logdet
+function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=false) where {T, N}
+ CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))
+
+ logdet_ = 0f0
+ for i=1:CH.L
+ Y = CH.squeezer.forward(Y)
+ for j=1:CH.K
+ logdet ? (Y_,logdet1) = CH.AN_Y[i, j].forward(Y; logdet=true) : Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
+ logdet ? (Y_,logdet2) = CH.CL[i, j].forward_Y(Y_; logdet=true) : Y = CH.CL[i, j].forward_Y(Y_; logdet=false)
+ logdet && (logdet_ += (logdet1 + logdet2))
+ end
+ if CH.split_scales && i < CH.L # don't split after last iteration
+ Y, Zy = tensor_split(Y)
+ Y_save[i] = Zy
+ CH.XY_dims[i] = collect(size(Zy))
+ end
+ end
+ CH.split_scales && (Y = cat_states(Y_save, Y))
+ logdet ? (return Y, logdet_) : (return Y)
+end
+
# Inverse pass and compute gradients
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet
@@ -234,26 +256,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::NetworkMultiScaleConditionalHINT)
end
end
-# Forward pass and compute logdet
-function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}
- CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))
-
- for i=1:CH.L
- Y = CH.squeezer.forward(Y)
- for j=1:CH.K
- Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
- Y = CH.CL[i, j].forward_Y(Y_)
- end
- if CH.split_scales && i < CH.L # don't split after last iteration
- Y, Zy = tensor_split(Y)
- Y_save[i] = Zy
- CH.XY_dims[i] = collect(size(Zy))
- end
- end
- CH.split_scales && (Y = cat_states(Y_save, Y))
- return Y
-end
# Inverse pass and compute gradients
function inverse_Y(Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}
diff --git a/src/networks/invertible_network_hint_multiscale.jl b/src/networks/invertible_network_hint_multiscale.jl
index 8428ecac..e4a73910 100644
--- a/src/networks/invertible_network_hint_multiscale.jl
+++ b/src/networks/invertible_network_hint_multiscale.jl
@@ -67,7 +67,7 @@ end
# Constructor
function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
- split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
+ split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, activation::ActivationFunction=SigmoidLayer(), ndims=2)
AN = Array{ActNorm}(undef, L, K)
CL = Array{CouplingLayerHINT}(undef, L, K)
@@ -83,7 +83,7 @@ function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
for i=1:L
for j=1:K
AN[i, j] = ActNorm(n_in*4; logdet=true)
- CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
+ CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; activation=activation,permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, logdet=true, ndims=ndims)
end
n_in *= channel_factor
diff --git a/src/networks/invertible_network_irim.jl b/src/networks/invertible_network_irim.jl
index 34af85fc..fde4d0ce 100644
--- a/src/networks/invertible_network_irim.jl
+++ b/src/networks/invertible_network_irim.jl
@@ -66,7 +66,7 @@ end
@Flux.functor NetworkLoop
# 2D Constructor
-function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2)
+function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2, activation::ActivationFunction=SigmoidLayer())
if type == "additive"
L = Array{CouplingLayerIRIM}(undef, maxiter)
@@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4,
AN = Array{ActNorm}(undef, maxiter)
for j=1:maxiter
if type == "additive"
- L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
+ L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2,activation=activation, ndims=ndims)
elseif type == "HINT"
L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2,
s1=s1, s2=s2, ndims=ndims)
diff --git a/test/test_layers/test_actnorm.jl b/test/test_layers/test_actnorm.jl
index 852549df..a5a3142c 100644
--- a/test/test_layers/test_actnorm.jl
+++ b/test/test_layers/test_actnorm.jl
@@ -2,8 +2,9 @@
# Date: January 2020
using InvertibleNetworks, LinearAlgebra, Test, Statistics
+using Random
-
+Random.seed!(11)
###############################################################################
# Test logdet implementation