Skip to content

Commit

Permalink
Reopening of previous PR (#309)
Browse files Browse the repository at this point in the history
* fixed some signatures for Model

* fixed a method call

* fixed method signatures

* sort of fixed the matchingvalue functionality for model

* formatting

* removed redundant _tilde method

* removed left-over acclogp! that should not be here anymore

* export SamplingContext

* use context instead of ctx to refer to contexts

* formatting

* use context instead of ctx for variables

* use context instead of ctx to refer to contexts

* Update src/compiler.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/context_implementations.jl

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* added some whitespace to some docstrings

* deprecated tilde and dot_tilde plus exported new versions

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* minor version bump

* added impl of matchingvalue for contexts

* reverted the change that makes assume always resample

* removed the inds arguments from assume and dot_assume to stay non-breaking

* Update src/context_implementations.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added missing sampler arg to tilde_observe

* added missing sampler argument in dot_tilde_observe

* fixed order of arguments in some dot_assume calls

* formatting

* formatting

* added missing sampler argument in tilde_observe for SamplingContext

* added missing word in a docstring

* updated submodel macro

* removed unwrap_childcontext and related since its not needed for this PR

* updated submodel macro

* fixed evaluation implementations of dot_assume

* updated pointwise_loglikelihoods and related

* added proper tests for pointwise_loglikelihoods

* updated DPPL tests to reflect recent changes

* bump minor version since this will be breaking

* formatting

* formatting

* renamed mean_of_mean_models used in tests

* bumped dppl version in integration tests

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* fixed ambiguity error

* Introduction of `SamplingContext`: keeping it simple (#259)

This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc.

Co-authored-by: Hong Ge <[email protected]>

* Update src/DynamicPPL.jl

Co-authored-by: David Widmann <[email protected]>

* added initial impl of SimpleVarInfo

* remove unnecessary debug statements to be compat with Zygote

* make reconstruct slightly more generic

* added a couple of convenience constructors

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* small fix

* return var_info from tilde-statements, allowing impl of immutable versions

* allow usage of non-Ref types in SimpleVarInfo

* update submodel-macro

* formatting and docstring for submodel-macro

* attempt at supporting implicit returns too

* added a small comment

* simplifed submodel macro a bit

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed typo

* use bang-bang convention

* updated PointwiseLikelihoodContext

* fixed issue where we unnecessarily replace the return-statement

* check subtype in the retval

* formatting

* fixed type-instability in retval check

* introduced evaluate method for model

* remove unnecessary type-requirement

* make return-value check much nicer

* removed redundant creation of anonymous function

* dont use UnionAll in return_values

* updated tests for submodel to reflect new syntax

* moved to using BangBang-convention for most methods

* remove SimpleVarInfo from this branch

* added a comment

* reverted submodel macro to use = rather than ~

* updated SimpleVarInfo impl

* added a couple of missing deprecations

* updated tests

* updated implementations of logjoint and others

* formatting

* added eltype impl for SimpleVarInfo

* formatting

* fixed eltype for SimpleVarInfo

* implement setindex!! in prep for allowing sampling with immutable vi

* formatting

* initial work on allowing sampling using SimpleVarInfo

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add constructor for SimpleVarInfo using model

* improved leftover to_namedtuple_expr, fixing a bug when used with Zygote

* bumped patch version

* fixed set_flag!!

* forgot the return in the replace_returns

* bigboy update to benchmarks

* fixed some issues and added support for usage of Dict in SimpleVarInfo

* added docstring and improved indexing behvaior for SimpleVarInfo

* formatting

* dont allow sampling with indexing when using SimpleVarInfo with NamedTuple unless shapes are specified

* _setval_kernel and others are only supported by VarInfo atm

* fixed typo in comment

* added more values_as impls

* removed redundant values_from_metadata

* fixed bug in push!! for SimpleVarInfo

* forgot which branch Im on

* added handling of short defs in replace_returns and more docstrings

* fixed bug in generate_tilde introduced in a merge

* fixed a bug in isfuncdef

* fixed tests

* formatting

* uncomment mistakenly commented code

* bumped version

* updated doctests

* dont carry over bang-bang versions that we dont need for general varinfos

* Apply suggestions from @phipsgabler

Co-authored-by: Philipp Gabler <[email protected]>

* updated tests

* removed unnecessary BangBang methods

* fixed zygote rule for dot_observe

* fixed Setfield.jl + returning VarInfo bug in model-macro

* updated tests

* fixed docs

* formatting

* fixed issues when using ThreadSafeVarInfo

* fixed _pointwise_observe for ThreadSafeVarInfo

* updated ThreadSafeVarInfo

* made SimpleVarInfo compat with ThreadSafeVarInfo and added show

* added some tests for return-values of models

* formatting

* fixed doctest for SimpleVarInfo

* formatting

* removed comparison of show from doctest for SimpleVarInfo

* Update src/compiler.jl

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* removed OrderedCollections from docs

* some additional fixes

* fixed method ambiguity and some ill-defined map

* renamed evaluate to evaluate!!

* added implementations of haskey, getindex and setindex!! for SimpleVarInfo

* formatting

* dropped redundant definition

* use getproperty instead of getindex

* fixed method-ambiguity and added some comments

* fixed docstring of SimpleVarInfo

* fixed docstrings

* fixed Project.toml for docs

* fixed docstring of canview

* fixed docstrings

* another attempt at fixing docstrings

* added a TODO comment

* remove some output from docstring of SimpleVarInfo

* fixed haskey and hasvalue for AbstractDict

* updated some comments

* updated some errors

* added sampling dot_assume for SimpleVarInfo

* added true versions of density computations to TestUtils

* added tests specific for SimpleVarInfo

* also document TestUtils

* added TestUtils to docs

* fixed setindex!! for SimpleVarInfo using AbstractDict

* added more tests

* formatting

* dont use BangBang for setall!

* revert unnecessary changes to settrans!

* revert unnecessary changes to set_flag!

* revert some changes to docstrings

* fixed some comments and docstrings

* added more convenient logjoint, logprior, and loglikelihood methods

* removed unnecessary export

* fixed export

* use the Setfield impl of getindex, etc. as default and specialize on AbstractDict

* fixed docstrings of logjoint, etc.

* Apply suggestions from code review

Co-authored-by: Philipp Gabler <[email protected]>

* fixed docstring for model

* replaced return_values by capturing return-value from tilde-statements instead

* added some tests for return-value of model

* added broadcast_foreach

* Apply suggestions from @devmotion

Co-authored-by: David Widmann <[email protected]>

* remove broadcast_foreach for now

* some fixes to ThreadSafeVarInfo

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* fixed docstrings

* forgot qualification for set

* formatting

* added comment about why we cant use MacroTools.isdef

* remove unnecessary deprecation

* udpated some docstrings

* fixed more docstrings

* make overloads of BangBang methods qualified

* remove overloading of values and instead use values_as without the type specified

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* renamed hasvalue for SimpleVarInfo to _haskey

* revert changes from previous commit

* minor version bump

* fixed sampling with ThreadSafeVarInfo

* fixed setindex!! for ThreadSafeVarInfo

* fixed eltype for ThreadSafeVarInfo wrapping a SimpleVarInfo

* fixed a test

* relax atol in serialization tests a bit

* temporarily disable Julia 1.3

* relax atol for a prior check

* Improvements to `@submodel` in #309 (#348)

* added prefix keyword argument to submodel-macro

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* converted example in docs into test

* fixed docstring

* Apply suggestions from code review

Co-authored-by: Philipp Gabler <[email protected]>

* removed redundant prefix_submodel_context def and added another example to docstring

* fixed doctests

* attempt at fixing doctests

* another attempt at fixing doctests

* had a typo in docstring

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Philipp Gabler <[email protected]>

* fixed a test case using submodel

* improved docstring according to comments by @devmotion

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Philipp Gabler <[email protected]>
  • Loading branch information
5 people authored Dec 12, 2021
1 parent 12f3b36 commit 1744ba7
Show file tree
Hide file tree
Showing 28 changed files with 1,563 additions and 264 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
version:
- '1.3' # minimum supported version
# - '1.3' # minimum supported version
- '1' # current stable version
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.16.2"
version = "0.17.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Distributions = "0.25"
Documenter = "0.27"
Setfield = "0.7.1, 0.8"
StableRNGs = "1"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ makedocs(;
sitename="DynamicPPL",
format=Documenter.HTML(),
modules=[DynamicPPL],
pages=["Home" => "index.md"],
pages=["Home" => "index.md", "TestUtils" => "test_utils.md"],
strict=true,
checkdocs=:exports,
doctestfilters=[
Expand Down
5 changes: 5 additions & 0 deletions docs/src/test_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# DynamicPPL.TestUtils

```@autodocs
Modules = [DynamicPPL.TestUtils]
```
30 changes: 29 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ZygoteRules: ZygoteRules
using BangBang: BangBang
using Setfield: Setfield

using Setfield: Setfield
using BangBang: BangBang
Expand All @@ -31,15 +32,23 @@ import Base:
keys,
haskey

using BangBang: push!!, empty!!, setindex!!

# VarInfo
export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
push!!,
empty!!,
getlogp,
setlogp!,
acclogp!,
resetlogp!,
setlogp!!,
acclogp!!,
resetlogp!!,
get_num_produce,
set_num_produce!,
reset_num_produce!,
Expand Down Expand Up @@ -139,13 +148,32 @@ include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("threadsafe.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")

include("test_utils.jl")

# Deprecations
@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo)
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution
)
@deprecate push!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler)
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector
)
@deprecate push!(
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector})

