Skip to content

Commit

Permalink
add documentation for the change
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Sep 30, 2024
1 parent 73d1304 commit f5150d9
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
113 changes: 113 additions & 0 deletions docs/src/manuals/model-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,119 @@ model = RxInfer.create_model(conditioned)
GraphPlot.gplot(RxInfer.getmodel(model))
```

## Node Contraction

RxInfer's model specification extension for GraphPPL supports a feature called _node contraction_. This feature allows you to _contract_ (or _replace_) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios:

- When running inference in a submodel is computationally expensive
- When a submodel contains many variables whose inference results are not of primary importance
- When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel

Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules.

```@example node-contraction
using RxInfer, Plots
@model function ShiftedNormal(data, mean, precision, shift)
shifted_mean := mean + shift
data ~ Normal(mean = shifted_mean, precision = precision)
end
@model function Model(data, precision, shift)
mean ~ Normal(mean = 15.0, var = 1.0)
data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift)
end
result = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
plot(title = "Inference results over `mean`")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the [Model structure visualisation](@ref user-guide-model-specification-visualization) section.

```@example node-contraction
using Cairo, GraphPlot
GraphPlot.gplot(getmodel(result.model))
```

Now, let's create an optimized version of the `ShiftedNormal` submodel as a standalone node with its own message passing update rules.

!!! note
Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the [Custom Node](@ref create-node) section.

```@example node-contraction
@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ]
@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin
return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision)
end
result_with_contraction = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
using Test #hide
@test result.posteriors[:mean] ≈ result_with_contraction.posteriors[:mean] #hide
plot(title = "Inference results over `mean` with node contraction")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As you can see, the inference result is identical to the previous case. However, the structure of the model is different:

```@example node-contraction
GraphPlot.gplot(getmodel(result_with_contraction.model))
```

With node contraction, we no longer have access to the variables defined inside the `ShiftedNormal` submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique [where no explicit message passing update rules are required](@ref inference-undefinedrules).

We can also verify that node contraction indeed improves the performance of the inference:

```@example node-contraction
using BenchmarkTools
benchmark_without_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
benchmark_with_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
using Test #hide
@test benchmark_with_contraction.allocs < benchmark_without_contraction.allocs #hide
@test mean(benchmark_with_contraction.times) < mean(benchmark_without_contraction.times) #hide
@test median(benchmark_with_contraction.times) < median(benchmark_without_contraction.times) #hide
@test minimum(benchmark_with_contraction.times) < minimum(benchmark_without_contraction.times) #hide
nothing #hide
```

Let's examine the benchmark results:

```@example node-contraction
benchmark_without_contraction
```

```@example node-contraction
benchmark_with_contraction
```

As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules.
This performance improvement is reflected in reduced execution time and fewer memory allocations.

### [Node creation options](@id user-guide-model-specification-node-creation-options)

`GraphPPL` allows to pass optional arguments to the node creation constructor with the `where { options... }` options specification syntax.
Expand Down

0 comments on commit f5150d9

Please sign in to comment.