This repository has been archived by the owner on Dec 3, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 37
Implement Automatic Relevance Determination #83
Labels
Comments
@theogf Thank you for your ideas in GH-82! I like having a scaling variable field of type I want to propose a couple changes. Particularly:
This would require a modification to the
Here is a minimal working example for the RationalQuadraticKernel: # Base Functions
abstract type BaseFunction{T<:Real} end
abstract type WeightedBaseFunction{T} <: BaseFunction{T} end
struct SquaredEuclidean{T<:Real} <: BaseFunction{T}
a::T
end
struct WeightedSquaredEuclidean{T<:Real} <: WeightedBaseFunction{T}
w::Vector{T}
end
# Will be used by the Kernel types to construct appropriate version of Base Function
base_sqdist(a::T) where {T} = SquaredEuclidean{T}(a)
base_sqdist(w::AbstractVector{T}) where {T} = WeightedSquaredEuclidean{T}(w)
# Will be used in unsafe_base_evaluate to retrieve weight and pass through
get_scale_factor(f::WeightedSquaredEuclidean) = f.w
# Base Function Rules
@inline base_initiate(::BaseFunction{T}) where {T} = zero(T)
@inline base_return(::BaseFunction{T}, s::T) where {T} = s
@inline base_aggregate(::SquaredEuclidean{T}, s::T, x::T, y::T) where {T} = s + (x-y)^2
@inline base_aggregate(::WeightedSquaredEuclidean{T}, s::T, x::T, y::T, w::T) where {T} = s + w*(x-y)^2
@inline base_return(f::SquaredEuclidean{T}, s::T) where {T} = f.a*s
# Base Evaluation Changes
function unsafe_base_evaluate(
f::BaseFunction{T},
x::AbstractArray{T},
y::AbstractArray{T}
) where {T<:Real}
println("Running unweighted unsafe_base_evaluate")
s = base_initiate(f)
@simd for I in eachindex(x, y)
@inbounds xi = x[I]
@inbounds yi = y[I]
s = base_aggregate(f, s, xi, yi)
end
base_return(f, s)
end
# These are new:
function unsafe_base_evaluate(
f::BaseFunction{T},
x::AbstractArray{T},
y::AbstractArray{T},
w::AbstractArray{T}
) where {T<:Real}
println("Running weighted unsafe_base_evaluate")
s = base_initiate(f)
@simd for I in eachindex(x, y, w)
@inbounds xi = x[I]
@inbounds yi = y[I]
@inbounds wi = w[I]
s = base_aggregate(f, s, xi, yi, wi)
end
base_return(f, s)
end
@inline function unsafe_base_evaluate(
f::WeightedBaseFunction{T},
x::AbstractArray{T},
y::AbstractArray{T}
) where {T<:Real}
unsafe_base_evaluate(f, x, y, get_scale_factor(f))
end
# Kernels could be defined as:
const Scale{T} = Union{AbstractVector{T},T}
abstract type Kernel{T<:Real} end
struct RationalQuadraticKernel{T<:Real} <: Kernel{T}
α::Scale{T}
β::T
end
@inline function kappa(κ::RationalQuadraticKernel{T}, d²::T) where {T}
return (one(T) + d²)^(-κ.β)
end
@inline basefunction(κ::RationalQuadraticKernel) = base_sqdist(κ.α)
# Demonstration
k1 = RationalQuadraticKernel{Float64}(2.0,1.0)
k2 = RationalQuadraticKernel{Float64}([2.0, 2.0, 2.0],1.0)
k3 = RationalQuadraticKernel{Float64}([1.0, 2.0, 3.0],1.0)
b1 = basefunction(k1)
b2 = basefunction(k2)
b3 = basefunction(k3)
x = rand(3)
y = rand(3)
unsafe_base_evaluate(b1, x, y)
unsafe_base_evaluate(b2, x, y)
unsafe_base_evaluate(b3, x, y) What are your thoughts on this approach? |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Implement Automatic Relevance Determination as described in Rasmussen (page 2 of pdf, 106 of text):
The text was updated successfully, but these errors were encountered: