-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create vae_mnist_new_architecture.jl #487
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Library includes | ||
using Knet | ||
using PyPlot | ||
using AutoGrad | ||
|
||
|
||
# General Type definitions | ||
const F = Float32 # Data type for gpu usage | ||
const Atype = gpu() >= 0 ? KnetArray{F} : Array{F} | ||
const Itype = Union{KnetArray{F,4},AutoGrad.Result{KnetArray{F,4}}} | ||
abstract type Layer end; | ||
|
||
|
||
# Parameter definitions | ||
nz = 20 # Encoding dimension | ||
nh = 400 # Size of hidden layer | ||
|
||
|
||
""" | ||
The Convolution layer | ||
""" | ||
struct Conv <: Layer; w; b; f::Function; pad::Int; str::Int; end | ||
(c::Conv)(x::Itype) = c.f.(conv4(c.w, x, padding = c.pad, stride = c.str) .+ c.b) | ||
Conv(w1, w2, cx, cy;f = relu,pad = 1,str = 1) = Conv(param(w1, w2, cx, cy), param0(1, 1, cy, 1), f, pad, str) | ||
|
||
|
||
""" | ||
The DeConvolution Layer = Reverse of Conv | ||
""" | ||
struct DeConv <: Layer; w; b; f::Function; pad::Int; str::Int; end | ||
(c::DeConv)(x) = c.f.(deconv4(c.w, x, padding = c.pad, stride = c.str) .+ c.b) | ||
DeConv(w1, w2, cx, cy;f = relu,pad = 1,str = 1) = DeConv(param(w1, w2, cx, cy), param0(1, 1, cx, 1), f, pad, str) | ||
|
||
|
||
""" | ||
The Dense layer | ||
""" | ||
struct Dense <: Layer; w; b; f::Function; end | ||
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b) | ||
Dense(i::Int, o::Int; f = relu) = Dense(param(o, i), param0(o), f) | ||
|
||
|
||
""" | ||
Chain of layers | ||
""" | ||
struct Chain; layers; end | ||
(c::Chain)(x) = (for l in c.layers; x=l(x); end; x) | ||
(c::Chain)(x, m) = (for (index, l) in enumerate(c.layers); x = l(x, m[index]); end; x) | ||
|
||
|
||
""" | ||
Chain of Networks -> Autoencoder | ||
""" | ||
struct Autoencoder; ϕ::Chain; θ::Chain; end | ||
function (ae::Autoencoder)(x; samples=1, β=1, F=Float32) | ||
z_out = ae.ϕ(x) | ||
μ, logσ² = z_out[1:nz, :], z_out[nz + 1:end, :] | ||
σ² = exp.(logσ²) | ||
σ = sqrt.(σ²) | ||
|
||
KL = -sum(@. 1 + logσ² - μ * μ - σ²) / 2 | ||
KL /= length(x) | ||
|
||
BCE = F(0) | ||
|
||
for s = 1:samples | ||
ϵ = convert(Atype, randn(F, size(μ))) | ||
z = @. μ + ϵ * σ | ||
x̂ = ae.θ(z) | ||
BCE += binary_cross_entropy(x, x̂) | ||
end | ||
|
||
BCE /= samples | ||
|
||
return BCE + β * KL | ||
end | ||
|
||
(ae::Autoencoder)(x, y) = ae(x) | ||
|
||
|
||
function binary_cross_entropy(x, x̂) | ||
x = reshape(x, size(x̂)) | ||
s = @. x * log(x̂ + F(1e-10)) + (1 - x) * log(1 - x̂ + F(1e-10)) | ||
return -sum(s) / length(x) | ||
end | ||
|
||
|
||
# Definition of the Encoder | ||
ϕ = Chain(( | ||
Conv(3, 3, 1, 16, pad=1), | ||
Conv(4, 4, 16, 32, pad=1, str=2), | ||
Conv(3, 3, 32, 32, pad=1), | ||
Conv(4, 4, 32, 64, pad=1, str=2), | ||
|
||
x->mat(x), | ||
|
||
Dense(64 * 7^2, nh), | ||
Dense(nh, 2 * nz), | ||
)) | ||
|
||
|
||
# Definition of the Decoder | ||
θ = Chain(( | ||
Dense(nz, nh), | ||
Dense(nh, 64 * 7^2), | ||
|
||
x->reshape(x, (7, 7, 64, :)), | ||
|
||
DeConv(4, 4, 32, 64, pad=1, str=2), | ||
DeConv(3, 3, 32, 32, pad=1), | ||
DeConv(4, 4, 16, 32, pad=1, str=2), | ||
DeConv(3, 3, 1, 16, f=sigm, pad=1), | ||
)) | ||
|
||
# Initialize the autoencoder with Encoder and Decoder | ||
ae = Autoencoder(ϕ, θ) | ||
|
||
# Load dataset | ||
include(Knet.dir("data", "mnist.jl")) | ||
dtrn, dtst = mnistdata() | ||
|
||
|
||
""" | ||
Visualize the progress during training | ||
""" | ||
function cb_plot(ae, img, epoch) | ||
img_o = convert(Array{Float64}, img) | ||
img_r = convert(Array{Float64}, ae.θ(ae.ϕ(img)[1:nz, :])) | ||
|
||
figure("Epoch $epoch") | ||
clf() | ||
subplot(1, 2, 1) | ||
title("Original") | ||
imshow(img_o[:, :, 1, 1]) | ||
subplot(1, 2, 2) | ||
title("Reproduced") | ||
imshow(img_r[:, :, 1, 1]) | ||
end | ||
|
||
|
||
""" | ||
Main function for training | ||
Questions to: [email protected] | ||
""" | ||
function train(ae, dtrn, iters) | ||
img = convert(Atype, reshape(dtrn.x[:,1], (28, 28, 1, 1))) | ||
for epoch = 1:iters | ||
@time adam!(ae, dtrn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I'm not wrong, this is not the correct way to iterate over epochs, since here each time a new Adam struct is created and information (e.g. accumulated moments) from previous epochs are lost There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes completely correct! Thanks for the advice, i have adapted my example proposal accordingly. |
||
|
||
if (epoch % 20) == 0 | ||
@show ae(first(dtrn)...) | ||
cb_plot(ae, img, epoch) | ||
end | ||
end | ||
end | ||
|
||
# Precompile | ||
@info "Precompile" | ||
ae(first(dtrn)...) | ||
@time adam!(ae, dtrn) | ||
|
||
# Train | ||
@info "Start training!" | ||
@time train(ae, dtrn, 50) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably not efficient, you can run all samples at once by additional "sample" batching. First, you need to reshape
μ
to (nz,B,1), then you need to sample from randn with size (nz,B,Nsample) and broadcastμ
on it. Then, you can changebinary_cross_entropy
to deal with (nz,B,Nsample) input.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the form was not efficient for sampling multiple times within one batch. The Suggestion to broadcast is more effecient and much faster. However, I was not able to broadcast this efficiently trough the decoder network. As the sampling doesn't increase performance as far as i can tell and the majority of implementations i found do not use it, i have also abandoned this for the example here.