From 9a1bb51641a16dfbbe3dfe288e358e19d4996bb0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 11 Nov 2024 09:50:42 +0000 Subject: [PATCH] Permit opting in to unsafe perturbations (#362) * Add test case * Update tests to do the right thing * Fix up add_to_primal * Bump patch version * Fix error message * Impove docstring * Enable unsafe perturbation for Bijectors * Fix up Bijectors testing * Remove accidentally-commited debug info * Remove redundant code * Write docstring * Remove redundant method --- Project.toml | 2 +- ext/MooncakeCUDAExt.jl | 2 +- src/rrules/array_legacy.jl | 4 +- src/rrules/iddict.jl | 4 +- src/rrules/memory.jl | 14 +-- src/rrules/tasks.jl | 2 +- src/tangents.jl | 87 +++++++++++++++---- src/test_resources.jl | 5 ++ src/test_utils.jl | 41 +++++---- test/ext/dynamic_ppl/dynamic_ppl.jl | 2 +- .../bijectors/bijectors.jl | 10 +-- test/integration_testing/diff_tests.jl | 2 +- test/integration_testing/gp/gp.jl | 44 +++++----- .../temporalgps/temporalgps.jl | 2 +- test/integration_testing/turing/turing.jl | 2 +- test/tangents.jl | 6 ++ 16 files changed, 152 insertions(+), 77 deletions(-) diff --git a/Project.toml b/Project.toml index 12e2f750b..1fdcbad09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.39" +version = "0.4.40" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index 0ab3eb1f2..b201c0574 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -35,7 +35,7 @@ TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y __increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 -_add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y +_add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y _diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y _dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) _scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y diff --git a/src/rrules/array_legacy.jl b/src/rrules/array_legacy.jl index 50329f68f..30b700b4a 100644 --- a/src/rrules/array_legacy.jl +++ b/src/rrules/array_legacy.jl @@ -35,9 +35,9 @@ function _dot(t::T, s::T) where {T<:Array} ) end -function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N} +function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N} x′ = Array{P, N}(undef, size(x)...) - return _map_if_assigned!(_add_to_primal, x′, x, t) + return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t) end function _diff(p::P, q::P) where {V, N, P<:Array{V, N}} diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index 558ad2676..ed8b57a04 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -24,9 +24,9 @@ function _scale(a::Float64, t::IdDict{K, V}) where {K, V} return IdDict{K, V}([k => _scale(a, v) for (k, v) in t]) end _dot(p::T, q::T) where {T<:IdDict} = sum([_dot(p[k], q[k]) for k in keys(p)]; init=0.0) -function _add_to_primal(p::IdDict{K, V}, t::IdDict{K}) where {K, V} +function _add_to_primal(p::IdDict{K, V}, t::IdDict{K}, unsafe::Bool) where {K, V} ks = intersect(keys(p), keys(t)) - return IdDict{K, V}([k => _add_to_primal(p[k], t[k]) for k in ks]) + return IdDict{K, V}([k => _add_to_primal(p[k], t[k], unsafe) for k in ks]) end function _diff(p::P, q::P) where {K, V, P<:IdDict{K, V}} @assert union(keys(p), keys(q)) == keys(p) diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index 5026dcae9..86cce8818 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -72,8 +72,10 @@ end set_to_zero!!(x::Memory) = _map_if_assigned!(set_to_zero!!, x, x) -function _add_to_primal(p::Memory{P}, t::Memory) where {P} - return _map_if_assigned!(_add_to_primal, Memory{P}(undef, length(p)), p, t) +function _add_to_primal(p::Memory{P}, t::Memory, unsafe::Bool) where {P} + return _map_if_assigned!( + (p, t) -> _add_to_primal(p, t, unsafe), Memory{P}(undef, length(p)), p, t + ) end function _diff(p::Memory{P}, q::Memory{P}) where {P} @@ -172,9 +174,9 @@ function _dot(t::T, s::T) where {T<:Array} ) end -function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N} +function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N} x′ = Array{P, N}(undef, size(x)...) - return _map_if_assigned!(_add_to_primal, x′, x, t) + return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t) end function _diff(p::P, q::P) where {P<:Array} @@ -273,7 +275,9 @@ function set_to_zero!!(x::MemoryRef) return x end -_add_to_primal(p::MemoryRef, t::MemoryRef) = construct_ref(p, _add_to_primal(p.mem, t.mem)) +function _add_to_primal(p::MemoryRef, t::MemoryRef, unsafe::Bool) + return construct_ref(p, _add_to_primal(p.mem, t.mem, unsafe)) +end function _diff(p::P, q::P) where {P<:MemoryRef} @assert Core.memoryrefoffset(p) == Core.memoryrefoffset(q) diff --git a/src/rrules/tasks.jl b/src/rrules/tasks.jl index cc7074c80..3b21aad1a 100644 --- a/src/rrules/tasks.jl +++ b/src/rrules/tasks.jl @@ -17,7 +17,7 @@ increment!!(t::TaskTangent, s::TaskTangent) = t set_to_zero!!(t::TaskTangent) = t -_add_to_primal(p::Task, t::TaskTangent) = p +_add_to_primal(p::Task, t::TaskTangent, ::Bool) = p _diff(::Task, ::Task) = TaskTangent() diff --git a/src/tangents.jl b/src/tangents.jl index a8558087c..b77c03b4b 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -637,37 +637,92 @@ function _dot(t::T, s::T) where {T<:Union{Tangent, MutableTangent}} end """ - _add_to_primal(p::P, t::T) where {P, T} - -Required for testing. -_Not_ currently defined by default. -`_containerlike_add_to_primal` is potentially what you want to target when implementing for -a particular primal-tangent pair. + _add_to_primal(p::P, t::T, unsafe::Bool=false) where {P, T} Adds `t` to `p`, returning a `P`. It must be the case that `tangent_type(P) == T`. + +If `unsafe` is `true` and `P` is a composite type, then `_add_to_primal` will construct a +new instance of `P` by directly invoking the `:new` instruction for `P`, rather than +attempting to use the default constructor for `P`. This is fine if you are confident that +the new `P` constructed by adding `t` to `p` will always be a valid instance of `P`, but +could cause problems if you are not confident of this. + +This is, for example, fine for the following type: +```julia +struct Foo{T} + x::Vector{T} + y::Vector{T} + function Foo(x::Vector{T}, y::Vector{T}) where {T} + @assert length(x) == length(y) + return new{T}(x, y) + end +end +``` +Here, the value returned by `_add_to_primal` will satisfy the invariant asserted in the +inner constructor for `Foo`. """ -_add_to_primal(x, ::NoTangent) = x -_add_to_primal(x::T, t::T) where {T<:IEEEFloat} = x + t -function _add_to_primal(x::SimpleVector, t::Vector{Any}) - return svec(map(n -> _add_to_primal(x[n], t[n]), eachindex(x))...) +_add_to_primal(p, t) = _add_to_primal(p, t, false) +_add_to_primal(x, ::NoTangent, ::Bool) = x +_add_to_primal(x::T, t::T, ::Bool) where {T<:IEEEFloat} = x + t +function _add_to_primal(x::SimpleVector, t::Vector{Any}, unsafe::Bool) + return svec(map(n -> _add_to_primal(x[n], t[n], unsafe), eachindex(x))...) +end +function _add_to_primal(x::Tuple, t::Tuple, unsafe::Bool) + return _map((x, t) -> _add_to_primal(x, t, unsafe), x, t) +end +function _add_to_primal(x::NamedTuple, t::NamedTuple, unsafe::Bool) + return _map((x, t) -> _add_to_primal(x, t, unsafe), x, t) end -_add_to_primal(x::Tuple, t::Tuple) = _map(_add_to_primal, x, t) -_add_to_primal(x::NamedTuple, t::NamedTuple) = _map(_add_to_primal, x, t) -_add_to_primal(x, ::Tangent{NamedTuple{(), Tuple{}}}) = x -function _add_to_primal(p::P, t::T) where {P, T<:Union{Tangent, MutableTangent}} +struct AddToPrimalException <: Exception + primal_type::Type +end + +function Base.showerror(io::IO, err::AddToPrimalException) + msg = "Attempted to construct an instance of $(err.primal_type) using the default " * + "constuctor. In most cases, this error is caused by the lack of existence of the " * + "default constructor for this type. There are two approaches to dealing with " * + "this problem. The first is to avoid having to call `_add_to_primal` on this " * + "type, which can be achieved by avoiding testing functions whose arguments are " * + "of this type. If this cannot be avoided, you should consider using calling " * + "`Mooncake._add_to_primal` with its third positional argument set to `true`. " * + "If you are using some of Mooncake's testing functionality, this can be achieved " * + "by setting the `unsafe_perturb` setting to `true` -- check the docstring " * + "for `Mooncake._add_to_primal` to ensure that your use case is unlikely to " * + "cause problems." + println(io, msg) +end + +function _add_to_primal(p::P, t::T, unsafe::Bool) where {P, T<:Union{Tangent, MutableTangent}} Tt = tangent_type(P) if Tt != typeof(t) throw(ArgumentError("p of type $P has tangent_type $Tt, but t is of type $T")) end tmp = map(fieldnames(P)) do f tf = getfield(t.fields, f) - isdefined(p, f) && is_init(tf) && return _add_to_primal(getfield(p, f), val(tf)) + isdefined(p, f) && is_init(tf) && return _add_to_primal(getfield(p, f), val(tf), unsafe) !isdefined(p, f) && !is_init(tf) && return FieldUndefined() throw(error("unable to handle undefined-ness")) end i = findfirst(==(FieldUndefined()), tmp) - return i === nothing ? P(tmp...) : P(tmp[1:i-1]...) + + # If unsafe mode is enabled, then call `_new_` directly, and avoid the possibility that + # the default inner constructor for `P` does not exist. + if unsafe + return i === nothing ? _new_(P, tmp...) : _new_(P, tmp[1:i-1]...) + end + + # If unsafe mode is disabled, try to use the default constructor for `P`. If this does + # not work, then throw an informative error message. + try + return i === nothing ? P(tmp...) : P(tmp[1:i-1]...) + catch e + if e isa MethodError + throw(AddToPrimalException(P)) + else + rethrow(e) + end + end end """ diff --git a/src/test_resources.jl b/src/test_resources.jl index 5b32a0a3e..b8ccd5f4e 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -579,6 +579,11 @@ end tuple_with_union(x::Bool) = (x ? 5.0 : 5, nothing) +struct NoDefaultCtor{T} + x::T + NoDefaultCtor(x::T) where {T} = new{T}(x) +end + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), diff --git a/src/test_utils.jl b/src/test_utils.jl index 9da7d370b..ee91cf0f4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -283,35 +283,34 @@ function address_maps_are_consistent(x::AddressMap, y::AddressMap) end # Assumes that the interface has been tested, and we can simply check for numerical issues. -function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rule) - @nospecialize rng f_f̄ x_x̄ +function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool) + @nospecialize rng x_x̄ x_x̄ = map(_deepcopy, x_x̄) # defensive copy # Run original function on deep-copies of inputs. - f = primal(f_f̄) x = map(primal, x_x̄) x̄ = map(tangent, x_x̄) # Run primal, and ensure that we still have access to mutated inputs afterwards. x_primal = _deepcopy(x) - y_primal = f(x_primal...) + y_primal = x_primal[1](x_primal[2:end]...) # Use finite differences to estimate vjps ẋ = map(_x -> randn_tangent(rng, _x), x) ε = 1e-7 - x′ = _add_to_primal(x, _scale(ε, ẋ)) - y′ = f(x′...) + x′ = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) + y′ = x′[1](x′[2:end]...) ẏ = _scale(1 / ε, _diff(y′, y_primal)) ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) - # Run `rrule!!` on copies of `f` and `x`. We use randomly generated tangents so that we + # Run rule on copies of `f` and `x`. We use randomly generated tangents so that we # can later verify that non-zero values do not get propagated by the rule. x̄_zero = map(zero_tangent, x) x̄_fwds = map(Mooncake.fdata, x̄_zero) x_x̄_rule = map((x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), x, x̄_fwds) inputs_address_map = populate_address_map(map(primal, x_x̄_rule), map(tangent, x_x̄_rule)) - y_ȳ_rule, pb!! = rule(to_fwds(f_f̄), x_x̄_rule...) + y_ȳ_rule, pb!! = rule(x_x̄_rule...) # Verify that inputs / outputs are the same under `f` and its rrule. @test has_equal_data(x_primal, map(primal, x_x̄_rule)) @@ -332,8 +331,8 @@ function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rul x̄_init = map(set_to_zero!!, x̄_zero) ȳ = increment!!(ȳ_init, ȳ_delta) map(increment!!, x̄_init, x̄_delta) - _, x̄_rvs_inc... = pb!!(Mooncake.rdata(ȳ)) - x̄_rvs = map((x, x_inc) -> increment!!(rdata(x), x_inc), x̄_delta, x̄_rvs_inc) + x̄_rvs_inc = pb!!(Mooncake.rdata(ȳ)) + x̄_rvs = increment!!(map(rdata, x̄_delta), x̄_rvs_inc) x̄ = map(tangent, x̄_fwds, x̄_rvs) # Check that inputs have been returned to their original value. @@ -481,6 +480,7 @@ __get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) perf_flag::Symbol=:none, interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, + unsafe_perturb::Bool=false, ) Run standardised tests on the `rule` for `x`. @@ -527,6 +527,9 @@ This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will Typically this should be left at its default `false` value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to `true` in order to get access to additional information and automated testing. +- `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`. + Should usually be left `false` -- consult the docstring for `_add_to_primal` for more + info on when you might wish to set it to `true`. """ function test_rule( rng::AbstractRNG, x...; @@ -535,16 +538,16 @@ function test_rule( perf_flag::Symbol=:none, interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, + unsafe_perturb::Bool=false, ) @nospecialize rng x # Construct the rule. - rule = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode) + sig = _typeof(__get_primals(x)) + rule = Mooncake.build_rrule(interp, sig; debug_mode) # If something is primitive, then the rule should be `rrule!!`. - if is_primitive - @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) - end + is_primitive && @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) # Generate random tangents for anything that is not already a CoDual. x_x̄ = map(x -> x isa CoDual ? x : interface_only ? uninit_codual(x) : zero_codual(x), x) @@ -553,14 +556,13 @@ function test_rule( test_rrule_interface(x_x̄...; rule) # Test that answers are numerically correct / consistent. - interface_only || test_rrule_numerical_correctness(rng, x_x̄...; rule) + interface_only || test_rule_correctness(rng, x_x̄...; rule, unsafe_perturb) # Test the performance of the rule. test_rrule_performance(perf_flag, rule, x_x̄...) # Test the interface again, in order to verify that caching is working correctly. - rule_2 = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode) - test_rrule_interface(x_x̄..., rule=rule_2) + test_rrule_interface(x_x̄..., rule=Mooncake.build_rrule(interp, sig; debug_mode)) end @@ -790,6 +792,7 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false) # Verify that operations required for finite difference testing to run, and produce the # correct output type. @test _add_to_primal(p, t) isa P + @test _add_to_primal(p, t, true) isa P @test _diff(p, p) isa T @test _dot(t, t) isa Float64 @test _scale(11.0, t) isa T @@ -798,9 +801,9 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false) # Run some basic numerical sanity checks on the output the functions required for finite # difference testing. These are necessary but insufficient conditions. if !interface_only - @test has_equal_data(_add_to_primal(p, z), p) + @test has_equal_data(_add_to_primal(p, z, true), p) if !has_equal_data(z, r) - @test !has_equal_data(_add_to_primal(p, r), p) + @test !has_equal_data(_add_to_primal(p, r, true), p) end @test has_equal_data(_diff(p, p), zero_tangent(p)) end diff --git a/test/ext/dynamic_ppl/dynamic_ppl.jl b/test/ext/dynamic_ppl/dynamic_ppl.jl index e702fba98..c2098be0c 100644 --- a/test/ext/dynamic_ppl/dynamic_ppl.jl +++ b/test/ext/dynamic_ppl/dynamic_ppl.jl @@ -5,5 +5,5 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DynamicPPL, Mooncake, Test @testset "DynamicPPLMooncakeExt" begin - test_rule(sr(123456), DynamicPPL.istrans, DynamicPPL.VarInfo(); interface_only=true) + test_rule(sr(123456), DynamicPPL.istrans, DynamicPPL.VarInfo(); unsafe_perturb=true) end diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index 9a2a7fa4b..c53b4f266 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -2,7 +2,7 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) -using Bijectors: Bijectors +using Bijectors: Bijectors, inverse using LinearAlgebra: LinearAlgebra using Random: randn @@ -25,8 +25,7 @@ function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) if name === nothing name = string(bijector) end - b_inv = Bijectors.inverse(bijector) - return TestCase(x -> bijector(b_inv(x)), randn(rng, dim); name = name) + return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name = name) end @testset "Bijectors integration tests" begin @@ -43,7 +42,7 @@ end Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), 3, ), - b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + b_binv_test_case(Bijectors.InvertibleBatchNorm(3; eps=1e-5, mtm=1e-1), (3, 3)), b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), b_binv_test_case(Bijectors.PDBijector(), (3, 3)), @@ -128,7 +127,8 @@ end true end else - test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + rng = Xoshiro(123456) + test_rule(rng, case.func, case.arg; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/integration_testing/diff_tests.jl b/test/integration_testing/diff_tests.jl index 857044449..aa0bf95bf 100644 --- a/test/integration_testing/diff_tests.jl +++ b/test/integration_testing/diff_tests.jl @@ -6,6 +6,6 @@ TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv )) @info "$n: $(_typeof((f, x...)))" - test_rule(sr(123456), f, x...; interface_only=false, is_primitive=false) + test_rule(sr(123456), f, x...; is_primitive=false) end end diff --git a/test/integration_testing/gp/gp.jl b/test/integration_testing/gp/gp.jl index d0e96e156..55f74459f 100644 --- a/test/integration_testing/gp/gp.jl +++ b/test/integration_testing/gp/gp.jl @@ -5,7 +5,7 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using AbstractGPs, KernelFunctions, Mooncake, Test @testset "gp" begin - base_kernels = Any[ + ks = Any[ ZeroKernel(), ConstantKernel(; c=1.0), SEKernel(), @@ -15,35 +15,37 @@ using AbstractGPs, KernelFunctions, Mooncake, Test LinearKernel(), PolynomialKernel(; degree=2, c=0.5), ] - simple_xs = Any[ - randn(10), - randn(1), - range(0.0; step=0.1, length=11), - ColVecs(randn(2, 11)), - RowVecs(randn(9, 4)), + xs = Any[ + (randn(10), randn(10)), + (randn(1), randn(1)), + (ColVecs(randn(2, 11)), ColVecs(randn(2, 11))), + (RowVecs(randn(3, 4)), RowVecs(randn(3, 4))), ] - d_2_xs = Any[ColVecs(randn(2, 11)), RowVecs(randn(9, 2))] - @testset "$k, $(typeof(x1))" for (k, x1) in vcat( - Any[(k, x) for k in base_kernels for x in simple_xs], - Any[(with_lengthscale(k, 1.1), x) for k in base_kernels for x in simple_xs], - Any[(with_lengthscale(k, rand(2)), x) for k in base_kernels for x in d_2_xs], - Any[(k ∘ LinearTransform(randn(2, 2)), x) for k in base_kernels for x in d_2_xs], + d_2_xs = Any[ + (ColVecs(randn(2, 11)), ColVecs(randn(2, 11))), + (RowVecs(randn(9, 2)), RowVecs(randn(9, 2))), + ] + @testset "$k, $(typeof(x1))" for (k, x1, x2) in vcat( + Any[(k, x1, x2) for k in ks for (x1, x2) in xs], + Any[(with_lengthscale(k, 1.1), x1, x2) for k in ks for (x1, x2) in xs], + Any[(with_lengthscale(k, rand(2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], + Any[(k ∘ LinearTransform(randn(2, 2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], Any[ - (k ∘ LinearTransform(Diagonal(randn(2))), x) for - k in base_kernels for x in d_2_xs + (k ∘ LinearTransform(Diagonal(randn(2))), x1, x2) for + k in ks for (x1, x2) in d_2_xs ], ) fx = GP(k)(x1, 1.1) - @testset "$(_typeof(x))" for x in Any[ - (kernelmatrix, k, x1, x1), - (kernelmatrix_diag, k, x1, x1), + @testset "$(_typeof(args))" for args in Any[ + (kernelmatrix, k, x1, x2), + (kernelmatrix_diag, k, x1, x2), (kernelmatrix, k, x1), (kernelmatrix_diag, k, x1), - (rand, Xoshiro(123456), fx), + (fx -> rand(Xoshiro(123456), fx), fx), (logpdf, fx, rand(fx)), ] - @info typeof(x) - test_rule(sr(123456), x...; interface_only=true, is_primitive=false) + @info typeof(args) + test_rule(sr(123456), args...; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/integration_testing/temporalgps/temporalgps.jl b/test/integration_testing/temporalgps/temporalgps.jl index 8120b3e50..280b8fa86 100644 --- a/test/integration_testing/temporalgps/temporalgps.jl +++ b/test/integration_testing/temporalgps/temporalgps.jl @@ -22,6 +22,6 @@ temporalgps_logpdf_tester(k, x, y, s) = logpdf(build_gp(k)(x, s), y) f = temporalgps_logpdf_tester sig = _typeof((temporalgps_logpdf_tester, k, x, y, s)) @info "$sig" - test_rule(Xoshiro(123456), f, k, x, y, s; is_primitive=false, debug_mode=false) + test_rule(Xoshiro(123456), f, k, x, y, s; is_primitive=false) end end diff --git a/test/integration_testing/turing/turing.jl b/test/integration_testing/turing/turing.jl index d1f3eea14..b380e4946 100644 --- a/test/integration_testing/turing/turing.jl +++ b/test/integration_testing/turing/turing.jl @@ -118,6 +118,6 @@ end ) @info name f, x = build_turing_problem(sr(123), model, ex) - test_rule(sr(123456), f, x; interface_only=true, is_primitive=false) + test_rule(sr(123456), f, x; interface_only, is_primitive=false, unsafe_perturb=true) end end diff --git a/test/tangents.jl b/test/tangents.jl index 69a2e6ffb..c097cf771 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -164,6 +164,12 @@ end end end + @testset "restricted inner constructor" begin + p = TestResources.NoDefaultCtor(5.0) + t = Mooncake.Tangent((x=5.0, )) + @test_throws Mooncake.AddToPrimalException Mooncake._add_to_primal(p, t) + @test Mooncake._add_to_primal(p, t, true) isa typeof(p) + end end # TODO: add the following test to `tangent_test_cases`