From ef6d4591b36194fca069d8bc7ae8c1e2ee288080 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 21 Apr 2023 10:59:22 +0200 Subject: [PATCH] Fix method ambiguities (#483) * Fix method ambiguities * Update test/kernels/kernelsum.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Try to fix `map` issues * Define `map` for `ColVecs`/`RowVecs` * Fix ambiguity issues * Better fix * Add back some definitions --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/chainrules.jl | 5 ++++- src/kernels/overloads.jl | 8 ++++++++ src/transform/chaintransform.jl | 7 +++++-- src/transform/transform.jl | 20 ++++++++++++++++++-- src/utils.jl | 8 ++++---- test/kernels/kernelproduct.jl | 6 ++++++ test/kernels/kernelsum.jl | 5 +++++ test/kernels/kerneltensorproduct.jl | 6 ++++++ test/runtests.jl | 5 ++++- test/transform/chaintransform.jl | 1 + 11 files changed, 62 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 32486a3d4..c763a8824 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.54" +version = "0.10.55" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index dde352f8c..3b52860dd 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -2,7 +2,10 @@ # Note that this is type piracy as the derivative should be NaN for x == y. function ChainRulesCore.frule( - (_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector + (_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any}, + d::Distances.Euclidean, + x::AbstractVector, + y::AbstractVector, ) Δ = x - y D = sqrt(sum(abs2, Δ)) diff --git a/src/kernels/overloads.jl b/src/kernels/overloads.jl index b3ba76c7f..3285c3dd4 100644 --- a/src/kernels/overloads.jl +++ b/src/kernels/overloads.jl @@ -18,5 +18,13 @@ for (M, op, T) in ( $M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k) $M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k)) + + # Fix method ambiguity issues + function $M.$op(ks1::$T, ks2::$T{<:AbstractVector{<:Kernel}}) + return $T(vcat(collect(ks1.kernels), ks2.kernels)) + end + function $M.$op(ks1::$T{<:AbstractVector{<:Kernel}}, ks2::$T) + return $T(vcat(ks1.kernels, collect(ks2.kernels))) + end end end diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index 9129fb84b..208b1f689 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -34,8 +34,11 @@ function ChainTransform(v, θ::AbstractVector) end Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁)) -Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t)) -Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...)) +Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform((tc.transforms..., t)) +Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform((t, tc.transforms...)) +function Base.:∘(tc1::ChainTransform, tc2::ChainTransform) + return ChainTransform((tc2.transforms..., tc1.transforms...)) +end (t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x) diff --git a/src/transform/transform.jl b/src/transform/transform.jl index c7da2729d..795a3498f 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -8,8 +8,21 @@ abstract type Transform end # We introduce our own _map for Transform so that we can work around # https://github.com/FluxML/Zygote.jl/issues/646 and define our own pullback # (see zygoterules.jl) -Base.map(t::Transform, x::AbstractVector) = _map(t, x) -_map(t::Transform, x::AbstractVector) = t.(x) +Base.map(t::Transform, x::ColVecs) = _map(t, x) +Base.map(t::Transform, x::RowVecs) = _map(t, x) + +# Fallback +# No separate methods for `x::ColVecs` and `x::RowVecs` to avoid method ambiguities +function _map(t::Transform, x::AbstractVector) + # Avoid stackoverflow + if x isa RowVecs + return map(t, eachrow(x.X)) + elseif x isa ColVecs + return map(t, eachcol(x.X)) + else + return map(t, x) + end +end """ IdentityTransform() @@ -19,6 +32,9 @@ Transformation that returns exactly the input. struct IdentityTransform <: Transform end (t::IdentityTransform)(x) = x + +# More efficient implementation than `map(IdentityTransform(), x)` +# Introduces, however, discrepancy between `map` and `_map` _map(::IdentityTransform, x::AbstractVector) = x ### TODO Maybe defining adjoints could help but so far it's not working diff --git a/src/utils.jl b/src/utils.jl index bb805b15f..0942fa5f7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -101,10 +101,10 @@ _to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :)) pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2) pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2) -function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs) +function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::ColVecs) return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2) end -function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector) +function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector{<:AbstractVector{<:Real}}) return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2) end function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs) @@ -172,10 +172,10 @@ dim(x::RowVecs) = size(x.X, 2) pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1) pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1) -function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs) +function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::RowVecs) return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1) end -function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector) +function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector{<:AbstractVector{<:Real}}) return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1) end function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs) diff --git a/test/kernels/kernelproduct.jl b/test/kernels/kernelproduct.jl index b2cdea042..94e816ade 100644 --- a/test/kernels/kernelproduct.jl +++ b/test/kernels/kernelproduct.jl @@ -3,6 +3,12 @@ k2 = SqExponentialKernel() k = KernelProduct(k1, k2) @test k == KernelProduct([k1, k2]) == KernelProduct((k1, k2)) + for (_k1, _k2) in Iterators.product( + (k1, KernelProduct((k1,)), KernelProduct([k1])), + (k2, KernelProduct((k2,)), KernelProduct([k2])), + ) + @test k == _k1 * _k2 + end @test length(k) == 2 @test string(k) == ( "Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " * diff --git a/test/kernels/kernelsum.jl b/test/kernels/kernelsum.jl index cfe239717..4b6f30f94 100644 --- a/test/kernels/kernelsum.jl +++ b/test/kernels/kernelsum.jl @@ -3,6 +3,11 @@ k2 = SqExponentialKernel() k = KernelSum(k1, k2) @test k == KernelSum([k1, k2]) == KernelSum((k1, k2)) + for (_k1, _k2) in Iterators.product( + (k1, KernelSum((k1,)), KernelSum([k1])), (k2, KernelSum((k2,)), KernelSum([k2])) + ) + @test k == _k1 + _k2 + end @test length(k) == 2 @test repr(k) == ( "Sum of 2 kernels:\n" * diff --git a/test/kernels/kerneltensorproduct.jl b/test/kernels/kerneltensorproduct.jl index d181b89ac..a59c972ee 100644 --- a/test/kernels/kerneltensorproduct.jl +++ b/test/kernels/kerneltensorproduct.jl @@ -13,6 +13,12 @@ @test kernel1 == kernel2 @test kernel1.kernels === (k1, k2) === KernelTensorProduct((k1, k2)).kernels + for (_k1, _k2) in Iterators.product( + (k1, KernelTensorProduct((k1,)), KernelTensorProduct([k1])), + (k2, KernelTensorProduct((k2,)), KernelTensorProduct([k2])), + ) + @test kernel1 == _k1 ⊗ _k2 + end @test length(kernel1) == length(kernel2) == 2 @test string(kernel1) == ( "Tensor product of 2 kernels:\n" * diff --git a/test/runtests.jl b/test/runtests.jl index f486ddd2a..caf43cb91 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -149,7 +149,10 @@ include("test_utils.jl") if GROUP == "" || GROUP == "Others" include("utils.jl") - @test isempty(detect_unbound_args(KernelFunctions)) + @testset "general" begin + @test isempty(detect_unbound_args(KernelFunctions)) + @test isempty(detect_ambiguities(KernelFunctions)) + end @testset "distances" begin include("distances/pairwise.jl") diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index dad6f563d..8d7442e05 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -12,6 +12,7 @@ # Check composition constructors. @test (tf ∘ ChainTransform([tp])).transforms == (tp, tf) @test (ChainTransform([tf]) ∘ tp).transforms == (tp, tf) + @test (ChainTransform([tf]) ∘ ChainTransform([tp])).transforms == (tp, tf) # Verify correctness. x = ColVecs(randn(rng, 2, 3))