Skip to content
This repository has been archived by the owner on Dec 3, 2019. It is now read-only.

Implement Automatic Relevance Determination #83

Open
trthatcher opened this issue Jan 12, 2019 · 1 comment
Open

Implement Automatic Relevance Determination #83

trthatcher opened this issue Jan 12, 2019 · 1 comment

Comments

@trthatcher
Copy link
Owner

Implement Automatic Relevance Determination as described in Rasmussen (page 2 of pdf, 106 of text):

http://www.gaussianprocess.org/gpml/chapters/RW5.pdf

@trthatcher
Copy link
Owner Author

@theogf Thank you for your ideas in GH-82! I like having a scaling variable field of type Union{Real,AbstractVector{<:Real}}.

I want to propose a couple changes. Particularly:

  • Keep the original unsafe_base_evaluate but define an additional version with a signature that includes the scaling variable

This would require a modification to the BaseFunction types. Outline of approach:

  • Parameterize the BaseFunction types and define a WeightedBaseFunction abstract type (<: BaseFunction)
  • ARD versions of base functions would be concrete subtypes with vector field for scaling
  • The scaling variable in the Kernel type would be of type Union{AbstractVector{T},T}
  • When kernelmatrix is called, the corresponding BaseFunction/WeightedBaseFunction would be constructed and passed to the base_evaluate or basematrix function for dispatch
  • There would be two base_evaluate methods, one for (x,y) and one for (x,y,scale)

Here is a minimal working example for the RationalQuadraticKernel:

# Base Functions
abstract type BaseFunction{T<:Real} end

abstract type WeightedBaseFunction{T} <: BaseFunction{T} end

struct SquaredEuclidean{T<:Real} <: BaseFunction{T} 
    a::T
end

struct WeightedSquaredEuclidean{T<:Real} <: WeightedBaseFunction{T} 
    w::Vector{T}
end

# Will be used by the Kernel types to construct appropriate version of Base Function
base_sqdist(a::T) where {T} = SquaredEuclidean{T}(a)
base_sqdist(w::AbstractVector{T}) where {T} = WeightedSquaredEuclidean{T}(w)

# Will be used in unsafe_base_evaluate to retrieve weight and pass through
get_scale_factor(f::WeightedSquaredEuclidean) = f.w

# Base Function Rules
@inline base_initiate(::BaseFunction{T}) where {T} = zero(T)
@inline base_return(::BaseFunction{T}, s::T) where {T} = s

@inline base_aggregate(::SquaredEuclidean{T}, s::T, x::T, y::T) where {T} = s + (x-y)^2
@inline base_aggregate(::WeightedSquaredEuclidean{T}, s::T, x::T, y::T, w::T) where {T} = s + w*(x-y)^2

@inline base_return(f::SquaredEuclidean{T}, s::T) where {T} = f.a*s


# Base Evaluation Changes
function unsafe_base_evaluate(
        f::BaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T}
    ) where {T<:Real}
    println("Running unweighted unsafe_base_evaluate")
    s = base_initiate(f)
    @simd for I in eachindex(x, y)
        @inbounds xi = x[I]
        @inbounds yi = y[I]
        s = base_aggregate(f, s, xi, yi)
    end
    base_return(f, s)
end

# These are new:
function unsafe_base_evaluate(
        f::BaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T},
        w::AbstractArray{T}
    ) where {T<:Real}
    println("Running weighted unsafe_base_evaluate")
    s = base_initiate(f)
    @simd for I in eachindex(x, y, w)
        @inbounds xi = x[I]
        @inbounds yi = y[I]
        @inbounds wi = w[I]
        s = base_aggregate(f, s, xi, yi, wi)
    end
    base_return(f, s)
end

@inline function unsafe_base_evaluate(
        f::WeightedBaseFunction{T},
        x::AbstractArray{T},
        y::AbstractArray{T}
    ) where {T<:Real}
    unsafe_base_evaluate(f, x, y, get_scale_factor(f))
end


# Kernels could be defined as:
const Scale{T} = Union{AbstractVector{T},T}

abstract type Kernel{T<:Real} end

struct RationalQuadraticKernel{T<:Real} <: Kernel{T}
    α::Scale{T}
    β::T
end

@inline function kappa::RationalQuadraticKernel{T}, d²::T) where {T}
    return (one(T) + d²)^(-κ.β)
end

@inline basefunction::RationalQuadraticKernel) = base_sqdist.α)


# Demonstration

k1 = RationalQuadraticKernel{Float64}(2.0,1.0)
k2 = RationalQuadraticKernel{Float64}([2.0, 2.0, 2.0],1.0)
k3 = RationalQuadraticKernel{Float64}([1.0, 2.0, 3.0],1.0)

b1 = basefunction(k1)
b2 = basefunction(k2)
b3 = basefunction(k3)

x = rand(3)
y = rand(3)

unsafe_base_evaluate(b1, x, y)
unsafe_base_evaluate(b2, x, y)
unsafe_base_evaluate(b3, x, y)

What are your thoughts on this approach?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

1 participant