From 9b7a3028472911272e164632cf56302f91755067 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 25 Aug 2023 20:02:43 +0200 Subject: [PATCH 1/2] fast path for dists in the PointMass constraint --- src/constraints/form/form_point_mass.jl | 3 + test/constraints/form/test_form_point_mass.jl | 62 +++++++++++++------ 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/constraints/form/form_point_mass.jl b/src/constraints/form/form_point_mass.jl index c21ce9bb8..225b962f1 100644 --- a/src/constraints/form/form_point_mass.jl +++ b/src/constraints/form/form_point_mass.jl @@ -82,6 +82,9 @@ call_starting_point(pmconstraint::PointMassFormConstraint, distribution::D) wher ReactiveMP.constrain_form(pmconstraint::PointMassFormConstraint, distribution) = call_optimizer(pmconstraint, distribution) +# There is no need to call the optimizer on a `Distribution` object since they should have a well defined `mode` +ReactiveMP.constrain_form(::PointMassFormConstraint, distribution::Distribution) = PointMass(mode(distribution)) + function default_point_mass_form_constraint_optimizer(::Type{Univariate}, ::Type{Continuous}, constraint::PointMassFormConstraint, distribution) target = let distribution = distribution (x) -> -logpdf(distribution, x[1]) diff --git a/test/constraints/form/test_form_point_mass.jl b/test/constraints/form/test_form_point_mass.jl index 6e4e131d9..a38bc70e7 100644 --- a/test/constraints/form/test_form_point_mass.jl +++ b/test/constraints/form/test_form_point_mass.jl @@ -2,11 +2,29 @@ module RxInferPointMassFormConstraintTest using Test using RxInfer, LinearAlgebra -using Random, StableRNGs, DomainSets +using Random, StableRNGs, DomainSets, Distributions +import ReactiveMP: constrain_form import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_boundaries, call_starting_point, call_optimizer +struct MyDistributionWithMode <: ContinuousUnivariateDistribution + mode :: Float64 +end + +# We are testing specifically that the point mass optimizer does not call `logpdf` and +# chooses a fast path with `mode` for `<: Distribution` objects +Distributions.logpdf(::MyDistributionWithMode, _) = error("This should not be called") +Distributions.mode(d::MyDistributionWithMode) = d.mode +Distributions.support(::MyDistributionWithMode) = RealInterval(-Inf, Inf) + +const arbitrary_dist_1 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(0, 1), x)) +const arbitrary_dist_2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Gamma(1, 1), x)) +const arbitrary_dist_3 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(-10, 10), x)) +const arbitrary_dist_4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 10), x)) +const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 100), x)) + @testset "PointMassFormConstraint" begin + @testset "is_point_mass_form_constraint" begin @test is_point_mass_form_constraint(PointMassFormConstraint()) end @@ -14,43 +32,51 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou @testset "boundaries" begin constraint = PointMassFormConstraint() - @test call_boundaries(constraint, NormalMeanVariance(0, 1)) === (-Inf, Inf) - @test call_boundaries(constraint, Gamma(1, 1)) === (0.0, Inf) + @test call_boundaries(constraint, arbitrary_dist_1) === (-Inf, Inf) + @test call_boundaries(constraint, arbitrary_dist_2) === (0.0, Inf) bm_constraint = PointMassFormConstraint(boundaries = (args...) -> (-1.0, 1.0)) - @test call_boundaries(bm_constraint, NormalMeanVariance(0, 1)) === (-1.0, 1.0) - @test call_boundaries(bm_constraint, Gamma(1, 1)) === (-1.0, 1.0) + @test call_boundaries(bm_constraint, arbitrary_dist_1) === (-1.0, 1.0) + @test call_boundaries(bm_constraint, arbitrary_dist_2) === (-1.0, 1.0) end @testset "starting point" begin constraint = PointMassFormConstraint() - @test call_starting_point(constraint, NormalMeanVariance(0, 1)) == [0.0] - @test_throws ErrorException call_starting_point(constraint, Gamma(1, 1)) + @test call_starting_point(constraint, arbitrary_dist_1) == [0.0] + @test_throws ErrorException call_starting_point(constraint, arbitrary_dist_2) bm_constraint = PointMassFormConstraint(starting_point = (args...) -> [1.0]) - @test call_starting_point(bm_constraint, NormalMeanVariance(0, 1)) == [1.0] - @test call_starting_point(bm_constraint, Gamma(1, 1)) == [1.0] + @test call_starting_point(bm_constraint, arbitrary_dist_1) == [1.0] + @test call_starting_point(bm_constraint, arbitrary_dist_2) == [1.0] end @testset "optimizer" begin constraint = PointMassFormConstraint() - @test isapprox(mean(call_optimizer(constraint, NormalMeanVariance(0, 1))), 0.0, atol = 0.1) - @test isapprox(mean(call_optimizer(constraint, NormalMeanVariance(-10, 10))), -10.0, atol = 0.1) - @test_throws ErrorException call_optimizer(constraint, GammaShapeRate(1, 1)) + @test isapprox(mean(constrain_form(constraint, arbitrary_dist_1)), 0.0, atol = 0.1) + @test isapprox(mean(constrain_form(constraint, arbitrary_dist_3)), -10.0, atol = 0.1) + @test_throws ErrorException constrain_form(constraint, arbitrary_dist_2) gopt_constraint = PointMassFormConstraint(starting_point = (args...) -> [1.0]) - @test isapprox(mean(call_optimizer(gopt_constraint, GammaShapeRate(100, 10))), 10, atol = 0.1) - @test isapprox(mean(call_optimizer(gopt_constraint, GammaShapeRate(100, 100))), 1, atol = 0.1) + @test isapprox(mean(constrain_form(gopt_constraint, arbitrary_dist_4)), 10, atol = 0.1) + @test isapprox(mean(constrain_form(gopt_constraint, arbitrary_dist_5)), 1, atol = 0.1) bm_constraint = PointMassFormConstraint(optimizer = (args...) -> PointMass(10.0)) - @test call_optimizer(bm_constraint, NormalMeanVariance(0, 1)) == PointMass(10.0) - @test call_optimizer(bm_constraint, Gamma(1, 1)) == PointMass(10.0) + @test constrain_form(bm_constraint, arbitrary_dist_1) == PointMass(10.0) + @test constrain_form(bm_constraint, arbitrary_dist_2) == PointMass(10.0) + end + + @testset "fast path for Distribution" begin + constraint = PointMassFormConstraint() + + for mode in randn(4) + @test mean(constrain_form(constraint, MyDistributionWithMode(mode))) === mode + end end @testset "optimizer for generic f" begin @@ -64,7 +90,7 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou f = ContinuousUnivariateLogPdf((x) -> logpdf(d1, x) + logpdf(d2, x)) - opt = call_optimizer(constraint, f) + opt = constrain_form(constraint, f) analytical = prod(ProdAnalytical(), d1, d2) @@ -79,7 +105,7 @@ import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_bou f = ContinuousUnivariateLogPdf(DomainSets.HalfLine(), (x) -> logpdf(d1, x) + logpdf(d2, x)) - opt = call_optimizer(constraint, f) + opt = constrain_form(constraint, f) analytical = prod(ProdAnalytical(), d1, d2) From 287aa85e2f3419134233e678b7d22e9cfcded5de Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 25 Aug 2023 20:42:36 +0200 Subject: [PATCH 2/2] style: make format --- test/constraints/form/test_form_point_mass.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/constraints/form/test_form_point_mass.jl b/test/constraints/form/test_form_point_mass.jl index a38bc70e7..4ff075d35 100644 --- a/test/constraints/form/test_form_point_mass.jl +++ b/test/constraints/form/test_form_point_mass.jl @@ -7,8 +7,8 @@ using Random, StableRNGs, DomainSets, Distributions import ReactiveMP: constrain_form import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_boundaries, call_starting_point, call_optimizer -struct MyDistributionWithMode <: ContinuousUnivariateDistribution - mode :: Float64 +struct MyDistributionWithMode <: ContinuousUnivariateDistribution + mode::Float64 end # We are testing specifically that the point mass optimizer does not call `logpdf` and @@ -24,7 +24,6 @@ const arbitrary_dist_4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Ga const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 100), x)) @testset "PointMassFormConstraint" begin - @testset "is_point_mass_form_constraint" begin @test is_point_mass_form_constraint(PointMassFormConstraint()) end