@deprecate setlogp!(vi, logp) setlogp!!(vi, logp)
@deprecate acclogp!(vi, logp) acclogp!!(vi, logp)
@deprecate resetlogp!(vi) resetlogp!!(vi)

end # module
4 changes: 2 additions & 2 deletions src/compat/ad.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# See https://github.com/TuringLang/Turing.jl/issues/1199
ChainRulesCore.@non_differentiable push!(
ChainRulesCore.@non_differentiable push!!(
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)

Expand All @@ -16,7 +16,7 @@ ZygoteRules.@adjoint function dot_observe(
)
function dot_observe_fallback(spl, dists, value, vi)
increment_num_produce!(vi)
return sum(map(Distributions.loglikelihood, dists, value))
return sum(map(Distributions.loglikelihood, dists, value)), vi
end
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
end
120 changes: 100 additions & 20 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,12 @@ end

function generate_tilde_literal(left, right)
# If the LHS is a literal, it is always an observation
@gensym value
return quote
$(DynamicPPL.tilde_observe!)(
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
$value
end
end

Expand All @@ -373,7 +375,7 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption
@gensym vn isassumption value

# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
Expand All @@ -389,32 +391,38 @@ function generate_tilde(left, right)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
end

$(DynamicPPL.tilde_observe!)(
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
end
end
end

function generate_tilde_assume(left, right, vn)
expr = :(
$left = $(DynamicPPL.tilde_assume!)(
# HACK: Because the Setfield.jl macro does not support assignment
# with multiple arguments on the LHS, we need to capture the return-values
# and then update the LHS variables one by one.
@gensym value
expr = :($left = $value)
if left isa Expr
expr = AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
end

return quote
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
__context__,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
)

return if left isa Expr
AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
else
return expr
$expr
$value
end
end

Expand All @@ -428,7 +436,7 @@ function generate_dot_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption
@gensym vn isassumption value
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
Expand All @@ -440,13 +448,14 @@ function generate_dot_tilde(left, right)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
end
end
end
Expand All @@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn)
# We don't need to use `Setfield.@set` here since
# `.=` is always going to be inplace + needs `left` to
# be something that supports `.=`.
return :(
$left .= $(DynamicPPL.dot_tilde_assume!)(
@gensym value
return quote
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
)
$left .= $value
$value
end
end

# Note that we cannot use `MacroTools.isdef` because
# of https://github.com/FluxML/MacroTools.jl/issues/154.
"""
isfuncdef(expr)
Return `true` if `expr` is any form of function definition, and `false` otherwise.
"""
function isfuncdef(e::Expr)
return if Meta.isexpr(e, :function)
# Classic `function f(...)`
true
elseif Meta.isexpr(e, :->)
# Anonymous functions/lambdas, e.g. `do` blocks or `->` defs.
true
elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call)
# Short function defs, e.g. `f(args...) = ...`.
true
else
false
end
end

"""
replace_returns(expr)
Return `Expr` with all `return ...` statements replaced with
`return ..., DynamicPPL.return_values(__varinfo__)`.
Note that this method will _not_ replace `return` statements within function
definitions. This is checked using [`isfuncdef`](@ref).
"""
replace_returns(e) = e
function replace_returns(e::Expr)
if isfuncdef(e)
return e
end

if Meta.isexpr(e, :return)
# NOTE: `return` always has an argument. In the case of
# an empty `return`, the lowered expression will be `return nothing`.
# Hence we don't need any special handling for empty returns.
retval_expr = if length(e.args) > 1
Expr(:tuple, e.args...)
else
e.args[1]
end

return :(return ($retval_expr, __varinfo__))
end

return Expr(e.head, map(replace_returns, e.args)...)
end

# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`.
make_returns_explicit!(body) = Expr(:return, body)
function make_returns_explicit!(body::Expr)
# If the last statement is a return-statement, we don't do anything.
# Otherwise we replace the last statement with a `return` statement.
if !Meta.isexpr(body.args[end], :return)
body.args[end] = Expr(:return, body.args[end])
end
return body
end

const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
Expand Down Expand Up @@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode)
# Replace the user-provided function body with the version created by DynamicPPL.
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
# to the call site.
# NOTE: We need to replace statements of the form `return ...` with
# `return (..., __varinfo__)` to ensure that the second
# element in the returned value is always the most up-to-date `__varinfo__`.
# See the docstrings of `replace_returns` for more info.
evaluatordef[:body] = MacroTools.@q begin
$(linenumbernode)
$(modelinfo[:body])
$(replace_returns(make_returns_explicit!(modelinfo[:body])))
end

## Build the model function.
Expand Down
Loading

0 comments on commit 1744ba7

Please sign in to comment.