From f5150d935f322944d703fb9616b5c8863924d7cb Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 30 Sep 2024 12:16:52 +0200 Subject: [PATCH] add documentation for the change --- docs/Project.toml | 1 + docs/src/manuals/model-specification.md | 113 ++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 70ead18e6..db4713649 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Cairo = "159f3aea-2a34-519c-b102-8c37f9878175" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/manuals/model-specification.md b/docs/src/manuals/model-specification.md index ab3d4b6e7..d7e24d53e 100644 --- a/docs/src/manuals/model-specification.md +++ b/docs/src/manuals/model-specification.md @@ -257,6 +257,119 @@ model = RxInfer.create_model(conditioned) GraphPlot.gplot(RxInfer.getmodel(model)) ``` +## Node Contraction + +RxInfer's model specification extension for GraphPPL supports a feature called _node contraction_. This feature allows you to _contract_ (or _replace_) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios: + +- When running inference in a submodel is computationally expensive +- When a submodel contains many variables whose inference results are not of primary importance +- When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel + +Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules. + +```@example node-contraction +using RxInfer, Plots + +@model function ShiftedNormal(data, mean, precision, shift) + shifted_mean := mean + shift + data ~ Normal(mean = shifted_mean, precision = precision) +end + +@model function Model(data, precision, shift) + mean ~ Normal(mean = 15.0, var = 1.0) + data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift) +end + +result = infer( + model = Model(precision = 1.0, shift = 1.0), + data = (data = 10.0, ) +) + +plot(title = "Inference results over `mean`") +plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2) +plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2) +vline!([ 10.0 ], label = "data point") +``` + +As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the [Model structure visualisation](@ref user-guide-model-specification-visualization) section. + +```@example node-contraction +using Cairo, GraphPlot + +GraphPlot.gplot(getmodel(result.model)) +``` + +Now, let's create an optimized version of the `ShiftedNormal` submodel as a standalone node with its own message passing update rules. + +!!! note + Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the [Custom Node](@ref create-node) section. + +```@example node-contraction +@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ] + +@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin + return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision) +end + +result_with_contraction = infer( + model = Model(precision = 1.0, shift = 1.0), + data = (data = 10.0, ), + allow_node_contraction = true +) +using Test #hide +@test result.posteriors[:mean] ≈ result_with_contraction.posteriors[:mean] #hide + +plot(title = "Inference results over `mean` with node contraction") +plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2) +plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2) +vline!([ 10.0 ], label = "data point") +``` + +As you can see, the inference result is identical to the previous case. However, the structure of the model is different: + +```@example node-contraction +GraphPlot.gplot(getmodel(result_with_contraction.model)) +``` + +With node contraction, we no longer have access to the variables defined inside the `ShiftedNormal` submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique [where no explicit message passing update rules are required](@ref inference-undefinedrules). + +We can also verify that node contraction indeed improves the performance of the inference: + +```@example node-contraction +using BenchmarkTools + +benchmark_without_contraction = @benchmark infer( + model = Model(precision = 1.0, shift = 1.0), + data = (data = 10.0, ) +) + +benchmark_with_contraction = @benchmark infer( + model = Model(precision = 1.0, shift = 1.0), + data = (data = 10.0, ), + allow_node_contraction = true +) + +using Test #hide +@test benchmark_with_contraction.allocs < benchmark_without_contraction.allocs #hide +@test mean(benchmark_with_contraction.times) < mean(benchmark_without_contraction.times) #hide +@test median(benchmark_with_contraction.times) < median(benchmark_without_contraction.times) #hide +@test minimum(benchmark_with_contraction.times) < minimum(benchmark_without_contraction.times) #hide +nothing #hide +``` + +Let's examine the benchmark results: + +```@example node-contraction +benchmark_without_contraction +``` + +```@example node-contraction +benchmark_with_contraction +``` + +As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules. +This performance improvement is reflected in reduced execution time and fewer memory allocations. + ### [Node creation options](@id user-guide-model-specification-node-creation-options) `GraphPPL` allows to pass optional arguments to the node creation constructor with the `where { options... }` options specification syntax.