From e2bdfb72bf3b8759d92abfd7a541bd0838a01ba1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 21 Oct 2021 21:42:32 +0100 Subject: [PATCH 01/22] added state_from_transition, parameters and setparameters!! --- src/AbstractMCMC.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index ef23cb51..3ec5616c 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -79,6 +79,36 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end +""" + state_from_transiton(state, transition_prev[, state_prev]) + +Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`. + +Defaults to `setparameters!!(state, parameters(transition_prev))`. +""" +function state_from_transition(state, transition_prev, state_prev) + return state_from_transition(state, transition_prev) +end + +function state_from_transition(state, transition) + return setparameters!!(state, parameters(transition)) +end + +""" + setparameters!!(state, parameters) + +Return new instance of `state` with parameters set to `parameters`. +""" +setparameters!! + +""" + parameters(transition) + +Return parameters in `transition`. +""" +parameters + + include("samplingstats.jl") include("logging.jl") include("interface.jl") From 7fa8de0b2d223cb04b142db59682b897034752c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Oct 2021 12:05:51 +0100 Subject: [PATCH 02/22] Update src/AbstractMCMC.jl Co-authored-by: David Widmann --- src/AbstractMCMC.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 3ec5616c..dd42d345 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -97,7 +97,10 @@ end """ setparameters!!(state, parameters) -Return new instance of `state` with parameters set to `parameters`. +Update the parameters of the `state` with `parameters` and return it. + +If `state` can be updated in-place, it is expected that this function returns `state` with updated +parameters. Otherwise a new `state` object with the new `parameters` is returned. """ setparameters!! From 0a4fd17e56a2ed6d0294f13007e20d1bef6d34e1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:18:01 +0000 Subject: [PATCH 03/22] renamed state_from_transition to updatestate!! --- src/AbstractMCMC.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index dd42d345..26c30f59 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,19 +80,14 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr struct MCMCSerial <: AbstractMCMCEnsemble end """ - state_from_transiton(state, transition_prev[, state_prev]) + updatestate!!(state, transition_prev[, state_prev]) Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`. Defaults to `setparameters!!(state, parameters(transition_prev))`. """ -function state_from_transition(state, transition_prev, state_prev) - return state_from_transition(state, transition_prev) -end - -function state_from_transition(state, transition) - return setparameters!!(state, parameters(transition)) -end +updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev) +updatestate!!(state, transition) = setparameters!!(state, parameters(transition)) """ setparameters!!(state, parameters) From 28bdf911417c285d8bf9018007518082d8858184 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:19:10 +0000 Subject: [PATCH 04/22] adhere to julia convention --- src/AbstractMCMC.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 26c30f59..b420841f 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -97,14 +97,14 @@ Update the parameters of the `state` with `parameters` and return it. If `state` can be updated in-place, it is expected that this function returns `state` with updated parameters. Otherwise a new `state` object with the new `parameters` is returned. """ -setparameters!! +function setparameters!! end """ parameters(transition) Return parameters in `transition`. """ -parameters +function parameters end include("samplingstats.jl") From 86a7826a71dc47717bacce82fb04b54fe9cba217 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:26:55 +0000 Subject: [PATCH 05/22] added docs --- docs/src/api.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 8dcf55f4..3ac5e338 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -76,3 +76,16 @@ For chains of this type, AbstractMCMC defines the following two methods. AbstractMCMC.chainscat AbstractMCMC.chainsstack ``` + +## Interacting with states of samplers + +To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: +```@docs +AbstractMCMC.parameters(state, parameters) +AbstractMCMC.setparameters!!(state, parameters) +``` +and optionally +```@docs +AbstractMCMC.updatestate!!(state, transition, state_prev) +``` +These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers. From e19cea71eb0310b6c7326043c0d76829dcdde50b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:34:12 +0000 Subject: [PATCH 06/22] fixed docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 3ac5e338..1780a03f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -81,7 +81,7 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.parameters(state, parameters) +AbstractMCMC.parameters(state) AbstractMCMC.setparameters!!(state, parameters) ``` and optionally From d86499fdaa8e748cce2b9f4c48f7987dd12a8464 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:51:58 +0000 Subject: [PATCH 07/22] fixed docs --- docs/src/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 1780a03f..9131e059 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -81,8 +81,8 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.parameters(state) -AbstractMCMC.setparameters!!(state, parameters) +AbstractMCMC.parameters +AbstractMCMC.setparameters!! ``` and optionally ```@docs From bce436d7344f6d95290780c7d9de1040ff3f110f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 12:33:37 +0000 Subject: [PATCH 08/22] added example for why updatestate!! is useful --- docs/src/api.md | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 9131e059..67f911d3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -89,3 +89,128 @@ and optionally AbstractMCMC.updatestate!!(state, transition, state_prev) ``` These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers. + +### Example: `MixtureSampler` + +In a `MixtureSampler` we need two things: +- `components`: collection of samplers. +- `weights`: collection of weights representing the probability of chosing the corresponding sampler. + +```julia +struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler + components::C + weights::W +end +``` + +To implement the state, we need to keep track of a couple of things: +- `index`: the index of the sampler used in this `step`. +- `transition`: the transition resulting from this `step`. +- `states`: the current states of _all_ the components. +Two aspects of this might seem a bit strange: +1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. +2. We need to put the `transition` from the `step` into the state. +The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. +For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. maybe the sampler is working in a transformed space but returns the samples in the original space, or maybe the sampler is even independent from the current realizations and the state is simply `nothing`. Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`. +```julia +struct MixtureState{T,S} + index::Int + transition::T + states::S +end +``` +The `step` for a `MixtureSampler` is defined by the following generative process +```math +\begin{aligned} +i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\ +X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) +\end{aligned} +``` +where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and `w_i` denotes the weight/probability of choosing the i-th sampler. +[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: + +```julia +# Update the corresponding state, i.e. `state.states[i]`, using +# the state and transition from the previous iteration. +state_current = AbstractMCMC.updatestate!!( + state.states[i], state.states[i_prev], state.transition +) + +# Take a `step` for this sampler using the updated state. +transition, state_current = AbstractMCMC.step( + rng, model, sampler_current, sampler_state; + kwargs... +) +``` + +The full [`AbstractMCMC.step`](@ref) implementation would then be something like: + +```julia +function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...) + # Sample the component to use in this `step`. + i = rand(Categorical(sampler.weights)) + sampler_current = sampler.components[i] + + # Update the corresponding state, i.e. `state.states[i]`, using + # the state and transition from the previous iteration. + i_prev = state.index + state_current = AbstractMCMC.updatestate!!( + state.states[i], state.states[i_prev], state.transition + ) + + # Take a `step` for this sampler using the updated state. + transition, state_current = AbstractMCMC.step( + rng, model, sampler_current, sampler_state; + kwargs... + ) + + # Create the new states. + # NOTE: A better approach would be to use `Setfield.@set state.states[i] = ...` + # but to keep this demo self-contained, we don't. + states_new = ntuple(1:length(state.states)) do j + if j != i + state.states[i] + else + state_inner + end + end + + # Create the new `MixtureState`. + state_new = MixtureState(i, transition, states_new) + + return transition, state_new +end +``` + +And for the initial [`AbstractMCMC.step`](@ref) we have: + +```julia +function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...) + # Initialize every state. + transitions_and_states = map(sampler.components) do spl + AbstractMCMC.step(rng, model, spl; kwargs...) + end + + # Sample the component to use this `step`. + i = rand(Categorical(sampler.weights)) + # Extract the corresponding transition. + transition = first(transition_and_states[i]) + # Extract states. + states = map(last, transitions_and_states) + # Create new `MixtureState`. + state = MixtureState(i, transition, states) + + return transition, state +end +``` + +To use `MixtureSampler`, one could then do something like + +```julia +sampler = MixtureSampler((0.1, 0.9), (sampler1, sampler2)) +transition, state = AbstractMCMC.step(rng, model, sampler) +while ... + transition, state = AbstractMCMC.step(rng, model, sampler, state) +end +``` + From 21f4d569b8791ead958e01955ce44700715bede0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 12:41:04 +0000 Subject: [PATCH 09/22] improved MixtureState example --- docs/src/api.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 67f911d3..af4f8b10 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -126,7 +126,7 @@ i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\ X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) \end{aligned} ``` -where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and `w_i` denotes the weight/probability of choosing the i-th sampler. +where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler. [`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: ```julia @@ -160,7 +160,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Take a `step` for this sampler using the updated state. transition, state_current = AbstractMCMC.step( - rng, model, sampler_current, sampler_state; + rng, model, sampler_current, state_current; kwargs... ) @@ -204,7 +204,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt end ``` -To use `MixtureSampler`, one could then do something like +To use `MixtureSampler` with two samplers `sampler1` and `sampler2` as components, we'd simply do ```julia sampler = MixtureSampler((0.1, 0.9), (sampler1, sampler2)) @@ -214,3 +214,4 @@ while ... end ``` +As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`, and then the above `step` would just extract the `model` corresponding to the current sampler. This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`. From de0e5b218e9f918d4df8f31d29fecc42581bd9fb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 12:47:58 +0000 Subject: [PATCH 10/22] further improvements to docs --- docs/src/api.md | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index af4f8b10..a13d748a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -110,8 +110,12 @@ To implement the state, we need to keep track of a couple of things: Two aspects of this might seem a bit strange: 1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. 2. We need to put the `transition` from the `step` into the state. + The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. -For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. maybe the sampler is working in a transformed space but returns the samples in the original space, or maybe the sampler is even independent from the current realizations and the state is simply `nothing`. Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`. + +For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`. + +Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`. ```julia struct MixtureState{T,S} index::Int @@ -127,7 +131,9 @@ X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) \end{aligned} ``` where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler. -[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: +[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. + +If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: ```julia # Update the corresponding state, i.e. `state.states[i]`, using @@ -214,4 +220,24 @@ while ... end ``` -As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`, and then the above `step` would just extract the `model` corresponding to the current sampler. This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`. +As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`: + +```julia +struct ManyModels{M} <: AbstractMCMC.AbstractModel + models::M +end + +Base.getindex(model::ManyModels, I...) = model.models[I...] +``` + +Then the above `step` would just extract the `model` corresponding to the current sampler: + +```julia +# Take a `step` for this sampler using the updated state. +transition, state_current = AbstractMCMC.step( + rng, model[i], sampler_current, state_current; + kwargs... +) +``` + +This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`. From 23b9119e30f613be21cf017fad0a905a47423d17 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Nov 2021 01:12:37 +0000 Subject: [PATCH 11/22] renamed parameters and setparameters!! to values and setvalues!! --- docs/src/api.md | 4 ++-- src/AbstractMCMC.jl | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index a13d748a..6a95ee75 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -81,8 +81,8 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.parameters -AbstractMCMC.setparameters!! +AbstractMCMC.values +AbstractMCMC.setvalues!! ``` and optionally ```@docs diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index b420841f..caf4faac 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,27 +84,27 @@ struct MCMCSerial <: AbstractMCMCEnsemble end Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`. -Defaults to `setparameters!!(state, parameters(transition_prev))`. +Defaults to `setvalues!!(state, values(transition_prev))`. """ updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev) -updatestate!!(state, transition) = setparameters!!(state, parameters(transition)) +updatestate!!(state, transition) = setvalues!!(state, values(transition)) """ - setparameters!!(state, parameters) + setvalues!!(state, values) -Update the parameters of the `state` with `parameters` and return it. +Update the values of the `state` with `values` and return it. If `state` can be updated in-place, it is expected that this function returns `state` with updated -parameters. Otherwise a new `state` object with the new `parameters` is returned. +values. Otherwise a new `state` object with the new `values` is returned. """ -function setparameters!! end +function setvalues!! end """ - parameters(transition) + values(transition) -Return parameters in `transition`. +Return values in `transition`. """ -function parameters end +function values end include("samplingstats.jl") From b9f476cf8273b415ab1a1fc8815cc273c8331cec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Nov 2021 01:19:05 +0000 Subject: [PATCH 12/22] fixed typo in docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 6a95ee75..a0c64384 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -175,7 +175,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # but to keep this demo self-contained, we don't. states_new = ntuple(1:length(state.states)) do j if j != i - state.states[i] + state.states[j] else state_inner end From f7b6096644665e1b06f42236a04895038933b9f8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Nov 2021 02:00:10 +0000 Subject: [PATCH 13/22] fixed documenting values --- src/AbstractMCMC.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index caf4faac..40b8c7d6 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -87,7 +87,7 @@ Return new instance of `state` using information from `transition_prev` and, opt Defaults to `setvalues!!(state, values(transition_prev))`. """ updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev) -updatestate!!(state, transition) = setvalues!!(state, values(transition)) +updatestate!!(state, transition) = setvalues!!(state, Base.values(transition)) """ setvalues!!(state, values) @@ -99,12 +99,12 @@ values. Otherwise a new `state` object with the new `values` is returned. """ function setvalues!! end -""" +@doc """ values(transition) Return values in `transition`. """ -function values end +Base.values include("samplingstats.jl") From 4ca57b00656a81da3d9ac902521e42193328aba1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Nov 2021 02:00:17 +0000 Subject: [PATCH 14/22] improved and fixed some bugs in docs --- docs/src/api.md | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index a0c64384..69919d5c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -171,13 +171,17 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt ) # Create the new states. - # NOTE: A better approach would be to use `Setfield.@set state.states[i] = ...` - # but to keep this demo self-contained, we don't. - states_new = ntuple(1:length(state.states)) do j - if j != i - state.states[j] + # NOTE: Code below will result in `states_new` begin a `Vector`. + # If we wanted to allow usage of alternative containers, e.g. `Tuple` + # it would be better to use something like `@set states[i] = state_current` + # where `@set` is from Setfield.jl. + states_new = map(1:length(state.states)) do j + if j == i + # Replace the i-th state with the new one. + state_current else - state_inner + # Otherwise we just carry over the previous ones. + state.states[j] end end @@ -200,7 +204,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Sample the component to use this `step`. i = rand(Categorical(sampler.weights)) # Extract the corresponding transition. - transition = first(transition_and_states[i]) + transition = first(transitions_and_states[i]) # Extract states. states = map(last, transitions_and_states) # Create new `MixtureState`. @@ -210,10 +214,19 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt end ``` -To use `MixtureSampler` with two samplers `sampler1` and `sampler2` as components, we'd simply do +Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `values` and `setvalues!!`: + +```julia +function AbstractMCMC.updatestate!!(::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) + # Let's `deepcopy` just to be certain. + return deepcopy(state_prev) +end +``` + +To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do ```julia -sampler = MixtureSampler((0.1, 0.9), (sampler1, sampler2)) +sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9]) transition, state = AbstractMCMC.step(rng, model, sampler) while ... transition, state = AbstractMCMC.step(rng, model, sampler, state) From abebd599209f2bfc41f382c84ea1f5089c8ff7e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Nov 2021 02:17:04 +0000 Subject: [PATCH 15/22] fixed typo in docs --- docs/src/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 69919d5c..1dabc67a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -171,8 +171,8 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt ) # Create the new states. - # NOTE: Code below will result in `states_new` begin a `Vector`. - # If we wanted to allow usage of alternative containers, e.g. `Tuple` + # NOTE: Code below will result in `states_new` being a `Vector`. + # If we wanted to allow usage of alternative containers, e.g. `Tuple`, # it would be better to use something like `@set states[i] = state_current` # where `@set` is from Setfield.jl. states_new = map(1:length(state.states)) do j From d1d4642d7d625577caae74d0e672c99d32c8e04f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 01:55:04 +0000 Subject: [PATCH 16/22] renamed values and setvalues!! to realize and realize!! --- docs/src/api.md | 6 +++--- src/AbstractMCMC.jl | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 1dabc67a..e1f78e42 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -81,8 +81,8 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.values -AbstractMCMC.setvalues!! +AbstractMCMC.realize +AbstractMCMC.realize!! ``` and optionally ```@docs @@ -214,7 +214,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt end ``` -Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `values` and `setvalues!!`: +Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`: ```julia function AbstractMCMC.updatestate!!(::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 40b8c7d6..e88734ed 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,27 +84,27 @@ struct MCMCSerial <: AbstractMCMCEnsemble end Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`. -Defaults to `setvalues!!(state, values(transition_prev))`. +Defaults to `realize!!(state, realize(transition_prev))`. """ updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev) -updatestate!!(state, transition) = setvalues!!(state, Base.values(transition)) +updatestate!!(state, transition) = realize!!(state, realize(transition)) """ - setvalues!!(state, values) + realize!!(state, realization) -Update the values of the `state` with `values` and return it. +Update the realization of the `state` with `realization` and return it. If `state` can be updated in-place, it is expected that this function returns `state` with updated -values. Otherwise a new `state` object with the new `values` is returned. +realize. Otherwise a new `state` object with the new `realization` is returned. """ -function setvalues!! end +function realize!! end -@doc """ - values(transition) +""" + realize(transition) -Return values in `transition`. +Return the realization of the random variables present in `transition`. """ -Base.values +function realize end include("samplingstats.jl") From c6c9554229c1b93f15e56f99281c593a646b15f5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 02:12:16 +0000 Subject: [PATCH 17/22] added model to updatestate!! --- docs/src/api.md | 4 ++-- src/AbstractMCMC.jl | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e1f78e42..6be52d6d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -161,7 +161,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # the state and transition from the previous iteration. i_prev = state.index state_current = AbstractMCMC.updatestate!!( - state.states[i], state.states[i_prev], state.transition + model, state.states[i], state.states[i_prev], state.transition ) # Take a `step` for this sampler using the updated state. @@ -217,7 +217,7 @@ end Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`: ```julia -function AbstractMCMC.updatestate!!(::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) +function AbstractMCMC.updatestate!!(model, ::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) # Let's `deepcopy` just to be certain. return deepcopy(state_prev) end diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index e88734ed..b1dc6b7d 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,14 +80,16 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr struct MCMCSerial <: AbstractMCMCEnsemble end """ - updatestate!!(state, transition_prev[, state_prev]) + updatestate!!(model, state, transition_prev[, state_prev]) -Return new instance of `state` using information from `transition_prev` and, optionally, `state_prev`. +Return new instance of `state` using information from `model`, `transition_prev` and, optionally, `state_prev`. Defaults to `realize!!(state, realize(transition_prev))`. """ -updatestate!!(state, transition_prev, state_prev) = updatestate!!(state, transition_prev) -updatestate!!(state, transition) = realize!!(state, realize(transition)) +function updatestate!!(model, state, transition_prev, state_prev) + return updatestate!!(state, transition_prev) +end +updatestate!!(model, state, transition) = realize!!(state, realize(transition)) """ realize!!(state, realization) From 600d36cb556ccfc906833ff6363fa69caa7e960d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Oct 2024 09:28:34 -0400 Subject: [PATCH 18/22] Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 57 +++++++++++---------------------------------- src/AbstractMCMC.jl | 32 +++++++++---------------- 2 files changed, 25 insertions(+), 64 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 648a87b8..f3ce4271 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -112,8 +112,8 @@ AbstractMCMC.chainsstack To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: ```@docs -AbstractMCMC.realize -AbstractMCMC.realize!! +AbstractMCMC.getparams +AbstractMCMC.setparams!! ``` and optionally ```@docs @@ -125,7 +125,7 @@ These methods can also be useful for implementing samplers which wraps some inne In a `MixtureSampler` we need two things: - `components`: collection of samplers. -- `weights`: collection of weights representing the probability of chosing the corresponding sampler. +- `weights`: collection of weights representing the probability of choosing the corresponding sampler. ```julia struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler @@ -136,7 +136,6 @@ end To implement the state, we need to keep track of a couple of things: - `index`: the index of the sampler used in this `step`. -- `transition`: the transition resulting from this `step`. - `states`: the current states of _all_ the components. Two aspects of this might seem a bit strange: 1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. @@ -146,11 +145,9 @@ The reason for (1) is that lots of samplers keep track of more than just the pre For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`. -Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`. ```julia -struct MixtureState{T,S} +struct MixtureState{S} index::Int - transition::T states::S end ``` @@ -162,15 +159,16 @@ X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) \end{aligned} ``` where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler. -[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. +[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: ```julia # Update the corresponding state, i.e. `state.states[i]`, using # the state and transition from the previous iteration. -state_current = AbstractMCMC.updatestate!!( - state.states[i], state.states[i_prev], state.transition +state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), ) # Take a `step` for this sampler using the updated state. @@ -191,8 +189,9 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Update the corresponding state, i.e. `state.states[i]`, using # the state and transition from the previous iteration. i_prev = state.index - state_current = AbstractMCMC.updatestate!!( - model, state.states[i], state.states[i_prev], state.transition + state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), ) # Take a `step` for this sampler using the updated state. @@ -217,7 +216,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt end # Create the new `MixtureState`. - state_new = MixtureState(i, transition, states_new) + state_new = MixtureState(i, states_new) return transition, state_new end @@ -239,20 +238,14 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt # Extract states. states = map(last, transitions_and_states) # Create new `MixtureState`. - state = MixtureState(i, transition, states) + state = MixtureState(i, states) return transition, state end ``` -Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`: +Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`. -```julia -function AbstractMCMC.updatestate!!(model, ::AdvancedMH.Transition, state_prev::AdvancedMH.Transition) - # Let's `deepcopy` just to be certain. - return deepcopy(state_prev) -end -``` To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do @@ -263,25 +256,3 @@ while ... transition, state = AbstractMCMC.step(rng, model, sampler, state) end ``` - -As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`: - -```julia -struct ManyModels{M} <: AbstractMCMC.AbstractModel - models::M -end - -Base.getindex(model::ManyModels, I...) = model.models[I...] -``` - -Then the above `step` would just extract the `model` corresponding to the current sampler: - -```julia -# Take a `step` for this sampler using the updated state. -transition, state_current = AbstractMCMC.step( - rng, model[i], sampler_current, state_current; - kwargs... -) -``` - -This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`. diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 07960440..687d8f83 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,35 +80,25 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr struct MCMCSerial <: AbstractMCMCEnsemble end """ - updatestate!!(model, state, transition_prev[, state_prev]) + getparams(state[; kwargs...]) -Return new instance of `state` using information from `model`, `transition_prev` and, optionally, `state_prev`. - -Defaults to `realize!!(state, realize(transition_prev))`. +Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`. """ -function updatestate!!(model, state, transition_prev, state_prev) - return updatestate!!(state, transition_prev) -end -updatestate!!(model, state, transition) = realize!!(state, realize(transition)) +function getparams end """ - realize!!(state, realization) - -Update the realization of the `state` with `realization` and return it. + setparams!!(state, params) -If `state` can be updated in-place, it is expected that this function returns `state` with updated -realize. Otherwise a new `state` object with the new `realization` is returned. -""" -function realize!! end +Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`. -""" - realize(transition) +This function should follow the `BangBang` interface: mutate `state` in-place if possible and +return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters. -Return the realization of the random variables present in `transition`. +Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another +word, the sampler should implement a consistent transformation between its internal representation +and the vector representation of the parameter values. """ -function realize end - - +function setparams!! end include("samplingstats.jl") include("logging.jl") include("interface.jl") From 1bfbef1bbd1fec56fd07dad796a7cbd534bbcd3a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Oct 2024 09:28:47 -0400 Subject: [PATCH 19/22] Update docs/src/api.md Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- docs/src/api.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f3ce4271..d8b8a4f4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -115,10 +115,6 @@ To make it a bit easier to interact with some arbitrary sampler state, we encour AbstractMCMC.getparams AbstractMCMC.setparams!! ``` -and optionally -```@docs -AbstractMCMC.updatestate!!(state, transition, state_prev) -``` These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers. ### Example: `MixtureSampler` From d9480d1616daa965c51bfb4eb96b0030b37779fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Oct 2024 09:29:18 -0400 Subject: [PATCH 20/22] Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- docs/src/api.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index d8b8a4f4..f4d52934 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -133,11 +133,8 @@ end To implement the state, we need to keep track of a couple of things: - `index`: the index of the sampler used in this `step`. - `states`: the current states of _all_ the components. -Two aspects of this might seem a bit strange: -1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. -2. We need to put the `transition` from the `step` into the state. - -The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. +We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. +The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`. From ddb588c3f3607474d9046c75e1a0735858c097f9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Oct 2024 15:56:38 -0400 Subject: [PATCH 21/22] Update docs/src/api.md Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- docs/src/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0dffa348..db4ce565 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -142,7 +142,6 @@ To implement the state, we need to keep track of a couple of things: We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. -For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`. ```julia struct MixtureState{S} From d6ab10afb993c6c02ad8affe70ccf200c635c025 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 11 Oct 2024 20:11:54 +0100 Subject: [PATCH 22/22] version bump --- Project.toml | 2 +- src/AbstractMCMC.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f57b1ff0..c4ef9ae1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.4.0" +version = "5.5.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index e490237c..8343bfa8 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -100,6 +100,7 @@ word, the sampler should implement a consistent transformation between its inter and the vector representation of the parameter values. """ function setparams!! end + include("samplingstats.jl") include("logging.jl") include("interface.jl")