Skip to content

Commit

Permalink
Fix method ambiguities (#483)
Browse files Browse the repository at this point in the history
* Fix method ambiguities

* Update test/kernels/kernelsum.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Try to fix `map` issues

* Define `map` for `ColVecs`/`RowVecs`

* Fix ambiguity issues

* Better fix

* Add back some definitions

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] authored Apr 21, 2023
1 parent 9da7bfd commit ef6d459
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.54"
version = "0.10.55"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

# Note that this is type piracy as the derivative should be NaN for x == y.
function ChainRulesCore.frule(
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
(_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any},
d::Distances.Euclidean,
x::AbstractVector,
y::AbstractVector,
)
Δ = x - y
D = sqrt(sum(abs2, Δ))
Expand Down
8 changes: 8 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,13 @@ for (M, op, T) in (

$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))

# Fix method ambiguity issues
function $M.$op(ks1::$T, ks2::$T{<:AbstractVector{<:Kernel}})
return $T(vcat(collect(ks1.kernels), ks2.kernels))
end
function $M.$op(ks1::$T{<:AbstractVector{<:Kernel}}, ks2::$T)
return $T(vcat(ks1.kernels, collect(ks2.kernels)))
end
end
end
7 changes: 5 additions & 2 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ function ChainTransform(v, θ::AbstractVector)
end

Base.:(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t))
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...))
Base.:(t::Transform, tc::ChainTransform) = ChainTransform((tc.transforms..., t))
Base.:(tc::ChainTransform, t::Transform) = ChainTransform((t, tc.transforms...))
function Base.:(tc1::ChainTransform, tc2::ChainTransform)
return ChainTransform((tc2.transforms..., tc1.transforms...))
end

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

Expand Down
20 changes: 18 additions & 2 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,21 @@ abstract type Transform end
# We introduce our own _map for Transform so that we can work around
# https://github.com/FluxML/Zygote.jl/issues/646 and define our own pullback
# (see zygoterules.jl)
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
_map(t::Transform, x::AbstractVector) = t.(x)
Base.map(t::Transform, x::ColVecs) = _map(t, x)
Base.map(t::Transform, x::RowVecs) = _map(t, x)

# Fallback
# No separate methods for `x::ColVecs` and `x::RowVecs` to avoid method ambiguities
function _map(t::Transform, x::AbstractVector)
# Avoid stackoverflow
if x isa RowVecs
return map(t, eachrow(x.X))
elseif x isa ColVecs
return map(t, eachcol(x.X))
else
return map(t, x)
end
end

"""
IdentityTransform()
Expand All @@ -19,6 +32,9 @@ Transformation that returns exactly the input.
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x

# More efficient implementation than `map(IdentityTransform(), x)`
# Introduces, however, discrepancy between `map` and `_map`
_map(::IdentityTransform, x::AbstractVector) = x

### TODO Maybe defining adjoints could help but so far it's not working
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ _to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))

pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::ColVecs)
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
end
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector{<:AbstractVector{<:Real}})
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
Expand Down Expand Up @@ -172,10 +172,10 @@ dim(x::RowVecs) = size(x.X, 2)

pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
function pairwise(d::PreMetric, x::AbstractVector{<:AbstractVector{<:Real}}, y::RowVecs)
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
end
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector{<:AbstractVector{<:Real}})
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
Expand Down
6 changes: 6 additions & 0 deletions test/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
k2 = SqExponentialKernel()
k = KernelProduct(k1, k2)
@test k == KernelProduct([k1, k2]) == KernelProduct((k1, k2))
for (_k1, _k2) in Iterators.product(
(k1, KernelProduct((k1,)), KernelProduct([k1])),
(k2, KernelProduct((k2,)), KernelProduct([k2])),
)
@test k == _k1 * _k2
end
@test length(k) == 2
@test string(k) == (
"Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " *
Expand Down
5 changes: 5 additions & 0 deletions test/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
k2 = SqExponentialKernel()
k = KernelSum(k1, k2)
@test k == KernelSum([k1, k2]) == KernelSum((k1, k2))
for (_k1, _k2) in Iterators.product(
(k1, KernelSum((k1,)), KernelSum([k1])), (k2, KernelSum((k2,)), KernelSum([k2]))
)
@test k == _k1 + _k2
end
@test length(k) == 2
@test repr(k) == (
"Sum of 2 kernels:\n" *
Expand Down
6 changes: 6 additions & 0 deletions test/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@

@test kernel1 == kernel2
@test kernel1.kernels === (k1, k2) === KernelTensorProduct((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelTensorProduct((k1,)), KernelTensorProduct([k1])),
(k2, KernelTensorProduct((k2,)), KernelTensorProduct([k2])),
)
@test kernel1 == _k1 _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Tensor product of 2 kernels:\n" *
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ include("test_utils.jl")
if GROUP == "" || GROUP == "Others"
include("utils.jl")

@test isempty(detect_unbound_args(KernelFunctions))
@testset "general" begin
@test isempty(detect_unbound_args(KernelFunctions))
@test isempty(detect_ambiguities(KernelFunctions))
end

@testset "distances" begin
include("distances/pairwise.jl")
Expand Down
1 change: 1 addition & 0 deletions test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Check composition constructors.
@test (tf ChainTransform([tp])).transforms == (tp, tf)
@test (ChainTransform([tf]) tp).transforms == (tp, tf)
@test (ChainTransform([tf]) ChainTransform([tp])).transforms == (tp, tf)

# Verify correctness.
x = ColVecs(randn(rng, 2, 3))
Expand Down

2 comments on commit ef6d459

@theogf
Copy link
Member

@theogf theogf commented on ef6d459 Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/82044

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.55 -m "<description of version>" ef6d4591b36194fca069d8bc7ae8c1e2ee288080
git push origin v0.10.55

Please sign in to comment.