Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create vae_mnist_new_architecture.jl #487

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions examples/variational-autoencoder/vae_mnist_new_architecture.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Library includes
using Knet
using PyPlot
using AutoGrad


# General Type definitions
const F = Float32 # Data type for gpu usage
const Atype = gpu() >= 0 ? KnetArray{F} : Array{F}
const Itype = Union{KnetArray{F,4},AutoGrad.Result{KnetArray{F,4}}}
abstract type Layer end;


# Parameter definitions
nz = 20 # Encoding dimension
nh = 400 # Size of hidden layer


"""
The Convolution layer
"""
struct Conv <: Layer; w; b; f::Function; pad::Int; str::Int; end
(c::Conv)(x::Itype) = c.f.(conv4(c.w, x, padding = c.pad, stride = c.str) .+ c.b)
Conv(w1, w2, cx, cy;f = relu,pad = 1,str = 1) = Conv(param(w1, w2, cx, cy), param0(1, 1, cy, 1), f, pad, str)


"""
The DeConvolution Layer = Reverse of Conv
"""
struct DeConv <: Layer; w; b; f::Function; pad::Int; str::Int; end
(c::DeConv)(x) = c.f.(deconv4(c.w, x, padding = c.pad, stride = c.str) .+ c.b)
DeConv(w1, w2, cx, cy;f = relu,pad = 1,str = 1) = DeConv(param(w1, w2, cx, cy), param0(1, 1, cx, 1), f, pad, str)


"""
The Dense layer
"""
struct Dense <: Layer; w; b; f::Function; end
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b)
Dense(i::Int, o::Int; f = relu) = Dense(param(o, i), param0(o), f)


"""
Chain of layers
"""
struct Chain; layers; end
(c::Chain)(x) = (for l in c.layers; x=l(x); end; x)
(c::Chain)(x, m) = (for (index, l) in enumerate(c.layers); x = l(x, m[index]); end; x)


"""
Chain of Networks -> Autoencoder
"""
struct Autoencoder; ϕ::Chain; θ::Chain; end
function (ae::Autoencoder)(x; samples=1, β=1, F=Float32)
z_out = ae.ϕ(x)
μ, logσ² = z_out[1:nz, :], z_out[nz + 1:end, :]
σ² = exp.(logσ²)
σ = sqrt.(σ²)

KL = -sum(@. 1 + logσ² - μ * μ - σ²) / 2
KL /= length(x)

BCE = F(0)

for s = 1:samples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably not efficient, you can run all samples at once by additional "sample" batching. First, you need to reshape μ to (nz,B,1), then you need to sample from randn with size (nz,B,Nsample) and broadcast μ on it. Then, you can change binary_cross_entropy to deal with (nz,B,Nsample) input.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the form was not efficient for sampling multiple times within one batch. The Suggestion to broadcast is more effecient and much faster. However, I was not able to broadcast this efficiently trough the decoder network. As the sampling doesn't increase performance as far as i can tell and the majority of implementations i found do not use it, i have also abandoned this for the example here.

ϵ = convert(Atype, randn(F, size(μ)))
z = @. μ + ϵ * σ
x̂ = ae.θ(z)
BCE += binary_cross_entropy(x, x̂)
end

BCE /= samples

return BCE + β * KL
end

(ae::Autoencoder)(x, y) = ae(x)


function binary_cross_entropy(x, x̂)
x = reshape(x, size(x̂))
s = @. x * log(x̂ + F(1e-10)) + (1 - x) * log(1 - x̂ + F(1e-10))
return -sum(s) / length(x)
end


# Definition of the Encoder
ϕ = Chain((
Conv(3, 3, 1, 16, pad=1),
Conv(4, 4, 16, 32, pad=1, str=2),
Conv(3, 3, 32, 32, pad=1),
Conv(4, 4, 32, 64, pad=1, str=2),

x->mat(x),

Dense(64 * 7^2, nh),
Dense(nh, 2 * nz),
))


# Definition of the Decoder
θ = Chain((
Dense(nz, nh),
Dense(nh, 64 * 7^2),

x->reshape(x, (7, 7, 64, :)),

DeConv(4, 4, 32, 64, pad=1, str=2),
DeConv(3, 3, 32, 32, pad=1),
DeConv(4, 4, 16, 32, pad=1, str=2),
DeConv(3, 3, 1, 16, f=sigm, pad=1),
))

# Initialize the autoencoder with Encoder and Decoder
ae = Autoencoder(ϕ, θ)

# Load dataset
include(Knet.dir("data", "mnist.jl"))
dtrn, dtst = mnistdata()


"""
Visualize the progress during training
"""
function cb_plot(ae, img, epoch)
img_o = convert(Array{Float64}, img)
img_r = convert(Array{Float64}, ae.θ(ae.ϕ(img)[1:nz, :]))

figure("Epoch $epoch")
clf()
subplot(1, 2, 1)
title("Original")
imshow(img_o[:, :, 1, 1])
subplot(1, 2, 2)
title("Reproduced")
imshow(img_r[:, :, 1, 1])
end


"""
Main function for training
Questions to: [email protected]
"""
function train(ae, dtrn, iters)
img = convert(Atype, reshape(dtrn.x[:,1], (28, 28, 1, 1)))
for epoch = 1:iters
@time adam!(ae, dtrn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm not wrong, this is not the correct way to iterate over epochs, since here each time a new Adam struct is created and information (e.g. accumulated moments) from previous epochs are lost

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes completely correct! Thanks for the advice, i have adapted my example proposal accordingly.


if (epoch % 20) == 0
@show ae(first(dtrn)...)
cb_plot(ae, img, epoch)
end
end
end

# Precompile
@info "Precompile"
ae(first(dtrn)...)
@time adam!(ae, dtrn)

# Train
@info "Start training!"
@time train(ae, dtrn, 50)