diff --git a/CITATION.bib b/CITATION.bib index 83daa6a2..39732c66 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,6 +1 @@ -@article{orozco2023invertiblenetworks, - title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows}, - author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J}, - journal={arXiv preprint arXiv:2312.13480}, - year={2023} -} \ No newline at end of file +@article{Orozco2024, doi = {10.21105/joss.06554}, url = {https://doi.org/10.21105/joss.06554}, year = {2024}, publisher = {The Open Journal}, volume = {9}, number = {99}, pages = {6554}, author = {Rafael Orozco and Philipp Witte and Mathias Louboutin and Ali Siahkoohi and Gabrio Rizzuti and Bas Peters and Felix J. Herrmann}, title = {InvertibleNetworks.jl: A Julia package for scalable normalizing flows}, journal = {Journal of Open Source Software} } diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..bab91ad2 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,50 @@ +cff-version: "1.2.0" +authors: +- family-names: Orozco + given-names: Rafael +- family-names: Witte + given-names: Philipp +- family-names: Louboutin + given-names: Mathias +- family-names: Siahkoohi + given-names: Ali +- family-names: Rizzuti + given-names: Gabrio +- family-names: Peters + given-names: Bas +- family-names: Herrmann + given-names: Felix J. +doi: 10.5281/zenodo.12810006 +message: If you use this software, please cite our article in the + Journal of Open Source Software. +preferred-citation: + authors: + - family-names: Orozco + given-names: Rafael + - family-names: Witte + given-names: Philipp + - family-names: Louboutin + given-names: Mathias + - family-names: Siahkoohi + given-names: Ali + - family-names: Rizzuti + given-names: Gabrio + - family-names: Peters + given-names: Bas + - family-names: Herrmann + given-names: Felix J. + date-published: 2024-07-30 + doi: 10.21105/joss.06554 + issn: 2475-9066 + issue: 99 + journal: Journal of Open Source Software + publisher: + name: Open Journals + start: 6554 + title: "InvertibleNetworks.jl: A Julia package for scalable + normalizing flows" + type: article + url: "https://joss.theoj.org/papers/10.21105/joss.06554" + volume: 9 +title: "InvertibleNetworks.jl: A Julia package for scalable normalizing + flows" diff --git a/Project.toml b/Project.toml index 270b7fcf..445fa874 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InvertibleNetworks" uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3" authors = ["Philipp Witte ", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. Herrmann "] -version = "2.2.8" +version = "2.3.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -15,6 +15,11 @@ Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +CUDA = "1, 2, 3, 4, 5" +cuDNN = "1" +ChainRulesCore = "0.8, 0.9, 0.10, 1" +Flux = "0.11, 0.12, 0.13, 0.14" +NNlib = "0.7, 0.8, 0.9" TimerOutputs = "0.5" Wavelets = "0.9, 0.10" diff --git a/README.md b/README.md index f832a0ae..dcd72fe8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # InvertibleNetworks.jl -| **Documentation** | **Build Status** | | +| **Documentation** | **Build Status** | **JOSS paper** | |:-----------------:|:-----------------:|:----------------:| -|[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/stable/) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/dev/)| [![CI](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml/badge.svg)](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml)| [![DOI](https://zenodo.org/badge/239018318.svg)](https://zenodo.org/badge/latestdoi/239018318) +|[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/stable/) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://slimgroup.github.io/InvertibleNetworks.jl/dev/)| [![CI](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml/badge.svg)](https://github.com/slimgroup/InvertibleNetworks.jl/actions/workflows/runtests.yml)| [![DOI](https://joss.theoj.org/papers/10.21105/joss.06554/status.svg)](https://doi.org/10.21105/joss.06554) + Building blocks for invertible neural networks in the [Julia] programming language. @@ -26,7 +27,7 @@ InvertibleNetworks is registered and can be added like any standard Julia packag ## Uncertainty-aware image reconstruction -Due to its memory scaling InvertibleNetworks.jl has been particularily successful at Bayesian posterior sampling with simulation-based inference. To get started with this application refer to a simple example ([Conditional sampling for MNSIT inpainting](https://github.com/slimgroup/InvertibleNetworks.jl/tree/master/examples/applications/application_conditional_mnist_inpainting.jl)) but feel free to modify this script for your application and please reach out to us if you run into any trouble. +Due to its memory scaling InvertibleNetworks.jl, has been particularily successful at Bayesian posterior sampling with simulation-based inference. To get started with this application refer to a simple example ([Conditional sampling for MNSIT inpainting](https://github.com/slimgroup/InvertibleNetworks.jl/tree/master/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl)) but feel free to modify this script for your application and please reach out to us for help. ![mnist_sampling_cond](docs/src/figures/mnist_sampling_cond.png) @@ -92,12 +93,7 @@ Y_, logdet = AN.forward(X) If you use InvertibleNetworks.jl in your research, we would be grateful if you cite us with the following bibtex: ``` -@article{orozco2023invertiblenetworks, - title={InvertibleNetworks. jl: A Julia package for scalable normalizing flows}, - author={Orozco, Rafael and Witte, Philipp and Louboutin, Mathias and Siahkoohi, Ali and Rizzuti, Gabrio and Peters, Bas and Herrmann, Felix J}, - journal={arXiv preprint arXiv:2312.13480}, - year={2023} -} +@article{Orozco2024, doi = {10.21105/joss.06554}, url = {https://doi.org/10.21105/joss.06554}, year = {2024}, publisher = {The Open Journal}, volume = {9}, number = {99}, pages = {6554}, author = {Rafael Orozco and Philipp Witte and Mathias Louboutin and Ali Siahkoohi and Gabrio Rizzuti and Bas Peters and Felix J. Herrmann}, title = {InvertibleNetworks.jl: A Julia package for scalable normalizing flows}, journal = {Journal of Open Source Software} } ``` diff --git a/examples/applications/application_conditional_hint_seismic.jl b/examples/applications/application_conditional_hint_seismic.jl deleted file mode 100644 index d862f9d6..00000000 --- a/examples/applications/application_conditional_hint_seismic.jl +++ /dev/null @@ -1,162 +0,0 @@ -# Generative model using the change of variables formula -# Author: Philipp Witte, pwitte3@gatech.edu -# Date: January 2020 - -using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test, JLD, Statistics -import Flux.Optimise.update! - -# Random seed -Random.seed!(66) - - -#################################################################################################### - -# Load original data X (size of n1 x n2 x nc x ntrain) -datadir = dirname(pathof(InvertibleNetworks))*"/../data/" -filename = "seismic_samples_64_by_64_num_10k.jld" -~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_64_by_64_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`) -X_orig = load("$(datadir)$(filename)")["X"] -n1, n2, nc, nsamples = size(X_orig) -AN = ActNorm(nsamples) -X_orig = AN.forward(X_orig) # zero mean and unit std - -# Split in training - testing -ntrain = Int(nsamples*.9) -ntest = nsamples - ntrain - -# Dimensions after wavelet squeeze to increase no. of channels -nx = Int(n1/2) -ny = Int(n2/2) -n_in = Int(nc*4) - -# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain) -X_train = zeros(Float32, nx, ny, n_in, ntrain) -for j=1:ntrain - X_train[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, j:j]) -end - -X_test = zeros(Float32, nx, ny, n_in, ntest) -for j=1:ntest - X_test[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, ntrain+j:ntrain+j]) -end - -# Create network -n_hidden = 64 -batchsize = 4 -depth = 8 -CH = NetworkConditionalHINT(n_in, n_hidden, depth) -Params = get_params(CH) - -#################################################################################################### - -# Loss -function loss(CH, X, Y) - Zx, Zy, logdet = CH.forward(X, Y) - f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet - ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy)) - ΔZx, ΔZy = tensor_split(ΔZ) - ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2] - return f, ΔX, ΔY -end - -# Training -maxiter = 1000 -opt = Flux.ADAM(1f-3) -lr_step = 100 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) -fval = zeros(Float32, maxiter) - -for j=1:maxiter - - # Evaluate objective and gradients - idx = randperm(ntrain)[1:batchsize] - X = X_train[:, :, :, idx] - Y = X + .5f0*randn(Float32, nx, ny, n_in, batchsize) - - fval[j] = loss(CH, X, Y)[1] - mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n")) - - # Update params - for p in Params - update!(opt, p.data, p.grad) - update!(lr_decay_fn, p.data, p.grad) - end - clear_grad!(CH) -end - -#################################################################################################### -# Plotting - -# Testing -test_size = 100 -idx = randperm(ntest)[1:test_size] # draw random samples from testing data -X = X_test[:, :, :, idx] -Y = X + .5f0*randn(Float32, nx, ny, n_in, test_size) -Zx_, Zy_ = CH.forward(X, Y)[1:2] - -Zx = randn(Float32, nx, ny, n_in, test_size) -Zy = randn(Float32, nx, ny, n_in, test_size) -X_, Y_ = CH.inverse(Zx, Zy) - -# Now select single fixed sample from all Ys -idx = 1 -X_fixed = X[:, :, :, idx:idx] -Y_fixed = Y[:, :, :, idx:idx] -Zy_fixed = CH.forward_Y(Y_fixed) - -# Draw new Zx, while keeping Zy fixed -X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, nx, ny, n_in, test_size))[1] - -# Unsqueeze all tensors -X = wavelet_unsqueeze(X) -Y = wavelet_unsqueeze(Y) -Zx_ = wavelet_unsqueeze(Zx_) -Zy_ = wavelet_unsqueeze(Zy_) - -X_ = wavelet_unsqueeze(X_) -Y_ = wavelet_unsqueeze(Y_) -Zx = wavelet_unsqueeze(Zx) -Zy = wavelet_unsqueeze(Zy) - -X_fixed = wavelet_unsqueeze(X_fixed) -Y_fixed = wavelet_unsqueeze(Y_fixed) -Zy_fixed = wavelet_unsqueeze(Zy_fixed) -X_post = wavelet_unsqueeze(X_post) - -# Plot one sample from X and Y and their latent versions -figure(figsize=[16,8]) -ax1 = subplot(2,4,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax2 = subplot(2,4,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ") -ax3 = subplot(2,4,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,4,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$") -ax5 = subplot(2,4,5); imshow(Zx_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx = f(x|y)$") -ax6 = subplot(2,4,6); imshow(Zy_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy = f(x|y)$") -ax7 = subplot(2,4,7); imshow(Zx[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx \sim \hat{p}_{zx}$") -ax8 = subplot(2,4,8); imshow(Zy[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy \sim \hat{p}_{zy}$") - -# Plot various samples from X and Y -figure(figsize=[16,8]) -i = randperm(test_size)[1:4] -ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") - -# Plot posterior samples, mean and standard deviation -figure(figsize=[16,8]) -X_post_mean = mean(X_post; dims=4) -X_post_std = std(X_post; dims=4) -ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x") -ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean") -ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ") -ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std)); -colorbar(); title("Posterior std"); - diff --git a/examples/applications/application_conditional_hint_seismic_linear.jl b/examples/applications/application_conditional_hint_seismic_linear.jl deleted file mode 100644 index 44d107e4..00000000 --- a/examples/applications/application_conditional_hint_seismic_linear.jl +++ /dev/null @@ -1,192 +0,0 @@ -# Generative model using the change of variables formula -# Author: Philipp Witte, pwitte3@gatech.edu -# Date: January 2020 - -using LinearAlgebra, InvertibleNetworks, JOLI, PyPlot, Flux, Random, Test, JLD, Statistics -import Flux.Optimise.update! -using Distributions: Uniform - -# Random seed -Random.seed!(66) - -#################################################################################################### - -# Load original data X (size of n1 x n2 x nc x ntrain) -datadir = dirname(pathof(InvertibleNetworks))*"/../data/" -filename = "seismic_samples_64_by_64_num_10k.jld" -~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_64_by_64_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`) -X_orig = load("$(datadir)$(filename)")["X"] -n1, n2, nc, nsamples = size(X_orig) -AN = ActNorm(nsamples) -X_orig = AN.forward(X_orig) # zero mean and unit std - -# Split in training - testing -ntrain = Int(nsamples*.9) -ntest = nsamples - ntrain - -# Dimensions after wavelet squeeze to increase no. of channels -nx = Int(n1/2) -ny = Int(n2/2) -n_in = Int(nc*4) - -# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain) -X_train = zeros(Float32, nx, ny, n_in, ntrain) -for j=1:ntrain - X_train[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, j:j]) -end - -X_test = zeros(Float32, nx, ny, n_in, ntest) -for j=1:ntest - X_test[:, :, :, j:j] = wavelet_squeeze(X_orig[:, :, :, ntrain+j:ntrain+j]) -end - -# Create network -n_hidden = 32 -batchsize = 4 -depth = 8 -CH = NetworkConditionalHINT(n_in, n_hidden, depth) -Params = get_params(CH) - -# Data modeling function -function model_data(X, A) - X = wavelet_unsqueeze(X) - nx, ny, nc, nb = size(X) - Y = reshape(A*reshape(X, :, nb), nx, ny, nc, nb) - Y = wavelet_squeeze(Y) - return Y -end - -# Forward operator, precompute phase in ambient dimension -function phasecode(n) - F = joDFT(n; DDT=Float32,RDT=ComplexF64) - phase=F*(adjoint(F)*exp.(1im*2*pi*convert(Array{Float32},rand(dist,n)))) - phase = phase ./abs.(phase) - sgn = sign.(convert(Array{Float32},randn(n))) - # Return operator - return M = joDiag(sgn) * adjoint(F) * joDiag(phase)*F -end - -dist = Uniform(-1, 1) -input_dim = (n1, n2) -subsamp = 2 -M = phasecode(prod(input_dim)) -R = joRestriction(prod(input_dim),randperm(prod(input_dim))[1:Int(round(prod(input_dim)/subsamp))]; DDT=Float32, RDT=Float32); -A_flat = R*M; -A = A_flat'*A_flat - -Y_train = model_data(X_train, A) -Y_test = model_data(X_test, A) - -#################################################################################################### - -# Loss -function loss(CH, X, Y) - Zx, Zy, logdet = CH.forward(X, Y) - f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet - ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy)) - ΔZx, ΔZy = tensor_split(ΔZ) - ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2] - return f, ΔX, ΔY -end - -# Training -maxiter = 1000 -opt = Flux.ADAM(1f-3) -lr_step = 100 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) -fval = zeros(Float32, maxiter) - -for j=1:maxiter - - # Evaluate objective and gradients - idx = randperm(ntrain)[1:batchsize] - X = X_train[:, :, :, idx] - Y = Y_train[:, :, :, idx] - - fval[j] = loss(CH, X, Y)[1] - mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n")) - - # Update params - for p in Params - update!(opt, p.data, p.grad) - update!(lr_decay_fn, p.data, p.grad) - end - clear_grad!(CH) -end - -#################################################################################################### -# Plotting - -# Testing -test_size = 100 -idx = randperm(ntest)[1:test_size] # draw random samples from testing data -X = X_test[:, :, :, idx] -Y = Y_test[:, :, :, idx] -Zx_, Zy_ = CH.forward(X, Y)[1:2] - -Zx = randn(Float32, nx, ny, n_in, test_size) -Zy = randn(Float32, nx, ny, n_in, test_size) -X_, Y_ = CH.inverse(Zx, Zy) - -# Now select single fixed sample from all Ys -idx = 1 -X_fixed = X[:, :, :, idx:idx] -Y_fixed = Y[:, :, :, idx:idx] -Zy_fixed = CH.forward_Y(Y_fixed) - -# Draw new Zx, while keeping Zy fixed -X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, nx, ny, n_in, test_size))[1] - -# Unsqueeze all tensors -X = wavelet_unsqueeze(X) -Y = wavelet_unsqueeze(Y) -Zx_ = wavelet_unsqueeze(Zx_) -Zy_ = wavelet_unsqueeze(Zy_) - -X_ = wavelet_unsqueeze(X_) -Y_ = wavelet_unsqueeze(Y_) -Zx = wavelet_unsqueeze(Zx) -Zy = wavelet_unsqueeze(Zy) - -X_fixed = wavelet_unsqueeze(X_fixed) -Y_fixed = wavelet_unsqueeze(Y_fixed) -Zy_fixed = wavelet_unsqueeze(Zy_fixed) -X_post = wavelet_unsqueeze(X_post) - -# Plot one sample from X and Y and their latent versions -figure(figsize=[16,8]) -ax1 = subplot(2,4,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax2 = subplot(2,4,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ") -ax3 = subplot(2,4,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,4,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$") -ax5 = subplot(2,4,5); imshow(Zx_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx = f(x|y)$") -ax6 = subplot(2,4,6); imshow(Zy_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy = f(x|y)$") -ax7 = subplot(2,4,7); imshow(Zx[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zx \sim \hat{p}_{zx}$") -ax8 = subplot(2,4,8); imshow(Zy[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Latent space: $zy \sim \hat{p}_{zy}$") - -# Plot various samples from X and Y -figure(figsize=[16,8]) -i = randperm(test_size)[1:4] -ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") - -# Plot posterior samples, mean and standard deviation -figure(figsize=[16,8]) -X_post_mean = mean(X_post; dims=4) -X_post_std = std(X_post; dims=4) -ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x") -ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean") -ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ") -ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std)); -colorbar(); title("Posterior std"); - diff --git a/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl b/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl deleted file mode 100644 index 7aa80a8e..00000000 --- a/examples/applications/application_multiscale_conditional_hint_seismic_linear.jl +++ /dev/null @@ -1,177 +0,0 @@ -# Generative model using the change of variables formula -# Author: Philipp Witte, pwitte3@gatech.edu -# Date: January 2020 - -using LinearAlgebra, InvertibleNetworks, JOLI, PyPlot, Flux, Random, Test, JLD, Statistics -import Flux.Optimise.update! -using Distributions: Uniform - -# Random seed -Random.seed!(66) - -#################################################################################################### - -# Load original data X (size of n1 x n2 x nc x ntrain) -#X_orig = load("../../data/seismic_samples_32_by_32_num_10k.jld")["X"] - - -# Load original data X (size of n1 x n2 x nc x ntrain) -datadir = dirname(pathof(InvertibleNetworks))*"/../data/" -filename = "seismic_samples_64_by_64_num_10k.jld" -~isfile("$(datadir)$(filename)") && run(`curl -L https://www.dropbox.com/s/mh5dv0yprestot4/seismic_samples_32_by_32_num_10k.jld\?dl\=0 --create-dirs -o $(datadir)$(filename)`) -X_orig = load("$(datadir)$(filename)")["X"] -n1, n2, nc, n_samples = size(X_orig) - -AN = ActNorm(n_samples) -X_orig = AN.forward(X_orig) # zero mean and unit std - -# Split in training - testing -ntrain = Int(n_samples*.9) -ntest = n_samples - ntrain - -# Dimensions after wavelet squeeze to increase no. of channels -nx = Int(n1) -ny = Int(n2) -n_in = Int(nc) - -# Apply wavelet squeeze (change dimensions to -> n1/2 x n2/2 x nc*4 x ntrain) -X_train = zeros(Float32, nx, ny, n_in, ntrain) -for j=1:ntrain - X_train[:, :, :, j:j] = X_orig[:, :, :, j:j] -end - -X_test = zeros(Float32, nx, ny, n_in, ntest) -for j=1:ntest - X_test[:, :, :, j:j] = X_orig[:, :, :, ntrain+j:ntrain+j] -end - -# Create network -n_hidden = 32 -batchsize = 4 -L = 1 -K = 8 -CH = NetworkMultiScaleConditionalHINT(n_in, n_hidden, L, K; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, split_scales=false) -Params = get_params(CH) - -# Data modeling function -function model_data(X, A) - nx, ny, nc, nb = size(X) - Y = reshape(A*reshape(X, :, nb), nx, ny, nc, nb) - return Y -end - -# Forward operator, precompute phase in ambient dimension -function phasecode(n) - F = joDFT(n; DDT=Float32,RDT=ComplexF64) - phase=F*(adjoint(F)*exp.(1im*2*pi*convert(Array{Float32},rand(dist,n)))) - phase = phase ./abs.(phase) - sgn = sign.(convert(Array{Float32},randn(n))) - # Return operator - return M = joDiag(sgn) * adjoint(F) * joDiag(phase)*F -end - -# Generate observed data -dist = Uniform(-1, 1) -input_dim = (n1, n2) -subsamp = 2 -M = phasecode(prod(input_dim)) -R = joRestriction(prod(input_dim),randperm(prod(input_dim))[1:Int(round(prod(input_dim)/subsamp))]; DDT=Float32, RDT=Float32); -A_flat = R*M; -A = A_flat'*A_flat - -Y_train = model_data(X_train, A) -Y_test = model_data(X_test, A) - -#################################################################################################### - -# Loss -function loss(CH, X, Y) - Zx, Zy, logdet = CH.forward(X, Y) - f = -log_likelihood(tensor_cat(Zx, Zy)) - logdet - ΔZ = -∇log_likelihood(tensor_cat(Zx, Zy)) - ΔZx, ΔZy = tensor_split(ΔZ) - ΔX, ΔY = CH.backward(ΔZx, ΔZy, Zx, Zy)[1:2] - return f, ΔX, ΔY -end - -# Training -maxiter = 500 -opt = Flux.ADAM(1f-3) -lr_step = 100 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) -fval = zeros(Float32, maxiter) - -for j=1:maxiter - - # Evaluate objective and gradients - idx = randperm(ntrain)[1:batchsize] - X = X_train[:, :, :, idx] - Y = Y_train[:, :, :, idx] - - fval[j] = loss(CH, X, Y)[1] - mod(j, 10) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n")) - - # Update params - for p in Params - update!(opt, p.data, p.grad) - update!(lr_decay_fn, p.data, p.grad) - end - clear_grad!(CH) -end - -#################################################################################################### -# Plotting - -# Testing -test_size = 100 -idx = randperm(ntest)[1:test_size] # draw random samples from testing data -X = X_test[:, :, :, idx] -Y = Y_test[:, :, :, idx] -Zx_, Zy_ = CH.forward(X, Y)[1:2] - -Zx = randn(Float32, size(Zx_)) -Zy = randn(Float32, size(Zy_)) -X_, Y_ = CH.inverse(Zx, Zy) - -# Now select single fixed sample from all Ys -idx = 1 -X_fixed = X[:, :, :, idx:idx] -Y_fixed = Y[:, :, :, idx:idx] -Zy_fixed = CH.forward_Y(Y_fixed) - -# Draw new Zx, while keeping Zy fixed -CH.forward_Y(X) # set X dimensions in forward pass (this needs to be fixed) -X_post = CH.inverse(Zx, Zy_fixed.*ones(Float32, size(Zx_)))[1] - -# Plot one sample from X and Y and their latent versions -figure(figsize=[8,8]) -ax1 = subplot(2,2,1); imshow(X[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax2 = subplot(2,2,2); imshow(Y[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y=x+n$ ") -ax3 = subplot(2,2,3); imshow(X_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,2,4); imshow(Y_[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Data space: $y = f(zx|zy)^{-1}$") - -# Plot various samples from X and Y -figure(figsize=[16,8]) -i = randperm(test_size)[1:4] -ax1 = subplot(2,4,1); imshow(X_[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax2 = subplot(2,4,2); imshow(X_[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax3 = subplot(2,4,3); imshow(X_[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax4 = subplot(2,4,4); imshow(X_[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x = f(zx|zy)^{-1}$") -ax5 = subplot(2,4,5); imshow(X[:, :, 1, i[1]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax6 = subplot(2,4,6); imshow(X[:, :, 1, i[2]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax7 = subplot(2,4,7); imshow(X[:, :, 1, i[3]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") -ax8 = subplot(2,4,8); imshow(X[:, :, 1, i[4]], cmap="gray", aspect="auto"); title(L"Model space: $x \sim \hat{p}_x$") - -# Plot posterior samples, mean and standard deviation -figure(figsize=[16,8]) -X_post_mean = mean(X_post; dims=4) -X_post_std = std(X_post; dims=4) -ax1 = subplot(2,4,1); imshow(X_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title("True x") -ax2 = subplot(2,4,2); imshow(X_post[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax3 = subplot(2,4,3); imshow(X_post[:, :, 1, 2], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax4 = subplot(2,4,4); imshow(X_post_mean[:, :, 1, 1], cmap="gray", aspect="auto"); title("Posterior mean") -ax5 = subplot(2,4,5); imshow(Y_fixed[:, :, 1, 1], cmap="gray", aspect="auto"); title(L"Noisy data $y_i=x_i+n$ ") -ax6 = subplot(2,4,6); imshow(X_post[:, :, 1, 4], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax7 = subplot(2,4,7); imshow(X_post[:, :, 1, 5], cmap="gray", aspect="auto"); title(L"Post. sample: $x = f(zx|zy_{fix})^{-1}$") -ax8 = subplot(2,4,8); imshow(X_post_std[:, :, 1,1], cmap="binary", aspect="auto", vmin=0, vmax=0.9*maximum(X_post_std)); -colorbar(); title("Posterior std"); \ No newline at end of file diff --git a/examples/applications/amortized_posterior_inference_1d_gaussian.jl b/examples/applications/conditional_sampling/amortized_glow_1d_gaussian.jl similarity index 100% rename from examples/applications/amortized_posterior_inference_1d_gaussian.jl rename to examples/applications/conditional_sampling/amortized_glow_1d_gaussian.jl diff --git a/examples/applications/application_conditional_mnist_inpainting.jl b/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl similarity index 67% rename from examples/applications/application_conditional_mnist_inpainting.jl rename to examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl index 10b5167d..40ecf1a9 100644 --- a/examples/applications/application_conditional_mnist_inpainting.jl +++ b/examples/applications/conditional_sampling/amortized_glow_mnist_inpainting.jl @@ -1,16 +1,20 @@ +# Take around 6 minutes on CPU using Pkg -Pkg.activate(".") -# Take around 6 minutes on CPU +Pkg.add("InvertibleNetworks") +Pkg.add("Flux") +Pkg.add("MLDatasets") +Pkg.add("ProgressMeter") +Pkg.add("MLUtils") +Pkg.add("ImageTransformations") + using InvertibleNetworks using Flux using LinearAlgebra using MLDatasets -using Statistics -using PyPlot -using ProgressMeter: Progress, next! -using Images +using ImageTransformations using MLUtils +using ProgressMeter: Progress, next! function posterior_sampler(G, y, size_x; device=gpu, num_samples=1, batch_size=16) # make samples from posterior for train sample @@ -111,52 +115,35 @@ for e=1:epochs # epoch loop append!(loss_val, norm(ZX)^2 / (N*n_val) - logdet_i / N) # normalize by image size and batch size end -# Training logs -final_obj_train = round(loss_train[end];digits=3) -final_obj_val = round(loss_val[end];digits=3) -fig = figure() -title("Objective value: train=$(final_obj_train) validation=$(final_obj_val)") +Pkg.add("Statistics") +Pkg.add("Plots") +using Statistics +using Plots + +# Training logs plot(loss_train;label="Train"); plot(batches:batches:batches*(epochs), loss_val;label="Validation"); -xlabel("Parameter update"); ylabel("Negative log likelihood objective") ; -legend() -savefig("log.png",dpi=300) - -# Make Figure of README -num_plot = 2 -fig = figure(figsize=(11, 5)); -for (i,ind) in enumerate([1,3]) - x = XY_val[1][:,:,:,ind:ind] - y = XY_val[2][:,:,:,ind:ind] - X_post = posterior_sampler(G, y, size(x); device=device, num_samples=64) |> cpu - - X_post_mean = mean(X_post; dims=ndims(X_post)) - X_post_var = var(X_post; dims=ndims(X_post)) +xlabel!("Parameter update"); ylabel!("Negative log likelihood objective") ; +savefig("log.png") - ssim_val = round(assess_ssim(X_post_mean[:,:,1,1], x[:,:,1,1]) ,digits=2) +# Make Figure like README +ind = 1 +x = XY_val[1][:,:,:,ind:ind] +y = XY_val[2][:,:,:,ind:ind] +X_post = posterior_sampler(G, y, size(x); device=device, num_samples=64) |> cpu - subplot(num_plot,7,1+7*(i-1)); imshow(x[:,:,1,1], vmin=0, vmax=1, cmap="gray") - axis("off"); title(L"$x$"); +X_post_mean = mean(X_post; dims=ndims(X_post)) +X_post_var = var(X_post; dims=ndims(X_post)) - subplot(num_plot,7,2+7*(i-1)); imshow(y[:,:,1,1] |> cpu, cmap="gray") - axis("off"); title(L"$y$"); +p1 = heatmap(x[:,:,1,1];title="Ground truth",cbar=false,clims = (0, 1),axis=([], false)) +p2 = heatmap(y[:,:,1,1];title="Observation",cbar=false,clims = (0, 1),axis=([], false)) +p3 = heatmap(X_post_mean[:,:,1,1];title="Posterior mean",cbar=false,clims = (0, 1),axis=([], false)) +p4 = heatmap(X_post_var[:,:,1,1];title="Posterior variance",cbar=false,axis=([], false)) +p5 = heatmap(X_post[:,:,1,1];title="Posterior sample",cbar=false,clims = (0, 1),axis=([], false)) +p6 = heatmap(X_post[:,:,1,2];title="Posterior sample",cbar=false,clims = (0, 1),axis=([], false)) - subplot(num_plot,7,3+7*(i-1)); imshow(X_post_mean[:,:,1,1] , vmin=0, vmax=1, cmap="gray") - axis("off"); title("SSIM="*string(ssim_val)*" \n"*"Conditional Mean") ; +plot(p1,p2,p3,p4,p5,p6,aspect_ratio=:equal, size=(600,400)) +savefig("posterior_sampling.png") - subplot(num_plot,7,4+7*(i-1)); imshow(X_post_var[:,:,1,1] , cmap="magma") - axis("off"); title(L"$UQ$") ; - - subplot(num_plot,7,5+7*(i-1)); imshow(X_post[:,:,1,1] |> cpu, vmin=0, vmax=1,cmap="gray") - axis("off"); title("Posterior Sample") ; - - subplot(num_plot,7,6+7*(i-1)); imshow(X_post[:,:,1,2] |> cpu, vmin=0, vmax=1,cmap="gray") - axis("off"); title("Posterior Sample") ; - - subplot(num_plot,7,7+7*(i-1)); imshow(X_post[:,:,1,3] |> cpu, vmin=0, vmax=1, cmap="gray") - axis("off"); title("Posterior Sample") ; -end -tight_layout() -savefig("mnist_sampling_cond.png",dpi=300,bbox_inches="tight") diff --git a/examples/applications/application_conditional_hint_banana.jl b/examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl similarity index 85% rename from examples/applications/application_conditional_hint_banana.jl rename to examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl index 3617f2c8..c3b8949b 100644 --- a/examples/applications/application_conditional_hint_banana.jl +++ b/examples/applications/conditional_sampling/amortized_hint_banana_denoise.jl @@ -1,6 +1,12 @@ -# Generative model using the change of variables formula +# Method: Amortized posterior sampler / simulation based inference / Forwad KL variational inference +# Application: sample from conditional distribution given noisy observations of the rosenbrock distribution. +# Note: we currently recommend conditional glow architectures instead of HINT, unless you need the latent space of +# the observation. + # Author: Philipp Witte, pwitte3@gatech.edu # Date: January 2020 +using Pkg +Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot") using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test import Flux.Optimise.update! @@ -35,9 +41,7 @@ end # Training maxiter = 1000 -opt = Flux.ADAM(1f-3) -lr_step = 100 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) +opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 100, 0.), Flux.ADAM(1f-3)) fval = zeros(Float32, maxiter) for j=1:maxiter @@ -51,7 +55,6 @@ for j=1:maxiter # Update params for p in get_params(H) update!(opt, p.data, p.grad) - update!(lr_decay_fn, p.data, p.grad) end clear_grad!(H) end @@ -96,5 +99,4 @@ plot(Y_fixed[1, 1, 1, :], Y_fixed[1, 1, 2, :], "r."); title(L"Model space: $x = ax7.set_xlim([-3.5,3.5]); ax7.set_ylim([0,50]) ax8 = subplot(2,4,8); plot(Zx[1, 1, 1, :], Zx[1, 1, 2, :], "."); plot(Zy_fixed[1, 1, 1, :], Zy_fixed[1, 1, 2, :], "r."); title(L"Latent space: $zx \sim \hat{p}_{zx}$") -ax8.set_xlim([-3.5, 3.5]); ax8.set_ylim([-3.5, 3.5]) - +ax8.set_xlim([-3.5, 3.5]); ax8.set_ylim([-3.5, 3.5]) \ No newline at end of file diff --git a/examples/applications/application_conditional_hint_banana_linear.jl b/examples/applications/conditional_sampling/amortized_hint_banana_linear.jl similarity index 87% rename from examples/applications/application_conditional_hint_banana_linear.jl rename to examples/applications/conditional_sampling/amortized_hint_banana_linear.jl index 51a04a73..2ae6af7a 100644 --- a/examples/applications/application_conditional_hint_banana_linear.jl +++ b/examples/applications/conditional_sampling/amortized_hint_banana_linear.jl @@ -1,6 +1,14 @@ -# Generative model using the change of variables formula +# Method: Amortized posterior sampler / simulation based inference / Forwad KL variational inference +# Application: sample from conditional distribution given observations of the rosenbrock distribution passed through +# a linear operator. +# Note: we currently recommend conditional glow architectures instead of HINT, unless you need the latent space of +# the observatin. + # Author: Philipp Witte, pwitte3@gatech.edu # Date: January 2020 +using Pkg + +Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot") using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random, Test import Flux.Optimise.update! @@ -39,13 +47,10 @@ end # Training maxiter = 1000 -opt = Flux.ADAM(1f-3) -lr_step = 100 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) +opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 100, 0.), Flux.ADAM(1f-3)) fval = zeros(Float32, maxiter) for j=1:maxiter - # Evaluate objective and gradients X = sample_banana(batchsize) Y = reshape(A*reshape(X, :, batchsize), nx, ny, n_in, batchsize) @@ -57,7 +62,6 @@ for j=1:maxiter # Update params for p in get_params(H) update!(opt, p.data, p.grad) - update!(lr_decay_fn, p.data, p.grad) end clear_grad!(H) end @@ -114,7 +118,3 @@ plot(Zy_fixed[1, 1, 1, :], Zy_fixed[1, 1, 2, :], "r."); title(L"Latent space: $z ax10.set_xlim([-3.5, 3.5]); ax10.set_ylim([-3.5, 3.5]) - - - - diff --git a/examples/applications/application_hint_scenario_2.jl b/examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl similarity index 96% rename from examples/applications/application_hint_scenario_2.jl rename to examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl index fdd9a3c9..f06d4cf5 100644 --- a/examples/applications/application_hint_scenario_2.jl +++ b/examples/applications/conditional_sampling/non_amortized_hint_linear_Gaussian.jl @@ -2,12 +2,17 @@ # Obtaining samples from posterior for the following problem: # y = Ax + ϵ, x ~ N(μ_x, Σ_x), ϵ ~ N(μ_ϵ, Σ_ϵ), A ~ N(0, I/|x|) +using Pkg +Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot"); +Pkg.add("Distributions"); Pkg.add("Printf") + using InvertibleNetworks, LinearAlgebra, Test using Distributions import Flux.Optimise.update!, Flux using PyPlot using Printf using Random + Random.seed!(19) # Model and data dimension @@ -102,9 +107,7 @@ end # Optimizer η = 0.01 -opt = Flux.ADAM(η) -lr_step = 2000 -lr_decay_fn = Flux.ExpDecay(η, .3, lr_step, 0.) +opt = Flux.Optimiser(Flux.ExpDecay(η, .3, 2000, 0.), Flux.ADAM(η)) # Loss function function loss(z_in, y_in) diff --git a/examples/applications/application_conditionalScenario2_hint_banana.jl b/examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl similarity index 95% rename from examples/applications/application_conditionalScenario2_hint_banana.jl rename to examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl index 71eee2aa..2887d753 100644 --- a/examples/applications/application_conditionalScenario2_hint_banana.jl +++ b/examples/applications/conditional_sampling/non_amortized_hint_rosenbrock_banana.jl @@ -2,6 +2,9 @@ # Author: Philipp Witte, pwitte3@gatech.edu # Date: January 2020 +using Pkg +Pkg.add("InvertibleNetworks"); Pkg.add("Flux"); Pkg.add("PyPlot"); + using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random import Flux.Optimise.update! @@ -88,9 +91,8 @@ end # Training maxiter = 1000 -opt = Flux.ADAM(1f-3) -lr_step = 50 -lr_decay_fn = Flux.ExpDecay(1f-3, .9, lr_step, 0.) +opt = Flux.Optimiser(Flux.ExpDecay(1f-3, .9, 50, 0.), Flux.ADAM(1f-3)) + fval = zeros(Float32, maxiter) for j=1:maxiter diff --git a/examples/applications/application_glow_banana_dist.jl b/examples/applications/non_conditional_sampling/glow_banana.jl similarity index 100% rename from examples/applications/application_glow_banana_dist.jl rename to examples/applications/non_conditional_sampling/glow_banana.jl diff --git a/examples/applications/generative_sampling.jl b/examples/applications/non_conditional_sampling/glow_seismic.jl similarity index 100% rename from examples/applications/generative_sampling.jl rename to examples/applications/non_conditional_sampling/glow_seismic.jl diff --git a/examples/applications/application_hint_banana_dist.jl b/examples/applications/non_conditional_sampling/hint_banana.jl similarity index 100% rename from examples/applications/application_hint_banana_dist.jl rename to examples/applications/non_conditional_sampling/hint_banana.jl diff --git a/examples/benchmarks/memory_usage_invertiblenetworks.jl b/examples/benchmarks/memory_usage_invertiblenetworks.jl new file mode 100644 index 00000000..d241d5af --- /dev/null +++ b/examples/benchmarks/memory_usage_invertiblenetworks.jl @@ -0,0 +1,119 @@ +# Test memory usage of InvertibleNetworks as network depth increases + + +using InvertibleNetworks, LinearAlgebra, Flux +import Flux.Optimise.update! + +device = InvertibleNetworks.CUDA.functional() ? gpu : cpu + +#turn off JULIA cuda optimization to get raw peformance +ENV["JULIA_CUDA_MEMORY_POOL"] = "none" + +using CUDA, Printf + +export @gpumem + +get_mem() = NVML.memory_info(NVML.Device(0)).used/(1024^3) + +function montitor_gpu_mem(used::Vector{T}, status::Ref) where {T<:Real} + while status[] + #cleanup() + push!(used, get_mem()) + end + nothing +end + +cleanup() = begin GC.gc(true); CUDA.reclaim(); end + +macro gpumem(expr) + return quote + # Cleanup + cleanup() + monitoring = Ref(true) + used = [get_mem()] + Threads.@spawn montitor_gpu_mem(used, monitoring) + val = $(esc(expr)) + monitoring[] = false + cleanup() + @printf("Min memory: %1.3fGiB , Peak memory: %1.3fGiB \n", + extrema(used)...) + used + end +end + +# Objective function +function loss(G,X) + Y, logdet = G.forward(X) + #cleanup() + f = .5f0/batchsize*norm(Y)^2 - logdet + ΔX, X_ = G.backward(1f0./batchsize*Y, Y) + return f +end + +# size of network input +nx = 256 +ny = nx +n_in = 3 +batchsize = 8 +X = rand(Float32, nx, ny, n_in, batchsize) |> device + +# Define network +n_hidden = 256 +L = 3 # number of scales + +num_retests=1 +mems_max= [] +mem_tested = [4,8,16,32,48,64,80] +for K in mem_tested + G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true) |> device + + loss(G,X) + curr_maxes = [] + for i in 1:num_retests + usedmem = @gpumem begin + loss(G,X) + end + append!(curr_maxes,maximum(usedmem)) + end + append!(mems_max, minimum(curr_maxes)) + println(mems_max) +end + +# control for memory of storing the network parameters on GPU, not relevant to backpropagation +mems_model = [] +for K in mem_tested + G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true) |> device + G(X) + G = G |> cpu + usedmem = @gpumem begin + G = G |> device + end + + append!(mems_model, maximum(usedmem)) +end +mems_model_norm = mems_model .- mems_model[1] + + +mem_used_invnets = mems_max .- mems_model_norm +mem_used_pytorch = [5.0897216796875, 7.8826904296875, 13.5487060546875, 24.8709716796875, 36.2010498046875, 40, NaN] +mem_ticks = mem_tested + + +using PyPlot + +font_size=15 +PyPlot.rc("font", family="serif"); +PyPlot.rc("font", family="serif", size=font_size); PyPlot.rc("xtick", labelsize=font_size); PyPlot.rc("ytick", labelsize=font_size); +PyPlot.rc("axes", labelsize=font_size) # fontsize of the x and y labels + +#nice plot +fig = figure(figsize=(14,6)) +plot(log.(mem_tested),mem_used_pytorch; color="black", linestyle="--",markevery=collect(range(0,4,step=1)),label="PyTorch package",marker="o",markerfacecolor="r") +plot(log.(mem_tested),mem_used_invnets; color="black", label="InvertibleNetworks.jl package",marker="o",markerfacecolor="b") +axvline(log.(mem_tested)[end-1], linestyle="--",color="red",label="PyTorch out of memory error") +grid() +xticks(log.(mem_ticks), [string(i) for i in mem_ticks],rotation=60) +legend() +xlabel("Network depth [# of layers]"); +ylabel("Peak memory used [GB]"); +fig.savefig("mem_used_new_depth_new.png", bbox_inches="tight", dpi=400) diff --git a/examples/benchmarks/memory_usage_normflows.py b/examples/benchmarks/memory_usage_normflows.py new file mode 100644 index 00000000..adaecf7a --- /dev/null +++ b/examples/benchmarks/memory_usage_normflows.py @@ -0,0 +1,138 @@ +# Test memory usage of normflows as network depth increases + +#salloc -A rafael -t00:80:00 --gres=gpu:1 --mem-per-cpu=30G srun --pty python + +import torch +import nvidia_smi + +def _get_gpu_mem(synchronize=True, empty_cache=True): + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) + mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) + return mem.used + +def _generate_mem_hook(handle_ref, mem, idx, hook_type, exp): + def hook(self, *args): + if hook_type == 'pre': + return + if len(mem) == 0 or mem[-1]["exp"] != exp: + call_idx = 0 + else: + call_idx = mem[-1]["call_idx"] + 1 + mem_all = _get_gpu_mem() + torch.cuda.synchronize() + lname = type(self).__name__ + lname = 'conv' if 'conv' in lname.lower() else lname + lname = 'ReLU' if 'relu' in lname.lower() else lname + mem.append({ + 'layer_idx': idx, + 'call_idx': call_idx, + 'layer_type': f"{lname}_{hook_type}", + 'exp': exp, + 'hook_type': hook_type, + 'mem_all': mem_all, + }) + return hook + + +def _add_memory_hooks(idx, mod, mem_log, exp, hr): + h = mod.register_forward_pre_hook(_generate_mem_hook(hr, mem_log, idx, 'pre', exp)) + hr.append(h) + h = mod.register_forward_hook(_generate_mem_hook(hr, mem_log, idx, 'fwd', exp)) + hr.append(h) + h = mod.register_backward_hook(_generate_mem_hook(hr, mem_log, idx, 'bwd', exp)) + hr.append(h) + + +def log_mem(model, inp, mem_log=None, exp=None): + nvidia_smi.nvmlInit() + mem_log = mem_log or [] + exp = exp or f'exp_{len(mem_log)}' + hr = [] + for idx, module in enumerate(model.modules()): + _add_memory_hooks(idx, module, mem_log, exp, hr) + try: + out = model(inp) + loss = out.sum() + loss.backward() + except Exception as e: + print(f"Errored with error {e}") + finally: + [h.remove() for h in hr] + return mem_log + +# Used this in InvertibleNetowrks.jl so should use here. Though didnt see much difference in performance +import os +os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = "1" + +import pandas as pd +import torch +import torchvision as tv +import numpy as np +import normflows as nf +from matplotlib import pyplot as plt +from tqdm import tqdm + +torch.manual_seed(0) + +L = 3 +nx = 256 +max_mems = [] + +for depth in [4,8,16,32,48,64]: + print(depth) + K = depth + input_shape = (3, nx, nx) + n_dims = np.prod(input_shape) + channels = 3 + hidden_channels = 256 + split_mode = 'channel' + scale = True + + # Set up flows, distributions and merge operations + q0 = [] + merges = [] + flows = [] + for i in range(L): + flows_ = [] + for j in range(K): + flows_ += [nf.flows.GlowBlock(channels * 2 ** (L + 1 - i), hidden_channels, + split_mode=split_mode, scale=scale)] + flows_ += [nf.flows.Squeeze()] + flows += [flows_] + if i > 0: + merges += [nf.flows.Merge()] + latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), + input_shape[2] // 2 ** (L - i)) + else: + latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, + input_shape[2] // 2 ** L) + q0 += [nf.distributions.DiagGaussian(latent_shape,trainable=False)] + # Construct flow model with the multiscale architecture + model = nf.MultiscaleFlow(q0, flows, merges) + enable_cuda = True + device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu') + model = model.to(device) + batch_size = 8 + transform = tv.transforms.Compose([tv.transforms.ToTensor(), nf.utils.Scale(255. / 256.), nf.utils.Jitter(1 / 256.),tv.transforms.Resize(size=input_shape[1])]) + train_data = tv.datasets.CIFAR10('datasets/', train=True, + download=True, transform=transform) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, + drop_last=True) + train_iter = iter(train_loader) + optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5) + x, y = next(train_iter) + + mem_log = [] + try: + mem_log.extend(log_mem(model, x.to(device), exp='std')) + except Exception as e: + print(f'log_mem failed because of {e}') + torch.cuda.synchronize() + torch.cuda.empty_cache() + df = pd.DataFrame(mem_log) + max_mem = df['mem_all'].max()/(1024**3) + print("Peak memory usage: %.2f Gb" % (max_mem,)) + max_mems.append(max_mem) + +# >>> max_mems +# [5.0897216796875, 7.8826904296875, 13.5487060546875, 24.8709716796875, 36.2010498046875, 39.9959716796875] \ No newline at end of file diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index 7d0d69e4..35538fbe 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -78,7 +78,12 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze # 1x1 Convolution and residual block for invertible layers C = Conv1x1(n_in; freeze=freeze_conv) - RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) + + split_num = Int(round(n_in/2)) + in_split = n_in-split_num + out_chan = 2*split_num + + RB = ResidualBlock(in_split+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) return ConditionalLayerGlow(C, RB, logdet, activation) end @@ -143,7 +148,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA # Backpropagate RB ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C))) - ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2)) + ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1]) ΔX2 += ΔY2 # Backpropagate 1x1 conv diff --git a/test/runtests.jl b/test/runtests.jl index 4dd9dce0..2d63b6cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,7 @@ basics = ["test_utils/test_objectives.jl", "test_utils/test_nnlib_convolution.jl", "test_utils/test_activations.jl", "test_utils/test_squeeze.jl", - "test_utils/test_jacobian.jl", + #"test_utils/test_jacobian.jl", "test_utils/test_chainrules.jl", "test_utils/test_flux.jl"] diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index a2a0380d..946b39b0 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -9,6 +9,49 @@ device = InvertibleNetworks.CUDA.functional() ? gpu : cpu # Random seed Random.seed!(3); +# Define network +nx = 32; ny = 32; nz = 32 +n_in = 3 +n_cond = 3 +n_hidden = 4 +batchsize = 2 +L = 2 +K = 2 +split_scales = false +N = (nx,ny) + +########################################### Test with split_scales = false N = (nx,ny) ######################### +# Invertibility + +# Network and input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +X = rand(Float32, N..., n_in, batchsize) |> device +Cond = rand(Float32, N..., n_cond, batchsize) |> device + +Y, Cond = G.forward(X,Cond) +X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes + +@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + +# Test gradients are set and cleared +G.backward(Y, Y, Cond) + +P = get_params(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, L*K*10+2) + +clear_grad!(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, 0) + + +Random.seed!(3); # Define network nx = 32; ny = 32; nz = 32 n_in = 2