diff --git a/CITATION.bib b/CITATION.bib
index 83daa6a2..39732c66 100644
--- a/CITATION.bib
+++ b/CITATION.bib
@@ -1,6 +1 @@
-@article{orozco2023invertiblenetworks,
- title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows},
- author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J},
- journal={arXiv preprint arXiv:2312.13480},
- year={2023}
-}
\ No newline at end of file
+@article{Orozco2024, doi = {10.21105/joss.06554}, url = {https://doi.org/10.21105/joss.06554}, year = {2024}, publisher = {The Open Journal}, volume = {9}, number = {99}, pages = {6554}, author = {Rafael Orozco and Philipp Witte and Mathias Louboutin and Ali Siahkoohi and Gabrio Rizzuti and Bas Peters and Felix J. Herrmann}, title = {InvertibleNetworks.jl: A Julia package for scalable normalizing flows}, journal = {Journal of Open Source Software} }
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 00000000..bab91ad2
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,50 @@
+cff-version: "1.2.0"
+authors:
+- family-names: Orozco
+ given-names: Rafael
+- family-names: Witte
+ given-names: Philipp
+- family-names: Louboutin
+ given-names: Mathias
+- family-names: Siahkoohi
+ given-names: Ali
+- family-names: Rizzuti
+ given-names: Gabrio
+- family-names: Peters
+ given-names: Bas
+- family-names: Herrmann
+ given-names: Felix J.
+doi: 10.5281/zenodo.12810006
+message: If you use this software, please cite our article in the
+ Journal of Open Source Software.
+preferred-citation:
+ authors:
+ - family-names: Orozco
+ given-names: Rafael
+ - family-names: Witte
+ given-names: Philipp
+ - family-names: Louboutin
+ given-names: Mathias
+ - family-names: Siahkoohi
+ given-names: Ali
+ - family-names: Rizzuti
+ given-names: Gabrio
+ - family-names: Peters
+ given-names: Bas
+ - family-names: Herrmann
+ given-names: Felix J.
+ date-published: 2024-07-30
+ doi: 10.21105/joss.06554
+ issn: 2475-9066
+ issue: 99
+ journal: Journal of Open Source Software
+ publisher:
+ name: Open Journals
+ start: 6554
+ title: "InvertibleNetworks.jl: A Julia package for scalable
+ normalizing flows"
+ type: article
+ url: "https://joss.theoj.org/papers/10.21105/joss.06554"
+ volume: 9
+title: "InvertibleNetworks.jl: A Julia package for scalable normalizing
+ flows"
diff --git a/Project.toml b/Project.toml
index 270b7fcf..445fa874 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte
", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. Herrmann "]
-version = "2.2.8"
+version = "2.3.1"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -15,6 +15,11 @@ Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[compat]
+CUDA = "1, 2, 3, 4, 5"
+cuDNN = "1"
+ChainRulesCore = "0.8, 0.9, 0.10, 1"
+Flux = "0.11, 0.12, 0.13, 0.14"
+NNlib = "0.7, 0.8, 0.9"
TimerOutputs = "0.5"
Wavelets = "0.9, 0.10"
diff --git a/README.md b/README.md
index f832a0ae..dcd72fe8 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,9 @@
# InvertibleNetworks.jl
-| **Documentation** | **Build Status** | |
+| **Documentation** | **Build Status** | **JOSS paper** |
|:-----------------:|:-----------------:|:----------------:|
-|[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/stable/) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/dev/)| [![CI](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml/badge.svg)](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml)| [![DOI](https://zenodo.org/badge/239018318.svg)](https://zenodo.org/badge/latestdoi/239018318)
+|[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/stable/) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/dev/)| [![CI](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml/badge.svg)](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml)| [![DOI](https://joss.theoj.org/papers/10.21105/joss.06554/status.svg)](https://doi.org/10.21105/joss.06554)
+
Building blocks for invertible neural networks in the [Julia] programming language.
@@ -26,7 +27,7 @@ InvertibleNetworks is registered and can be added like any standard Julia packag
## Uncertainty-aware image reconstruction
-Due to its memory scaling InvertibleNetworks.jl has been particularily successful at Bayesian posterior sampling with simulation-based inference. To get started with this application refer to a simple example ([Conditional sampling for MNSIT inpainting](https://github.com/slimgroup/InvertibleNetworks.jl/tree/master/examples/applications/application_conditional_mnist_inpainting.jl)) but feel free to modify this script for your application and please reach out to us if you run into any trouble.
+Due to its memory scaling InvertibleNetworks.jl, has been particularily successful at Bayesian posterior sampling with simulation-based inference. To get started with this application refer to a simple example ([Conditional sampling for MNSIT inpainting](https://github.com/slimgroup/InvertibleNetworks.jl/tree/master/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl)) but feel free to modify this script for your application and please reach out to us for help.
![mnist_sampling_cond](docs/src/figures/mnist_sampling_cond.png)
@@ -92,12 +93,7 @@ Y_, logdet = AN.forward(X)
If you use InvertibleNetworks.jl in your research, we would be grateful if you cite us with the following bibtex:
```
-@article{orozco2023invertiblenetworks,
- title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows},
- author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J},
- journal={arXiv preprint arXiv:2312.13480},
- year={2023}
-}
+@article{Orozco2024, doi = {10.21105/joss.06554}, url = {https://doi.org/10.21105/joss.06554}, year = {2024}, publisher = {The Open Journal}, volume = {9}, number = {99}, pages = {6554}, author = {Rafael Orozco and Philipp Witte and Mathias Louboutin and Ali Siahkoohi and Gabrio Rizzuti and Bas Peters and Felix J. Herrmann}, title = {InvertibleNetworks.jl: A Julia package for scalable normalizing flows}, journal = {Journal of Open Source Software} }
```
diff --git a/examples/applications/application_conditional_hint_seismic.jl b/examples/applications/application_conditional_hint_seismic.jl
deleted file mode 100644
index d862f9d6..00000000
--- a/examples/applications/application_conditional_hint_seismic.jl
+++ /dev/null
@@ -1,162 +0,0 @@
-# Generative model using the change of variables formula
-# Author: Philipp Witte, pwitte3@gatech.edu
-# Date: January 2020
-
-using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test, JLD, Statistics
-import Flux.Optimise.update!
-
-# Random seed
-Random.seed!(66)
-
-
-####################################################################################################
-
-# Load original data X (size of n1 x n2 x nc x ntrain)
-datadir = dirname(pathof(InvertibleNetworks))*"/../data/"
-filename = "seismic_samples_64_by_64_num_10k.jld"
-~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_64_by_64_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`)
-X_orig = load("$(datadir)$(filename)")["X"]
-n1, n2, nc, nsamples = size(X_orig)
-AN = ActNorm(nsamples)
-X_orig = AN.forward(X_orig) # zero mean and unit std
-
-# Split in training - testing
-ntrain = Int(nsamples*.9)
-ntest = nsamples - ntrain
-
-# Dimensions after wavelet squeeze to increase no. of channels
-nx = Int(n1/2)
-ny = Int(n2/2)
-n_in = Int(nc*4)
-
-# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain)
-X_train = zeros(Float32, nx, ny, n_in, ntrain)
-for j=1:ntrain
- X_train[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, j:j])
-end
-
-X_test = zeros(Float32, nx, ny, n_in, ntest)
-for j=1:ntest
- X_test[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, ntrain+j:ntrain+j])
-end
-
-# Create network
-n_hidden = 64
-batchsize = 4
-depth = 8
-CH = NetworkConditionalHINT(n_in, n_hidden, depth)
-Params = get_params(CH)
-
-####################################################################################################
-
-# Loss
-function loss(CH, X, Y)
- Zx, Zy, logdet = CH.forward(X, Y)
- f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet
- ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy))
- ΔZx, ΔZy = tensor_split(ΔZ)
- ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2]
- return f, ΔX, ΔY
-end
-
-# Training
-maxiter = 1000
-opt = Flux.ADAM(1f-3)
-lr_step = 100
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
-fval = zeros(Float32, maxiter)
-
-for j=1:maxiter
-
- # Evaluate objective and gradients
- idx = randperm(ntrain)[1:batchsize]
- X = X_train[:, :, :, idx]
- Y = X + .5f0*randn(Float32, nx, ny, n_in, batchsize)
-
- fval[j] = loss(CH, X, Y)[1]
- mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n"))
-
- # Update params
- for p in Params
- update!(opt, p.data, p.grad)
- update!(lr_decay_fn, p.data, p.grad)
- end
- clear_grad!(CH)
-end
-
-####################################################################################################
-# Plotting
-
-# Testing
-test_size = 100
-idx = randperm(ntest)[1:test_size] # draw random samples from testing data
-X = X_test[:, :, :, idx]
-Y = X + .5f0*randn(Float32, nx, ny, n_in, test_size)
-Zx_, Zy_ = CH.forward(X, Y)[1:2]
-
-Zx = randn(Float32, nx, ny, n_in, test_size)
-Zy = randn(Float32, nx, ny, n_in, test_size)
-X_, Y_ = CH.inverse(Zx, Zy)
-
-# Now select single fixed sample from all Ys
-idx = 1
-X_fixed = X[:, :, :, idx:idx]
-Y_fixed = Y[:, :, :, idx:idx]
-Zy_fixed = CH.forward_Y(Y_fixed)
-
-# Draw new Zx, while keeping Zy fixed
-X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, nx, ny, n_in, test_size))[1]
-
-# Unsqueeze all tensors
-X = wavelet_unsqueeze(X)
-Y = wavelet_unsqueeze(Y)
-Zx_ = wavelet_unsqueeze(Zx_)
-Zy_ = wavelet_unsqueeze(Zy_)
-
-X_ = wavelet_unsqueeze(X_)
-Y_ = wavelet_unsqueeze(Y_)
-Zx = wavelet_unsqueeze(Zx)
-Zy = wavelet_unsqueeze(Zy)
-
-X_fixed = wavelet_unsqueeze(X_fixed)
-Y_fixed = wavelet_unsqueeze(Y_fixed)
-Zy_fixed = wavelet_unsqueeze(Zy_fixed)
-X_post = wavelet_unsqueeze(X_post)
-
-# Plot one sample from X and Y and their latent versions
-figure(figsize=[16,8])
-ax1 = subplot(2,4,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax2 = subplot(2,4,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ")
-ax3 = subplot(2,4,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,4,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$")
-ax5 = subplot(2,4,5); imshow(Zx_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx = f(x|y)$")
-ax6 = subplot(2,4,6); imshow(Zy_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy = f(x|y)$")
-ax7 = subplot(2,4,7); imshow(Zx[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx \sim \hat{p}_{zx}$")
-ax8 = subplot(2,4,8); imshow(Zy[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy \sim \hat{p}_{zy}$")
-
-# Plot various samples from X and Y
-figure(figsize=[16,8])
-i = randperm(test_size)[1:4]
-ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-
-# Plot posterior samples, mean and standard deviation
-figure(figsize=[16,8])
-X_post_mean = mean(X_post; dims=4)
-X_post_std = std(X_post; dims=4)
-ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x")
-ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean")
-ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ")
-ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std));
-colorbar(); title("Posterior std");
-
diff --git a/examples/applications/application_conditional_hint_seismic_linear.jl b/examples/applications/application_conditional_hint_seismic_linear.jl
deleted file mode 100644
index 44d107e4..00000000
--- a/examples/applications/application_conditional_hint_seismic_linear.jl
+++ /dev/null
@@ -1,192 +0,0 @@
-# Generative model using the change of variables formula
-# Author: Philipp Witte, pwitte3@gatech.edu
-# Date: January 2020
-
-using LinearAlgebra, InvertibleNetworks, JOLI, PyPlot, Flux, Random, Test, JLD, Statistics
-import Flux.Optimise.update!
-using Distributions: Uniform
-
-# Random seed
-Random.seed!(66)
-
-####################################################################################################
-
-# Load original data X (size of n1 x n2 x nc x ntrain)
-datadir = dirname(pathof(InvertibleNetworks))*"/../data/"
-filename = "seismic_samples_64_by_64_num_10k.jld"
-~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_64_by_64_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`)
-X_orig = load("$(datadir)$(filename)")["X"]
-n1, n2, nc, nsamples = size(X_orig)
-AN = ActNorm(nsamples)
-X_orig = AN.forward(X_orig) # zero mean and unit std
-
-# Split in training - testing
-ntrain = Int(nsamples*.9)
-ntest = nsamples - ntrain
-
-# Dimensions after wavelet squeeze to increase no. of channels
-nx = Int(n1/2)
-ny = Int(n2/2)
-n_in = Int(nc*4)
-
-# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain)
-X_train = zeros(Float32, nx, ny, n_in, ntrain)
-for j=1:ntrain
- X_train[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, j:j])
-end
-
-X_test = zeros(Float32, nx, ny, n_in, ntest)
-for j=1:ntest
- X_test[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, ntrain+j:ntrain+j])
-end
-
-# Create network
-n_hidden = 32
-batchsize = 4
-depth = 8
-CH = NetworkConditionalHINT(n_in, n_hidden, depth)
-Params = get_params(CH)
-
-# Data modeling function
-function model_data(X, A)
- X = wavelet_unsqueeze(X)
- nx, ny, nc, nb = size(X)
- Y = reshape(A*reshape(X, :, nb), nx, ny, nc, nb)
- Y = wavelet_squeeze(Y)
- return Y
-end
-
-# Forward operator, precompute phase in ambient dimension
-function phasecode(n)
- F = joDFT(n; DDT=Float32,RDT=ComplexF64)
- phase=F*(adjoint(F)*exp.(1im*2*pi*convert(Array{Float32},rand(dist,n))))
- phase = phase ./abs.(phase)
- sgn = sign.(convert(Array{Float32},randn(n)))
- # Return operator
- return M = joDiag(sgn) * adjoint(F) * joDiag(phase)*F
-end
-
-dist = Uniform(-1, 1)
-input_dim = (n1, n2)
-subsamp = 2
-M = phasecode(prod(input_dim))
-R = joRestriction(prod(input_dim),randperm(prod(input_dim))[1:Int(round(prod(input_dim)/subsamp))]; DDT=Float32, RDT=Float32);
-A_flat = R*M;
-A = A_flat'*A_flat
-
-Y_train = model_data(X_train, A)
-Y_test = model_data(X_test, A)
-
-####################################################################################################
-
-# Loss
-function loss(CH, X, Y)
- Zx, Zy, logdet = CH.forward(X, Y)
- f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet
- ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy))
- ΔZx, ΔZy = tensor_split(ΔZ)
- ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2]
- return f, ΔX, ΔY
-end
-
-# Training
-maxiter = 1000
-opt = Flux.ADAM(1f-3)
-lr_step = 100
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
-fval = zeros(Float32, maxiter)
-
-for j=1:maxiter
-
- # Evaluate objective and gradients
- idx = randperm(ntrain)[1:batchsize]
- X = X_train[:, :, :, idx]
- Y = Y_train[:, :, :, idx]
-
- fval[j] = loss(CH, X, Y)[1]
- mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n"))
-
- # Update params
- for p in Params
- update!(opt, p.data, p.grad)
- update!(lr_decay_fn, p.data, p.grad)
- end
- clear_grad!(CH)
-end
-
-####################################################################################################
-# Plotting
-
-# Testing
-test_size = 100
-idx = randperm(ntest)[1:test_size] # draw random samples from testing data
-X = X_test[:, :, :, idx]
-Y = Y_test[:, :, :, idx]
-Zx_, Zy_ = CH.forward(X, Y)[1:2]
-
-Zx = randn(Float32, nx, ny, n_in, test_size)
-Zy = randn(Float32, nx, ny, n_in, test_size)
-X_, Y_ = CH.inverse(Zx, Zy)
-
-# Now select single fixed sample from all Ys
-idx = 1
-X_fixed = X[:, :, :, idx:idx]
-Y_fixed = Y[:, :, :, idx:idx]
-Zy_fixed = CH.forward_Y(Y_fixed)
-
-# Draw new Zx, while keeping Zy fixed
-X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, nx, ny, n_in, test_size))[1]
-
-# Unsqueeze all tensors
-X = wavelet_unsqueeze(X)
-Y = wavelet_unsqueeze(Y)
-Zx_ = wavelet_unsqueeze(Zx_)
-Zy_ = wavelet_unsqueeze(Zy_)
-
-X_ = wavelet_unsqueeze(X_)
-Y_ = wavelet_unsqueeze(Y_)
-Zx = wavelet_unsqueeze(Zx)
-Zy = wavelet_unsqueeze(Zy)
-
-X_fixed = wavelet_unsqueeze(X_fixed)
-Y_fixed = wavelet_unsqueeze(Y_fixed)
-Zy_fixed = wavelet_unsqueeze(Zy_fixed)
-X_post = wavelet_unsqueeze(X_post)
-
-# Plot one sample from X and Y and their latent versions
-figure(figsize=[16,8])
-ax1 = subplot(2,4,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax2 = subplot(2,4,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ")
-ax3 = subplot(2,4,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,4,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$")
-ax5 = subplot(2,4,5); imshow(Zx_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx = f(x|y)$")
-ax6 = subplot(2,4,6); imshow(Zy_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy = f(x|y)$")
-ax7 = subplot(2,4,7); imshow(Zx[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx \sim \hat{p}_{zx}$")
-ax8 = subplot(2,4,8); imshow(Zy[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy \sim \hat{p}_{zy}$")
-
-# Plot various samples from X and Y
-figure(figsize=[16,8])
-i = randperm(test_size)[1:4]
-ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-
-# Plot posterior samples, mean and standard deviation
-figure(figsize=[16,8])
-X_post_mean = mean(X_post; dims=4)
-X_post_std = std(X_post; dims=4)
-ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x")
-ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean")
-ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ")
-ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std));
-colorbar(); title("Posterior std");
-
diff --git a/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl b/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl
deleted file mode 100644
index 7aa80a8e..00000000
--- a/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl
+++ /dev/null
@@ -1,177 +0,0 @@
-# Generative model using the change of variables formula
-# Author: Philipp Witte, pwitte3@gatech.edu
-# Date: January 2020
-
-using LinearAlgebra, InvertibleNetworks, JOLI, PyPlot, Flux, Random, Test, JLD, Statistics
-import Flux.Optimise.update!
-using Distributions: Uniform
-
-# Random seed
-Random.seed!(66)
-
-####################################################################################################
-
-# Load original data X (size of n1 x n2 x nc x ntrain)
-#X_orig = load("../../data/seismic_samples_32_by_32_num_10k.jld")["X"]
-
-
-# Load original data X (size of n1 x n2 x nc x ntrain)
-datadir = dirname(pathof(InvertibleNetworks))*"/../data/"
-filename = "seismic_samples_64_by_64_num_10k.jld"
-~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_32_by_32_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`)
-X_orig = load("$(datadir)$(filename)")["X"]
-n1, n2, nc, n_samples = size(X_orig)
-
-AN = ActNorm(n_samples)
-X_orig = AN.forward(X_orig) # zero mean and unit std
-
-# Split in training - testing
-ntrain = Int(n_samples*.9)
-ntest = n_samples - ntrain
-
-# Dimensions after wavelet squeeze to increase no. of channels
-nx = Int(n1)
-ny = Int(n2)
-n_in = Int(nc)
-
-# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain)
-X_train = zeros(Float32, nx, ny, n_in, ntrain)
-for j=1:ntrain
- X_train[:, :, :, j:j] = X_orig[:, :, :, j:j]
-end
-
-X_test = zeros(Float32, nx, ny, n_in, ntest)
-for j=1:ntest
- X_test[:, :, :, j:j] = X_orig[:, :, :, ntrain+j:ntrain+j]
-end
-
-# Create network
-n_hidden = 32
-batchsize = 4
-L = 1
-K = 8
-CH = NetworkMultiScaleConditionalHINT(n_in, n_hidden, L, K; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, split_scales=false)
-Params = get_params(CH)
-
-# Data modeling function
-function model_data(X, A)
- nx, ny, nc, nb = size(X)
- Y = reshape(A*reshape(X, :, nb), nx, ny, nc, nb)
- return Y
-end
-
-# Forward operator, precompute phase in ambient dimension
-function phasecode(n)
- F = joDFT(n; DDT=Float32,RDT=ComplexF64)
- phase=F*(adjoint(F)*exp.(1im*2*pi*convert(Array{Float32},rand(dist,n))))
- phase = phase ./abs.(phase)
- sgn = sign.(convert(Array{Float32},randn(n)))
- # Return operator
- return M = joDiag(sgn) * adjoint(F) * joDiag(phase)*F
-end
-
-# Generate observed data
-dist = Uniform(-1, 1)
-input_dim = (n1, n2)
-subsamp = 2
-M = phasecode(prod(input_dim))
-R = joRestriction(prod(input_dim),randperm(prod(input_dim))[1:Int(round(prod(input_dim)/subsamp))]; DDT=Float32, RDT=Float32);
-A_flat = R*M;
-A = A_flat'*A_flat
-
-Y_train = model_data(X_train, A)
-Y_test = model_data(X_test, A)
-
-####################################################################################################
-
-# Loss
-function loss(CH, X, Y)
- Zx, Zy, logdet = CH.forward(X, Y)
- f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet
- ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy))
- ΔZx, ΔZy = tensor_split(ΔZ)
- ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2]
- return f, ΔX, ΔY
-end
-
-# Training
-maxiter = 500
-opt = Flux.ADAM(1f-3)
-lr_step = 100
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
-fval = zeros(Float32, maxiter)
-
-for j=1:maxiter
-
- # Evaluate objective and gradients
- idx = randperm(ntrain)[1:batchsize]
- X = X_train[:, :, :, idx]
- Y = Y_train[:, :, :, idx]
-
- fval[j] = loss(CH, X, Y)[1]
- mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n"))
-
- # Update params
- for p in Params
- update!(opt, p.data, p.grad)
- update!(lr_decay_fn, p.data, p.grad)
- end
- clear_grad!(CH)
-end
-
-####################################################################################################
-# Plotting
-
-# Testing
-test_size = 100
-idx = randperm(ntest)[1:test_size] # draw random samples from testing data
-X = X_test[:, :, :, idx]
-Y = Y_test[:, :, :, idx]
-Zx_, Zy_ = CH.forward(X, Y)[1:2]
-
-Zx = randn(Float32, size(Zx_))
-Zy = randn(Float32, size(Zy_))
-X_, Y_ = CH.inverse(Zx, Zy)
-
-# Now select single fixed sample from all Ys
-idx = 1
-X_fixed = X[:, :, :, idx:idx]
-Y_fixed = Y[:, :, :, idx:idx]
-Zy_fixed = CH.forward_Y(Y_fixed)
-
-# Draw new Zx, while keeping Zy fixed
-CH.forward_Y(X) # set X dimensions in forward pass (this needs to be fixed)
-X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, size(Zx_)))[1]
-
-# Plot one sample from X and Y and their latent versions
-figure(figsize=[8,8])
-ax1 = subplot(2,2,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax2 = subplot(2,2,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ")
-ax3 = subplot(2,2,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,2,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$")
-
-# Plot various samples from X and Y
-figure(figsize=[16,8])
-i = randperm(test_size)[1:4]
-ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$")
-ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$")
-
-# Plot posterior samples, mean and standard deviation
-figure(figsize=[16,8])
-X_post_mean = mean(X_post; dims=4)
-X_post_std = std(X_post; dims=4)
-ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x")
-ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean")
-ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ")
-ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$")
-ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std));
-colorbar(); title("Posterior std");
\ No newline at end of file
diff --git a/examples/applications/amortized_posterior_inference_1d_gaussian.jl b/examples/applications/conditional_sampling/amortized_glow_1d_gaussian.jl
similarity index 100%
rename from examples/applications/amortized_posterior_inference_1d_gaussian.jl
rename to examples/applications/conditional_sampling/amortized_glow_1d_gaussian.jl
diff --git a/examples/applications/application_conditional_mnist_inpainting.jl b/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl
similarity index 67%
rename from examples/applications/application_conditional_mnist_inpainting.jl
rename to examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl
index 10b5167d..40ecf1a9 100644
--- a/examples/applications/application_conditional_mnist_inpainting.jl
+++ b/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl
@@ -1,16 +1,20 @@
+# Take around 6 minutes on CPU
using Pkg
-Pkg.activate(".")
-# Take around 6 minutes on CPU
+Pkg.add("InvertibleNetworks")
+Pkg.add("Flux")
+Pkg.add("MLDatasets")
+Pkg.add("ProgressMeter")
+Pkg.add("MLUtils")
+Pkg.add("ImageTransformations")
+
using InvertibleNetworks
using Flux
using LinearAlgebra
using MLDatasets
-using Statistics
-using PyPlot
-using ProgressMeter: Progress, next!
-using Images
+using ImageTransformations
using MLUtils
+using ProgressMeter: Progress, next!
function posterior_sampler(G, y, size_x; device=gpu, num_samples=1, batch_size=16)
# make samples from posterior for train sample
@@ -111,52 +115,35 @@ for e=1:epochs # epoch loop
append!(loss_val, norm(ZX)^2 / (N*n_val) - logdet_i / N) # normalize by image size and batch size
end
-# Training logs
-final_obj_train = round(loss_train[end];digits=3)
-final_obj_val = round(loss_val[end];digits=3)
-fig = figure()
-title("Objective value: train=$(final_obj_train) validation=$(final_obj_val)")
+Pkg.add("Statistics")
+Pkg.add("Plots")
+using Statistics
+using Plots
+
+# Training logs
plot(loss_train;label="Train");
plot(batches:batches:batches*(epochs), loss_val;label="Validation");
-xlabel("Parameter update"); ylabel("Negative log likelihood objective") ;
-legend()
-savefig("log.png",dpi=300)
-
-# Make Figure of README
-num_plot = 2
-fig = figure(figsize=(11, 5));
-for (i,ind) in enumerate([1,3])
- x = XY_val[1][:,:,:,ind:ind]
- y = XY_val[2][:,:,:,ind:ind]
- X_post = posterior_sampler(G, y, size(x); device=device, num_samples=64) |> cpu
-
- X_post_mean = mean(X_post; dims=ndims(X_post))
- X_post_var = var(X_post; dims=ndims(X_post))
+xlabel!("Parameter update"); ylabel!("Negative log likelihood objective") ;
+savefig("log.png")
- ssim_val = round(assess_ssim(X_post_mean[:,:,1,1], x[:,:,1,1]) ,digits=2)
+# Make Figure like README
+ind = 1
+x = XY_val[1][:,:,:,ind:ind]
+y = XY_val[2][:,:,:,ind:ind]
+X_post = posterior_sampler(G, y, size(x); device=device, num_samples=64) |> cpu
- subplot(num_plot,7,1+7*(i-1)); imshow(x[:,:,1,1], vmin=0, vmax=1, cmap="gray")
- axis("off"); title(L"$x$");
+X_post_mean = mean(X_post; dims=ndims(X_post))
+X_post_var = var(X_post; dims=ndims(X_post))
- subplot(num_plot,7,2+7*(i-1)); imshow(y[:,:,1,1] |> cpu, cmap="gray")
- axis("off"); title(L"$y$");
+p1 = heatmap(x[:,:,1,1];title="Ground truth",cbar=false,clims = (0, 1),axis=([], false))
+p2 = heatmap(y[:,:,1,1];title="Observation",cbar=false,clims = (0, 1),axis=([], false))
+p3 = heatmap(X_post_mean[:,:,1,1];title="Posterior mean",cbar=false,clims = (0, 1),axis=([], false))
+p4 = heatmap(X_post_var[:,:,1,1];title="Posterior variance",cbar=false,axis=([], false))
+p5 = heatmap(X_post[:,:,1,1];title="Posterior sample",cbar=false,clims = (0, 1),axis=([], false))
+p6 = heatmap(X_post[:,:,1,2];title="Posterior sample",cbar=false,clims = (0, 1),axis=([], false))
- subplot(num_plot,7,3+7*(i-1)); imshow(X_post_mean[:,:,1,1] , vmin=0, vmax=1, cmap="gray")
- axis("off"); title("SSIM="*string(ssim_val)*" \n"*"Conditional Mean") ;
+plot(p1,p2,p3,p4,p5,p6,aspect_ratio=:equal, size=(600,400))
+savefig("posterior_sampling.png")
- subplot(num_plot,7,4+7*(i-1)); imshow(X_post_var[:,:,1,1] , cmap="magma")
- axis("off"); title(L"$UQ$") ;
-
- subplot(num_plot,7,5+7*(i-1)); imshow(X_post[:,:,1,1] |> cpu, vmin=0, vmax=1,cmap="gray")
- axis("off"); title("Posterior Sample") ;
-
- subplot(num_plot,7,6+7*(i-1)); imshow(X_post[:,:,1,2] |> cpu, vmin=0, vmax=1,cmap="gray")
- axis("off"); title("Posterior Sample") ;
-
- subplot(num_plot,7,7+7*(i-1)); imshow(X_post[:,:,1,3] |> cpu, vmin=0, vmax=1, cmap="gray")
- axis("off"); title("Posterior Sample") ;
-end
-tight_layout()
-savefig("mnist_sampling_cond.png",dpi=300,bbox_inches="tight")
diff --git a/examples/applications/application_conditional_hint_banana.jl b/examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl
similarity index 85%
rename from examples/applications/application_conditional_hint_banana.jl
rename to examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl
index 3617f2c8..c3b8949b 100644
--- a/examples/applications/application_conditional_hint_banana.jl
+++ b/examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl
@@ -1,6 +1,12 @@
-# Generative model using the change of variables formula
+# Method: Amortized posterior sampler / simulation based inference / Forwad KL variational inference
+# Application: sample from conditional distribution given noisy observations of the rosenbrock distribution.
+# Note: we currently recommend conditional glow architectures instead of HINT, unless you need the latent space of
+# the observation.
+
# Author: Philipp Witte, pwitte3@gatech.edu
# Date: January 2020
+using Pkg
+Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot")
using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test
import Flux.Optimise.update!
@@ -35,9 +41,7 @@ end
# Training
maxiter = 1000
-opt = Flux.ADAM(1f-3)
-lr_step = 100
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
+opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 100, 0.), Flux.ADAM(1f-3))
fval = zeros(Float32, maxiter)
for j=1:maxiter
@@ -51,7 +55,6 @@ for j=1:maxiter
# Update params
for p in get_params(H)
update!(opt, p.data, p.grad)
- update!(lr_decay_fn, p.data, p.grad)
end
clear_grad!(H)
end
@@ -96,5 +99,4 @@ plot(Y_fixed[1, 1, 1, :], Y_fixed[1, 1, 2, :], "r."); title(L"Model space: $x =
ax7.set_xlim([-3.5,3.5]); ax7.set_ylim([0,50])
ax8 = subplot(2,4,8); plot(Zx[1, 1, 1, :], Zx[1, 1, 2, :], ".");
plot(Zy_fixed[1, 1, 1, :], Zy_fixed[1, 1, 2, :], "r."); title(L"Latent space: $zx \sim \hat{p}_{zx}$")
-ax8.set_xlim([-3.5, 3.5]); ax8.set_ylim([-3.5, 3.5])
-
+ax8.set_xlim([-3.5, 3.5]); ax8.set_ylim([-3.5, 3.5])
\ No newline at end of file
diff --git a/examples/applications/application_conditional_hint_banana_linear.jl b/examples/applications/conditional_sampling/amortized_hint_banana_linear.jl
similarity index 87%
rename from examples/applications/application_conditional_hint_banana_linear.jl
rename to examples/applications/conditional_sampling/amortized_hint_banana_linear.jl
index 51a04a73..2ae6af7a 100644
--- a/examples/applications/application_conditional_hint_banana_linear.jl
+++ b/examples/applications/conditional_sampling/amortized_hint_banana_linear.jl
@@ -1,6 +1,14 @@
-# Generative model using the change of variables formula
+# Method: Amortized posterior sampler / simulation based inference / Forwad KL variational inference
+# Application: sample from conditional distribution given observations of the rosenbrock distribution passed through
+# a linear operator.
+# Note: we currently recommend conditional glow architectures instead of HINT, unless you need the latent space of
+# the observatin.
+
# Author: Philipp Witte, pwitte3@gatech.edu
# Date: January 2020
+using Pkg
+
+Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot")
using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test
import Flux.Optimise.update!
@@ -39,13 +47,10 @@ end
# Training
maxiter = 1000
-opt = Flux.ADAM(1f-3)
-lr_step = 100
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
+opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 100, 0.), Flux.ADAM(1f-3))
fval = zeros(Float32, maxiter)
for j=1:maxiter
-
# Evaluate objective and gradients
X = sample_banana(batchsize)
Y = reshape(A*reshape(X, :, batchsize), nx, ny, n_in, batchsize)
@@ -57,7 +62,6 @@ for j=1:maxiter
# Update params
for p in get_params(H)
update!(opt, p.data, p.grad)
- update!(lr_decay_fn, p.data, p.grad)
end
clear_grad!(H)
end
@@ -114,7 +118,3 @@ plot(Zy_fixed[1, 1, 1, :], Zy_fixed[1, 1, 2, :], "r."); title(L"Latent space: $z
ax10.set_xlim([-3.5, 3.5]); ax10.set_ylim([-3.5, 3.5])
-
-
-
-
diff --git a/examples/applications/application_hint_scenario_2.jl b/examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl
similarity index 96%
rename from examples/applications/application_hint_scenario_2.jl
rename to examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl
index fdd9a3c9..f06d4cf5 100644
--- a/examples/applications/application_hint_scenario_2.jl
+++ b/examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl
@@ -2,12 +2,17 @@
# Obtaining samples from posterior for the following problem:
# y = Ax + ϵ, x ~ N(μ_x, Σ_x), ϵ ~ N(μ_ϵ, Σ_ϵ), A ~ N(0, I/|x|)
+using Pkg
+Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot");
+Pkg.add("Distributions"); Pkg.add("Printf")
+
using InvertibleNetworks, LinearAlgebra, Test
using Distributions
import Flux.Optimise.update!, Flux
using PyPlot
using Printf
using Random
+
Random.seed!(19)
# Model and data dimension
@@ -102,9 +107,7 @@ end
# Optimizer
η = 0.01
-opt = Flux.ADAM(η)
-lr_step = 2000
-lr_decay_fn = Flux.ExpDecay(η, .3, lr_step, 0.)
+opt = Flux.Optimiser(Flux.ExpDecay(η, .3, 2000, 0.), Flux.ADAM(η))
# Loss function
function loss(z_in, y_in)
diff --git a/examples/applications/application_conditionalScenario2_hint_banana.jl b/examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl
similarity index 95%
rename from examples/applications/application_conditionalScenario2_hint_banana.jl
rename to examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl
index 71eee2aa..2887d753 100644
--- a/examples/applications/application_conditionalScenario2_hint_banana.jl
+++ b/examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl
@@ -2,6 +2,9 @@
# Author: Philipp Witte, pwitte3@gatech.edu
# Date: January 2020
+using Pkg
+Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot");
+
using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random
import Flux.Optimise.update!
@@ -88,9 +91,8 @@ end
# Training
maxiter = 1000
-opt = Flux.ADAM(1f-3)
-lr_step = 50
-lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.)
+opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 50, 0.), Flux.ADAM(1f-3))
+
fval = zeros(Float32, maxiter)
for j=1:maxiter
diff --git a/examples/applications/application_glow_banana_dist.jl b/examples/applications/non_conditional_sampling/glow_banana.jl
similarity index 100%
rename from examples/applications/application_glow_banana_dist.jl
rename to examples/applications/non_conditional_sampling/glow_banana.jl
diff --git a/examples/applications/generative_sampling.jl b/examples/applications/non_conditional_sampling/glow_seismic.jl
similarity index 100%
rename from examples/applications/generative_sampling.jl
rename to examples/applications/non_conditional_sampling/glow_seismic.jl
diff --git a/examples/applications/application_hint_banana_dist.jl b/examples/applications/non_conditional_sampling/hint_banana.jl
similarity index 100%
rename from examples/applications/application_hint_banana_dist.jl
rename to examples/applications/non_conditional_sampling/hint_banana.jl
diff --git a/examples/benchmarks/memory_usage_invertiblenetworks.jl b/examples/benchmarks/memory_usage_invertiblenetworks.jl
new file mode 100644
index 00000000..d241d5af
--- /dev/null
+++ b/examples/benchmarks/memory_usage_invertiblenetworks.jl
@@ -0,0 +1,119 @@
+# Test memory usage of InvertibleNetworks as network depth increases
+
+
+using InvertibleNetworks, LinearAlgebra, Flux
+import Flux.Optimise.update!
+
+device = InvertibleNetworks.CUDA.functional() ? gpu : cpu
+
+#turn off JULIA cuda optimization to get raw peformance
+ENV["JULIA_CUDA_MEMORY_POOL"] = "none"
+
+using CUDA, Printf
+
+export @gpumem
+
+get_mem() = NVML.memory_info(NVML.Device(0)).used/(1024^3)
+
+function montitor_gpu_mem(used::Vector{T}, status::Ref) where {T<:Real}
+ while status[]
+ #cleanup()
+ push!(used, get_mem())
+ end
+ nothing
+end
+
+cleanup() = begin GC.gc(true); CUDA.reclaim(); end
+
+macro gpumem(expr)
+ return quote
+ # Cleanup
+ cleanup()
+ monitoring = Ref(true)
+ used = [get_mem()]
+ Threads.@spawn montitor_gpu_mem(used, monitoring)
+ val = $(esc(expr))
+ monitoring[] = false
+ cleanup()
+ @printf("Min memory: %1.3fGiB , Peak memory: %1.3fGiB \n",
+ extrema(used)...)
+ used
+ end
+end
+
+# Objective function
+function loss(G,X)
+ Y, logdet = G.forward(X)
+ #cleanup()
+ f = .5f0/batchsize*norm(Y)^2 - logdet
+ ΔX, X_ = G.backward(1f0./batchsize*Y, Y)
+ return f
+end
+
+# size of network input
+nx = 256
+ny = nx
+n_in = 3
+batchsize = 8
+X = rand(Float32, nx, ny, n_in, batchsize) |> device
+
+# Define network
+n_hidden = 256
+L = 3 # number of scales
+
+num_retests=1
+mems_max= []
+mem_tested = [4,8,16,32,48,64,80]
+for K in mem_tested
+ G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true) |> device
+
+ loss(G,X)
+ curr_maxes = []
+ for i in 1:num_retests
+ usedmem = @gpumem begin
+ loss(G,X)
+ end
+ append!(curr_maxes,maximum(usedmem))
+ end
+ append!(mems_max, minimum(curr_maxes))
+ println(mems_max)
+end
+
+# control for memory of storing the network parameters on GPU, not relevant to backpropagation
+mems_model = []
+for K in mem_tested
+ G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true) |> device
+ G(X)
+ G = G |> cpu
+ usedmem = @gpumem begin
+ G = G |> device
+ end
+
+ append!(mems_model, maximum(usedmem))
+end
+mems_model_norm = mems_model .- mems_model[1]
+
+
+mem_used_invnets = mems_max .- mems_model_norm
+mem_used_pytorch = [5.0897216796875, 7.8826904296875, 13.5487060546875, 24.8709716796875, 36.2010498046875, 40, NaN]
+mem_ticks = mem_tested
+
+
+using PyPlot
+
+font_size=15
+PyPlot.rc("font", family="serif");
+PyPlot.rc("font", family="serif", size=font_size); PyPlot.rc("xtick", labelsize=font_size); PyPlot.rc("ytick", labelsize=font_size);
+PyPlot.rc("axes", labelsize=font_size) # fontsize of the x and y labels
+
+#nice plot
+fig = figure(figsize=(14,6))
+plot(log.(mem_tested),mem_used_pytorch; color="black", linestyle="--",markevery=collect(range(0,4,step=1)),label="PyTorch package",marker="o",markerfacecolor="r")
+plot(log.(mem_tested),mem_used_invnets; color="black", label="InvertibleNetworks.jl package",marker="o",markerfacecolor="b")
+axvline(log.(mem_tested)[end-1], linestyle="--",color="red",label="PyTorch out of memory error")
+grid()
+xticks(log.(mem_ticks), [string(i) for i in mem_ticks],rotation=60)
+legend()
+xlabel("Network depth [# of layers]");
+ylabel("Peak memory used [GB]");
+fig.savefig("mem_used_new_depth_new.png", bbox_inches="tight", dpi=400)
diff --git a/examples/benchmarks/memory_usage_normflows.py b/examples/benchmarks/memory_usage_normflows.py
new file mode 100644
index 00000000..adaecf7a
--- /dev/null
+++ b/examples/benchmarks/memory_usage_normflows.py
@@ -0,0 +1,138 @@
+# Test memory usage of normflows as network depth increases
+
+#salloc -A rafael -t00:80:00 --gres=gpu:1 --mem-per-cpu=30G srun --pty python
+
+import torch
+import nvidia_smi
+
+def _get_gpu_mem(synchronize=True, empty_cache=True):
+ handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
+ mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
+ return mem.used
+
+def _generate_mem_hook(handle_ref, mem, idx, hook_type, exp):
+ def hook(self, *args):
+ if hook_type == 'pre':
+ return
+ if len(mem) == 0 or mem[-1]["exp"] != exp:
+ call_idx = 0
+ else:
+ call_idx = mem[-1]["call_idx"] + 1
+ mem_all = _get_gpu_mem()
+ torch.cuda.synchronize()
+ lname = type(self).__name__
+ lname = 'conv' if 'conv' in lname.lower() else lname
+ lname = 'ReLU' if 'relu' in lname.lower() else lname
+ mem.append({
+ 'layer_idx': idx,
+ 'call_idx': call_idx,
+ 'layer_type': f"{lname}_{hook_type}",
+ 'exp': exp,
+ 'hook_type': hook_type,
+ 'mem_all': mem_all,
+ })
+ return hook
+
+
+def _add_memory_hooks(idx, mod, mem_log, exp, hr):
+ h = mod.register_forward_pre_hook(_generate_mem_hook(hr, mem_log, idx, 'pre', exp))
+ hr.append(h)
+ h = mod.register_forward_hook(_generate_mem_hook(hr, mem_log, idx, 'fwd', exp))
+ hr.append(h)
+ h = mod.register_backward_hook(_generate_mem_hook(hr, mem_log, idx, 'bwd', exp))
+ hr.append(h)
+
+
+def log_mem(model, inp, mem_log=None, exp=None):
+ nvidia_smi.nvmlInit()
+ mem_log = mem_log or []
+ exp = exp or f'exp_{len(mem_log)}'
+ hr = []
+ for idx, module in enumerate(model.modules()):
+ _add_memory_hooks(idx, module, mem_log, exp, hr)
+ try:
+ out = model(inp)
+ loss = out.sum()
+ loss.backward()
+ except Exception as e:
+ print(f"Errored with error {e}")
+ finally:
+ [h.remove() for h in hr]
+ return mem_log
+
+# Used this in InvertibleNetowrks.jl so should use here. Though didnt see much difference in performance
+import os
+os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = "1"
+
+import pandas as pd
+import torch
+import torchvision as tv
+import numpy as np
+import normflows as nf
+from matplotlib import pyplot as plt
+from tqdm import tqdm
+
+torch.manual_seed(0)
+
+L = 3
+nx = 256
+max_mems = []
+
+for depth in [4,8,16,32,48,64]:
+ print(depth)
+ K = depth
+ input_shape = (3, nx, nx)
+ n_dims = np.prod(input_shape)
+ channels = 3
+ hidden_channels = 256
+ split_mode = 'channel'
+ scale = True
+
+ # Set up flows, distributions and merge operations
+ q0 = []
+ merges = []
+ flows = []
+ for i in range(L):
+ flows_ = []
+ for j in range(K):
+ flows_ += [nf.flows.GlowBlock(channels * 2 ** (L + 1 - i), hidden_channels,
+ split_mode=split_mode, scale=scale)]
+ flows_ += [nf.flows.Squeeze()]
+ flows += [flows_]
+ if i > 0:
+ merges += [nf.flows.Merge()]
+ latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i),
+ input_shape[2] // 2 ** (L - i))
+ else:
+ latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L,
+ input_shape[2] // 2 ** L)
+ q0 += [nf.distributions.DiagGaussian(latent_shape,trainable=False)]
+ # Construct flow model with the multiscale architecture
+ model = nf.MultiscaleFlow(q0, flows, merges)
+ enable_cuda = True
+ device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
+ model = model.to(device)
+ batch_size = 8
+ transform = tv.transforms.Compose([tv.transforms.ToTensor(), nf.utils.Scale(255. / 256.), nf.utils.Jitter(1 / 256.),tv.transforms.Resize(size=input_shape[1])])
+ train_data = tv.datasets.CIFAR10('datasets/', train=True,
+ download=True, transform=transform)
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
+ drop_last=True)
+ train_iter = iter(train_loader)
+ optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5)
+ x, y = next(train_iter)
+
+ mem_log = []
+ try:
+ mem_log.extend(log_mem(model, x.to(device), exp='std'))
+ except Exception as e:
+ print(f'log_mem failed because of {e}')
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ df = pd.DataFrame(mem_log)
+ max_mem = df['mem_all'].max()/(1024**3)
+ print("Peak memory usage: %.2f Gb" % (max_mem,))
+ max_mems.append(max_mem)
+
+# >>> max_mems
+# [5.0897216796875, 7.8826904296875, 13.5487060546875, 24.8709716796875, 36.2010498046875, 39.9959716796875]
\ No newline at end of file
diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl
index 7d0d69e4..35538fbe 100644
--- a/src/conditional_layers/conditional_layer_glow.jl
+++ b/src/conditional_layers/conditional_layer_glow.jl
@@ -78,7 +78,12 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze
# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in; freeze=freeze_conv)
- RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
+
+ split_num = Int(round(n_in/2))
+ in_split = n_in-split_num
+ out_chan = 2*split_num
+
+ RB = ResidualBlock(in_split+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
return ConditionalLayerGlow(C, RB, logdet, activation)
end
@@ -143,7 +148,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA
# Backpropagate RB
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
- ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
+ ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1])
ΔX2 += ΔY2
# Backpropagate 1x1 conv
diff --git a/test/runtests.jl b/test/runtests.jl
index 4dd9dce0..2d63b6cf 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -16,7 +16,7 @@ basics = ["test_utils/test_objectives.jl",
"test_utils/test_nnlib_convolution.jl",
"test_utils/test_activations.jl",
"test_utils/test_squeeze.jl",
- "test_utils/test_jacobian.jl",
+ #"test_utils/test_jacobian.jl",
"test_utils/test_chainrules.jl",
"test_utils/test_flux.jl"]
diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl
index a2a0380d..946b39b0 100644
--- a/test/test_networks/test_conditional_glow_network.jl
+++ b/test/test_networks/test_conditional_glow_network.jl
@@ -9,6 +9,49 @@ device = InvertibleNetworks.CUDA.functional() ? gpu : cpu
# Random seed
Random.seed!(3);
+# Define network
+nx = 32; ny = 32; nz = 32
+n_in = 3
+n_cond = 3
+n_hidden = 4
+batchsize = 2
+L = 2
+K = 2
+split_scales = false
+N = (nx,ny)
+
+########################################### Test with split_scales = false N = (nx,ny) #########################
+# Invertibility
+
+# Network and input
+G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device
+X = rand(Float32, N..., n_in, batchsize) |> device
+Cond = rand(Float32, N..., n_cond, batchsize) |> device
+
+Y, Cond = G.forward(X,Cond)
+X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes
+
+@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
+
+# Test gradients are set and cleared
+G.backward(Y, Y, Cond)
+
+P = get_params(G)
+gsum = 0
+for p in P
+ ~isnothing(p.grad) && (global gsum += 1)
+end
+@test isequal(gsum, L*K*10+2)
+
+clear_grad!(G)
+gsum = 0
+for p in P
+ ~isnothing(p.grad) && (global gsum += 1)
+end
+@test isequal(gsum, 0)
+
+
+Random.seed!(3);
# Define network
nx = 32; ny = 32; nz = 32
n_in = 2