Skip to content

Commit

Permalink
Merge pull request #145 from biaslab/dev-pointmass-dist-fix
Browse files Browse the repository at this point in the history
fast path for dists in the PointMass constraint
  • Loading branch information
bvdmitri authored Aug 28, 2023
2 parents 6ad864e + 287aa85 commit 47618a7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/constraints/form/form_point_mass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ call_starting_point(pmconstraint::PointMassFormConstraint, distribution::D) wher

ReactiveMP.constrain_form(pmconstraint::PointMassFormConstraint, distribution) = call_optimizer(pmconstraint, distribution)

# There is no need to call the optimizer on a `Distribution` object since they should have a well defined `mode`
ReactiveMP.constrain_form(::PointMassFormConstraint, distribution::Distribution) = PointMass(mode(distribution))

function default_point_mass_form_constraint_optimizer(::Type{Univariate}, ::Type{Continuous}, constraint::PointMassFormConstraint, distribution)
target = let distribution = distribution
(x) -> -logpdf(distribution, x[1])
Expand Down
61 changes: 43 additions & 18 deletions test/constraints/form/test_form_point_mass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@ module RxInferPointMassFormConstraintTest

using Test
using RxInfer, LinearAlgebra
using Random, StableRNGs, DomainSets
using Random, StableRNGs, DomainSets, Distributions

import ReactiveMP: constrain_form
import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_boundaries, call_starting_point, call_optimizer

struct MyDistributionWithMode <: ContinuousUnivariateDistribution
mode::Float64
end

# We are testing specifically that the point mass optimizer does not call `logpdf` and
# chooses a fast path with `mode` for `<: Distribution` objects
Distributions.logpdf(::MyDistributionWithMode, _) = error("This should not be called")
Distributions.mode(d::MyDistributionWithMode) = d.mode
Distributions.support(::MyDistributionWithMode) = RealInterval(-Inf, Inf)

const arbitrary_dist_1 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(0, 1), x))
const arbitrary_dist_2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Gamma(1, 1), x))
const arbitrary_dist_3 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(-10, 10), x))
const arbitrary_dist_4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 10), x))
const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 100), x))

@testset "PointMassFormConstraint" begin
@testset "is_point_mass_form_constraint" begin
@test is_point_mass_form_constraint(PointMassFormConstraint())
Expand All @@ -14,43 +31,51 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou
@testset "boundaries" begin
constraint = PointMassFormConstraint()

@test call_boundaries(constraint, NormalMeanVariance(0, 1)) === (-Inf, Inf)
@test call_boundaries(constraint, Gamma(1, 1)) === (0.0, Inf)
@test call_boundaries(constraint, arbitrary_dist_1) === (-Inf, Inf)
@test call_boundaries(constraint, arbitrary_dist_2) === (0.0, Inf)

bm_constraint = PointMassFormConstraint(boundaries = (args...) -> (-1.0, 1.0))

@test call_boundaries(bm_constraint, NormalMeanVariance(0, 1)) === (-1.0, 1.0)
@test call_boundaries(bm_constraint, Gamma(1, 1)) === (-1.0, 1.0)
@test call_boundaries(bm_constraint, arbitrary_dist_1) === (-1.0, 1.0)
@test call_boundaries(bm_constraint, arbitrary_dist_2) === (-1.0, 1.0)
end

@testset "starting point" begin
constraint = PointMassFormConstraint()

@test call_starting_point(constraint, NormalMeanVariance(0, 1)) == [0.0]
@test_throws ErrorException call_starting_point(constraint, Gamma(1, 1))
@test call_starting_point(constraint, arbitrary_dist_1) == [0.0]
@test_throws ErrorException call_starting_point(constraint, arbitrary_dist_2)

bm_constraint = PointMassFormConstraint(starting_point = (args...) -> [1.0])

@test call_starting_point(bm_constraint, NormalMeanVariance(0, 1)) == [1.0]
@test call_starting_point(bm_constraint, Gamma(1, 1)) == [1.0]
@test call_starting_point(bm_constraint, arbitrary_dist_1) == [1.0]
@test call_starting_point(bm_constraint, arbitrary_dist_2) == [1.0]
end

@testset "optimizer" begin
constraint = PointMassFormConstraint()

@test isapprox(mean(call_optimizer(constraint, NormalMeanVariance(0, 1))), 0.0, atol = 0.1)
@test isapprox(mean(call_optimizer(constraint, NormalMeanVariance(-10, 10))), -10.0, atol = 0.1)
@test_throws ErrorException call_optimizer(constraint, GammaShapeRate(1, 1))
@test isapprox(mean(constrain_form(constraint, arbitrary_dist_1)), 0.0, atol = 0.1)
@test isapprox(mean(constrain_form(constraint, arbitrary_dist_3)), -10.0, atol = 0.1)
@test_throws ErrorException constrain_form(constraint, arbitrary_dist_2)

gopt_constraint = PointMassFormConstraint(starting_point = (args...) -> [1.0])

@test isapprox(mean(call_optimizer(gopt_constraint, GammaShapeRate(100, 10))), 10, atol = 0.1)
@test isapprox(mean(call_optimizer(gopt_constraint, GammaShapeRate(100, 100))), 1, atol = 0.1)
@test isapprox(mean(constrain_form(gopt_constraint, arbitrary_dist_4)), 10, atol = 0.1)
@test isapprox(mean(constrain_form(gopt_constraint, arbitrary_dist_5)), 1, atol = 0.1)

bm_constraint = PointMassFormConstraint(optimizer = (args...) -> PointMass(10.0))

@test call_optimizer(bm_constraint, NormalMeanVariance(0, 1)) == PointMass(10.0)
@test call_optimizer(bm_constraint, Gamma(1, 1)) == PointMass(10.0)
@test constrain_form(bm_constraint, arbitrary_dist_1) == PointMass(10.0)
@test constrain_form(bm_constraint, arbitrary_dist_2) == PointMass(10.0)
end

@testset "fast path for Distribution" begin
constraint = PointMassFormConstraint()

for mode in randn(4)
@test mean(constrain_form(constraint, MyDistributionWithMode(mode))) === mode
end
end

@testset "optimizer for generic f" begin
Expand All @@ -64,7 +89,7 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou

f = ContinuousUnivariateLogPdf((x) -> logpdf(d1, x) + logpdf(d2, x))

opt = call_optimizer(constraint, f)
opt = constrain_form(constraint, f)

analytical = prod(ProdAnalytical(), d1, d2)

Expand All @@ -79,7 +104,7 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou

f = ContinuousUnivariateLogPdf(DomainSets.HalfLine(), (x) -> logpdf(d1, x) + logpdf(d2, x))

opt = call_optimizer(constraint, f)
opt = constrain_form(constraint, f)

analytical = prod(ProdAnalytical(), d1, d2)

Expand Down

0 comments on commit 47618a7

Please sign in to comment.