Skip to content

Commit

Permalink
Merge pull request #364 from ReactiveBayes/hind-on-normal-usage
Browse files Browse the repository at this point in the history
Restore error message for unparametrized Normal/MvNormal/Gamma
  • Loading branch information
wouterwln authored Oct 3, 2024
2 parents f5401c8 + 1c0e49b commit d9536e2
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 24 deletions.
8 changes: 4 additions & 4 deletions docs/src/manuals/comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ Nowadays there's plenty of probabilistic programming languages and packages avai
using RxInfer #hide
@model function inner_inner(τ, y, x)
y ~ Normal(τ[1], τ[2] + x)
y ~ Normal(mean = τ[1], var = τ[2] + x)
end
@model function inner(θ, α)
β ~ Normal(0, 1)
α ~ Gamma(β, 1)
β ~ Normal(mean = 0.0, var = 1.0)
α ~ Gamma(shape = β, rate = 1.0)
α ~ inner_inner(τ = θ, x = 3)
end
@model function outer()
local w
for i = 1:5
w[i] ~ inner(θ = Gamma(1, 1))
w[i] ~ inner(θ = Gamma(shape = 1.0, rate = 1.0))
end
y ~ inner(θ = w[2:3])
end
Expand Down
8 changes: 4 additions & 4 deletions docs/src/manuals/constraints-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,19 @@ Read more about the `@constraints` macro in the [official documentation](https:/

```@example manual_constraints
@model function inner_inner(τ, y)
y ~ Normal(τ[1], τ[2])
y ~ Normal(mean = τ[1], var = τ[2])
end
@model function inner(θ, α)
β ~ Normal(0, 1)
α ~ Gamma(β, 1)
β ~ Normal(mean = 0.0, var = 1.0)
α ~ Gamma(shape = β, rate = 1.0)
α ~ inner_inner(τ = θ)
end
@model function outer()
local w
for i = 1:5
w[i] ~ inner(θ = Gamma(1, 1))
w[i] ~ inner(θ = Gamma(shape = 1.0, rate = 1.0))
end
y ~ inner(θ = w[2:3])
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manuals/model-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The `@model` macro returns a regular Julia function (in this example `model_name
```@example model-specification-model-macro
using RxInfer #hide
@model function my_model(observation, hyperparameter)
observations ~ Normal(0.0, hyperparameter)
observations ~ Normal(mean = 0.0, var = hyperparameter)
end
```

Expand Down
6 changes: 3 additions & 3 deletions src/model/graphppl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ end

GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :v)}) = ExponentialFamily.NormalMeanVariance
GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Normal}, ::GraphPPL.StaticInterfaces{(:μ, :τ)}) = ExponentialFamily.NormalMeanPrecision
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::Type{Normal}) =
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::GraphPPL.Atomic, ::Type{Normal}, rhs) =
error("`Normal` cannot be constructed without keyword arguments. Use `Normal(mean = ..., var = ...)` or `Normal(mean = ..., precision = ...)`.")

# GraphPPL.interfaces(::ReactiveMPGraphPPLBackend, ::Type{<:ExponentialFamily.NormalMeanVariance}, _) = GraphPPL.StaticInterfaces((:out, :μ, :v))
Expand All @@ -251,15 +251,15 @@ GraphPPL.interface_aliases(::ReactiveMPGraphPPLBackend, ::Type{Normal}) = GraphP

GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{MvNormal}, ::GraphPPL.StaticInterfaces{(:μ, :Σ)}) = ExponentialFamily.MvNormalMeanCovariance
GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{MvNormal}, ::GraphPPL.StaticInterfaces{(:μ, :Λ)}) = ExponentialFamily.MvNormalMeanPrecision
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::Type{MvNormal}) =
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::GraphPPL.Atomic, ::Type{MvNormal}, rhs) =
error("`MvNormal` cannot be constructed without keyword arguments. Use `MvNormal(mean = ..., covariance = ...)` or `MvNormal(mean = ..., precision = ...)`.")

GraphPPL.interface_aliases(::ReactiveMPGraphPPLBackend, ::Type{MvNormal}) =
GraphPPL.StaticInterfaceAliases(((:mean, ), (:m, ), (:covariance, ), (:cov, ), (:Λ⁻¹, ), (:V, ), (:precision, ), (:prec, ), (:W, ), (:Σ⁻¹, )))

GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :θ)}) = ExponentialFamily.GammaShapeScale
GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :β)}) = ExponentialFamily.GammaShapeRate
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::Type{Gamma}) =
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::GraphPPL.Atomic, ::Type{Gamma}, rhs) =
error("`Gamma` cannot be constructed without keyword arguments. Use `Gamma(shape = ..., rate = ...)` or `Gamma(shape = ..., scale = ...)`.")

GraphPPL.interface_aliases(::ReactiveMPGraphPPLBackend, ::Type{Gamma}) =
Expand Down
4 changes: 2 additions & 2 deletions test/ext/ProjectionExt/inference_with_projection_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ end

@model function mymodel(y, C)
a ~ Beta(2, 1)
b ~ Gamma(2, 1)
b ~ Gamma(shape = 2.0, rate = 1.0)
μ := foo(a, b)
for i in eachindex(y)
y[i] ~ Normal(mean = μ, variance = C)
Expand Down Expand Up @@ -448,7 +448,7 @@ end

