Relbo.jl
uses expression tree to represent computational graph of variational inference, doing this, symbolic rewriting technique based on Metatheory.jl
could be used to manipulate the computation procedure at compile time. Rao-Blackwellization rewriting for variance reduction could be easily implemented. It also takes advantages of codegen(meta programming) and automatic differentiation techniques(Zygote.jl
) to handle the numerical calculation, in an efficient way.
This project is motivated by projects in JuliaSymbolic comminities(Metatheory.jl
, Symbolics.jl
) for symbolic rewriting ideas; pyro
for the variational inference as PPL idea. Related projects include Turing.jl
and Soss.jl
. The difference between Relbo.jl
and Turing.jl
|Soss.jl
is that Turing.jl
|Soss.jl
focuses on HMC
based posterial sampling while Relbo.jl
focuses on variational inference(currently it is a very small trial project).
The easest way of using Relbo.jl
is through the dsl and the provided helper functions:
using Relbo
using Relbo: train
elbo = @ELBO ga, gb begin
(i, j, k, l)::Index
data::Observe(q)
(a, b)::Param
a | i
b | i
data | i
ga | i
gb | i
z ~ Beta(a, b)
q ~ InverseGaussian(z)
obsq = q(data)
guide ~ Beta(ga, gb) ≈ z
return Expectation(guide, obsq)
end
data = ones(100)
g = GD(elbo, :data; ga=12.0, gb=4.0, a=10.0, b=3.0)
train(g::GD, data, 10)
The expression tree of sf_grad_elbo_eval
could be ploted using AbstractTrees:
using Relbo
using AbstractTrees
print_tree(sf_grad_elbo_eval |> to_tree; maxdepth=20)
The result is as below:
"ExprTerm_*"
└─ Dict{Any, Any}
├─ ""
│ └─ "ExprTerm_*"
│ └─ Dict{Any, Any}
│ ├─ ""
│ │ └─ "ExprTerm_observe"
│ │ └─ Dict{Any, Any}
│ │ ├─ ""
│ │ │ └─ "Atom_q_InverseGaussian_observe"
│ │ │ └─ ""
│ │ │ └─ "Atom_z_Beta_sampling"
│ │ │ └─ Dict{Any, Any}
│ │ │ ├─ ""
│ │ │ │ └─ "Param_ga"
│ │ │ │ └─ ""
│ │ │ │ └─ "nothing"
│ │ │ └─ ""
│ │ │ └─ "Param_gb"
│ │ │ └─ ""
│ │ │ └─ "nothing"
│ │ └─ ""
│ │ └─ "Atom_q_data_observe"
│ │ └─ ""
│ │ └─ "Param_data"
│ │ └─ ""
│ │ └─ "nothing"
│ └─ ""
│ └─ "ExprTerm_observe"
│ └─ Dict{Any, Any}
│ ├─ ""
│ │ └─ "Atom_z_data_observe"
│ │ └─ ""
│ │ └─ "Param_z"
│ │ └─ ""
│ │ └─ "Atom_z_Beta_sampling"
│ │ └─ Dict{Any, Any}
│ │ ├─ ""
│ │ │ └─ "Param_gb"
│ │ │ └─ ""
│ │ │ └─ "nothing"
│ │ └─ ""
│ │ └─ "Param_ga"
│ │ └─ ""
│ │ └─ "nothing"
│ └─ ""
│ └─ "Atom_z_Beta_observe"
│ └─ Dict{Any, Any}
│ ├─ ""
│ │ └─ "Param_a"
│ │ └─ ""
│ │ └─ "nothing"
│ └─ ""
│ └─ "Param_b"
│ └─ ""
│ └─ "nothing"
└─ ""
└─ "ExprTerm_grad"
└─ ""
└─ "ExprTerm_log"
└─ ""
└─ "ExprTerm_observe"
└─ Dict{Any, Any}
├─ ""
│ └─ "Atom_z_Beta_observe"
│ └─ Dict{Any, Any}
│ ├─ ""
│ │ └─ "Param_gb"
│ │ └─ ""
│ │ └─ "nothing"
│ └─ ""
│ └─ "Param_ga"
│ └─ ""
│ └─ "nothing"
└─ ""
└─ "Atom_z_data_observe"
└─ ""
└─ "Param_z"
└─ ""
└─ "Atom_z_Beta_sampling"
└─ Dict{Any, Any}
├─ ""
│ └─ "Param_ga"
│ └─ ""
│ └─ "nothing"
└─ ""
└─ "Param_gb"
└─ ""
└─ "nothing"
It just use the term rewriting functions provided by Metatheory.jl
. Below shows how to define a score function estimator rewriter
to transform a gradient over intergal into a score function version, which has smaller variance.
function sf_estimator(x::ExprTerm)
grad_op = x.op
@assert is_single_arg(x)
elbo = get_single_arg(x)
guide = elbo.args[1]
nelbo, nguide = copy(elbo), copy(guide)
log_guide = ExprTerm(FunctorOperation(:log), nguide)
grad_log_guide = ExprTerm(grad_op, log_guide)
push!(nelbo.args, grad_log_guide)
return nelbo
end
function sf_estimator_2expr(x::ExprTerm)
elbo = sf_estimator(x)
return :($elbo)
end
sf_grad_rule = @rule x x::ExprTerm => sf_estimator_2expr(x) where is_sf_grad(x)
sf_grad_rule = Postwalk(PassThrough(sf_grad_rule))
Fow more informations of how to manipulate the expression tree, see src/rewrite.jl
The expression tree of Relbo.jl
could be easily transformed into runnable code using cgen
functions provided in src/codegen.jl
, the code for gradients are generated using Zygote.jl
. The resulting code is easily broadcastable along batch dimension, which allows it to be scaled up easily.
code = cgen(sf_grad_elbo_eval)
grad_func = sampling_fun_generator(code, [:data, :ga, :gb, :a, :b], true)
ga = 12
gb = 4
a = 10
b = 11
data = rand(100)
@time grad_func.(data, Ref(ga), Ref(gb), Ref(a), Ref(b))
grad = sum(grad_func.(data, Ref(ga), Ref(gb), Ref(a), Ref(b)))
@show size(grad)
The generated grad_func
is as below:
begin
(data, ga, gb, a, b)->begin
begin
z = Beta(ga, gb)
z_observe = rand(z)
var"z##328" = Beta(a, b)
q = InverseGaussian(z_observe)
end
return (pdf(q, data) * pdf(var"z##328", z_observe)) * collect(gradient(((ga, gb)->begin
begin
z = Beta(ga, gb)
end
log(pdf(z, z_observe))
end), (ga, gb)...))
end
end
For more information, see test/test_terms.jl