@model function mymodel(y, C)
a ~ Beta(1, 1)
b ~ Gamma(1, 1)
b ~ Gamma(shape = 1.0, rate = 1.0)
μ := foo(a, b)
for i in eachindex(y)
y[i] ~ MvNormal(mean = μ, covariance = C)
Expand Down
6 changes: 3 additions & 3 deletions test/inference/inference_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
@testitem "__infer_create_factor_graph_model" begin
@model function simple_model_for_infer_create_model(y, a, b)
x ~ Beta(a, b)
y ~ Normal(x, 1.0)
y ~ Normal(mean = x, var = 1.0)
end

import RxInfer: __infer_create_factor_graph_model, ProbabilisticModel, getmodel
Expand All @@ -52,7 +52,7 @@ end
# A simple model for testing that resembles a simple kalman filter with
# random walk state transition and unknown observational noise
@model function test_model1(y)
τ ~ Gamma(1.0, 1.0)
τ ~ Gamma(shape = 1.0, rate = 1.0)

x[1] ~ Normal(mean = 0.0, variance = 1.0)
y[1] ~ Normal(mean = x[1], precision = τ)
Expand Down Expand Up @@ -364,7 +364,7 @@ end
@testitem "Invalid data size error" begin
@model function test_model1(y)
n = length(y)
τ ~ Gamma(1.0, 1.0)
τ ~ Gamma(shape = 1.0, rate = 1.0)

x[1] ~ Normal(mean = 0.0, variance = 1.0)
y[1] ~ Normal(mean = x[1], precision = τ)
Expand Down
2 changes: 1 addition & 1 deletion test/model/graphppl_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
input = :(a = constvar())
@test @capture(apply_pipeline(input, error_datavar_constvar_randomvar), error(_))

input = :(x ~ Normal(0, 1))
input = :(x ~ Normal(mean = 0.0, var = 1.0))
@test apply_pipeline(input, error_datavar_constvar_randomvar) == input
end

Expand Down
12 changes: 6 additions & 6 deletions test/model/initialization_plugin_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ end
import RxInfer: SpecificSubModelInit, InitSpecification, InitDescriptor, InitMarginal, InitObject, GeneralSubModelInit

@model function dummymodel()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
x ~ Normal(mean = 0.0, var = 1.0)
y ~ Normal(mean = x, var = 1.0)
end

@test SpecificSubModelInit(GraphPPL.FactorID(dummymodel, 1), InitSpecification()) isa SpecificSubModelInit
Expand All @@ -42,8 +42,8 @@ end
import RxInfer: SpecificSubModelInit, InitSpecification, InitDescriptor, InitMarginal, InitObject, GeneralSubModelInit

@model function dummymodel()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
x ~ Normal(mean = 0.0, var = 1.0)
y ~ Normal(mean = x, var = 1.0)
end

@test GeneralSubModelInit(dummymodel, InitSpecification()) isa GeneralSubModelInit
Expand Down Expand Up @@ -641,7 +641,7 @@ end
local x
for i in 1:3
for j in 1:3
x[i, j] ~ Normal(0, 1)
x[i, j] ~ Normal(mean = 0.0, var = 1.0)
end
end
end
Expand Down Expand Up @@ -683,7 +683,7 @@ end
@test default_init(some_model) === RxInfer.EmptyInit

@model function model_with_init()
x ~ Normal(0.0, 1.0)
x ~ Normal(mean = 0.0, var = 1.0)
end

default_init(::typeof(model_with_init)) = @initialization begin
Expand Down
10 changes: 10 additions & 0 deletions test/models/aliases/aliases_gamma_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@
@test first(results.free_energy[end]) 4.385584096993327
@test all(<=(1e-14), diff(results.free_energy)) # it oscilates a bit at the end, but all should be less or equal to zero
end

@testitem "`Gamma` by itself cannot be used as a node" begin
@model function gamma_by_itself(d)
x ~ Gamma(1.0, 1.0)
d ~ Gamma(x, 1.0)
end
@test_throws "`Gamma` cannot be constructed without keyword arguments. Use `Gamma(shape = ..., rate = ...)` or `Gamma(shape = ..., scale = ...)`." infer(
model = gamma_by_itself(), data = (d = 1.0,), iterations = 1, free_energy = false
)
end
20 changes: 20 additions & 0 deletions test/models/aliases/aliases_normal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@
@test last(result.free_energy) 2.319611135721246
@test all(iszero, diff(result.free_energy))
end

@testitem "`Normal` by itself cannot be used as a node" begin
@model function normal_by_itself(d)
x ~ Normal(0.0, 1.0)
d ~ Normal(x, 1.0)
end
@test_throws "`Normal` cannot be constructed without keyword arguments. Use `Normal(mean = ..., var = ...)` or `Normal(mean = ..., precision = ...)`." infer(
model = normal_by_itself(), data = (d = 1.0,), iterations = 1, free_energy = false
)
end

@testitem "`MvNormal` by itself cannot be used as a node" begin
@model function mvnormal_by_itself(d)
x ~ MvNormal(zeros(2), diageye(2))
d ~ MvNormal(x, diageye(2))
end
@test_throws "`MvNormal` cannot be constructed without keyword arguments. Use `MvNormal(mean = ..., covariance = ...)` or `MvNormal(mean = ..., precision = ...)`." infer(
model = mvnormal_by_itself(), data = (d = 1.0,), iterations = 1, free_energy = false
)
end

0 comments on commit d9536e2

Please sign in to comment.