From f3482badf1f275e2a98bf7e338d406d399609f37 Mon Sep 17 00:00:00 2001 From: Gergo Bohner Date: Mon, 9 Jul 2018 09:13:09 +0100 Subject: [PATCH] Tuning dispatch framework for discussion --- MLR/TuningDispatch_framework.ipynb | 681 +++++++++++++++++++++++++++++ 1 file changed, 681 insertions(+) create mode 100644 MLR/TuningDispatch_framework.ipynb diff --git a/MLR/TuningDispatch_framework.ipynb b/MLR/TuningDispatch_framework.ipynb new file mode 100644 index 0000000..bf32449 --- /dev/null +++ b/MLR/TuningDispatch_framework.ipynb @@ -0,0 +1,681 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "predict (generic function with 1 method)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define BaseModel as an abstract type, all models will belong to this category\n", + "abstract type BaseModel end\n", + "abstract type BaseModelFit{T<:BaseModel} end\n", + "\n", + "# When we fit a model, we get back a ModelFit object, that has the original model, as well as results, and should not change once fitted\n", + "struct ModelFit{T} <: BaseModelFit{T}\n", + " model :: T\n", + " fit_result\n", + "end\n", + "model(modelFit::ModelFit) = modelFit.model # Accessor function for the family of ModelFit types, instead of directly accessing the field. This way the accessor function is already informed by the type of the model, as it infers it from the type of ModelFit it is accessing, and ends up being faster than using modelFit.model arbitrarily?\n", + "\n", + "# Define a generic predict for BaseModelFit, that disambiguates them based on what Model they are the result of\n", + "predict(modelFit::BaseModelFit, Xnew) = predict(model(modelFit), modelFit, Xnew)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "predict (generic function with 3 methods)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "```\n", + "Every model has \n", + "- a unique name (that is the model type)\n", + "- a fit function (that dispatches based on first input of model type and returns a ModelFit type) and a \n", + "- predict function (that dispatches based on the first input of model type, and second input of ModelFit type)\n", + "```\n", + "mutable struct LinModel <: BaseModel\n", + " parameters # Maybe a dictionary of names and values for now\n", + "end\n", + "\n", + "fit(model::LinModel, X::AbstractArray, y::AbstractArray) = ModelFit(model, model.parameters[\"x1\"])\n", + "predict(model::LinModel, modelFit::BaseModelFit, Xnew) = 11\n", + "\n", + "mutable struct NonLinModel <: BaseModel\n", + " parameters # Maybe a dictionary of names and values for now\n", + "end\n", + "\n", + "fit(model::NonLinModel, X::AbstractArray, y::AbstractArray) = ModelFit(model, 2)\n", + "predict(model::NonLinModel, modelFit::BaseModelFit, Xnew) = 22" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LinModel(Dict(\"x1\"=>1,\"x0\"=>0))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm = LinModel(Dict(\"x0\" => 0, \"x1\" => 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm_fit = fit(lm, [],[])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModelFit{NonLinModel}(NonLinModel(\"haha\"), 2)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nlm_fit = ModelFit(NonLinModel(\"haha\"),2)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "11" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(lm_fit,[])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "22" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(nlm_fit,[])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tuning as a composable wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Define some types of tuning\n", + "abstract type BaseTuning end\n", + "struct SimpleGridTuning <: BaseTuning\n", + " grid\n", + "end\n", + "struct ModelSelectTuning <: BaseTuning\n", + " models\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tuning (generic function with 1 method)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now imagine we want to add tuning to the model, this should result in a composite that inherits the initial model, as well as adds a tuning grid to change the parameters on\n", + "struct TunedModel{T<:BaseModel} <: BaseModel\n", + " model :: T\n", + " tuning :: BaseTuning\n", + "end\n", + "\n", + "# Accessor functions (for compile-time lookup gain)\n", + "model(tunedModel::TunedModel) = tunedModel.model\n", + "tuning(tunedModel::TunedModel) = tunedModel.tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "model (generic function with 3 methods)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "struct TunedModelFit{T} <: BaseModelFit{T}\n", + " model :: T\n", + " fit_result\n", + " tuning :: BaseTuning\n", + " tuning_result\n", + "end\n", + "model(modelFit::TunedModelFit) = modelFit.model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "fit (generic function with 3 methods)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This is just an example of SimpleGridTuning, do_tuning dispatches on the type of tuning as well\n", + "function do_tuning(model::BaseModel, tuning::SimpleGridTuning, X, y)\n", + " tuning_result = [fit(typeof(model)(parameters), X, y) for parameters in tuning.grid]\n", + " best_result_ind = 2\n", + " fit_result = tuning_result[best_result_ind] # Choose the best one given metric\n", + " return TunedModelFit(\n", + " typeof(model)(tuning.grid[best_result_ind]), \n", + " fit_result, \n", + " tuning, \n", + " tuning_result)\n", + "end\n", + "\n", + "# This is just an example of ModelSelectTuning, do_tuning dispatches on the type of tuning as well\n", + "function do_tuning(model::BaseModel, tuning::ModelSelectTuning, X, y)\n", + " tuning_result = [fit(cur_model, X, y) for cur_model in tuning.models]\n", + " best_result_ind = 3\n", + " fit_result = tuning_result[best_result_ind] # Choose the best one given metric\n", + " return TunedModelFit(\n", + " fit_result.model, \n", + " fit_result, \n", + " tuning, \n", + " tuning_result)\n", + "end\n", + "\n", + "\n", + "fit(tunedModel::TunedModel, X, y) = do_tuning(model(tunedModel), tuning(tunedModel), X, y)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Array{Dict{String,Int64},1}:\n", + " Dict(\"x1\"=>1,\"x0\"=>0)\n", + " Dict(\"x1\"=>2,\"x0\"=>0)\n", + " Dict(\"x1\"=>3,\"x0\"=>0)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "curTuneGrid = [\n", + " Dict(\"x0\" => 0, \"x1\" => 1),\n", + " Dict(\"x0\" => 0, \"x1\" => 2),\n", + " Dict(\"x0\" => 0, \"x1\" => 3)\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModel{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm_tuned = TunedModel(lm, SimpleGridTuning(curTuneGrid))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Array{Dict{String,Int64},1}:\n", + " Dict(\"x1\"=>1,\"x0\"=>0)\n", + " Dict(\"x1\"=>2,\"x0\"=>0)\n", + " Dict(\"x1\"=>3,\"x0\"=>0)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tuning(lm_tuned).grid" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModel{LinModel}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "typeof(lm_tuned)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm_tuned_fit = fit(lm_tuned, [], [])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "11" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(lm_tuned_fit, [])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3-element Array{ModelFit{LinModel},1}:\n", + " ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1)\n", + " ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2)\n", + " ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm_tuned_fit.tuning_result" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LinModel(Dict(\"x1\"=>2,\"x0\"=>0))" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(lm_tuned_fit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)]), ModelSelectTuning(BaseModel[LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), NonLinModel(\"haha\"), TunedModel{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]))]), Any[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{NonLinModel}(NonLinModel(\"haha\"), 2), TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)])])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compose model selection and grid tuning\n", + "curTuneModels = [lm, NonLinModel(\"haha\"), lm_tuned]\n", + "selectionTunedModel = TunedModel(lm, ModelSelectTuning(curTuneModels))\n", + "\n", + "selectionTunedModel_fit = fit(selectionTunedModel, [], [])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModel{LinModel}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "typeof(selectionTunedModel)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The underlying tuned model still gets tuned by its own underlying fit function as the result (same result as calling fit on the tuned model directly, see lm_tuned_fit )\n", + "selectionTunedModel_fit.tuning_result[3]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm_tuned_fit" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Side design question: Should the returned model after ModelSelectionTuning be as simplified as possible, or retain its full complexity?" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LinModel(Dict(\"x1\"=>2,\"x0\"=>0))" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# At the moment it returns fitted version rather than the full underlying model (it does do the fitting to evaluate performnace properly, and that is stored in tuning_results as expected)\n", + "selectionTunedModel_fit.model" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModel{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]))" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selectionTunedModel_fit.tuning.models[3] # This returns the full underlying model that was further tuned" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TunedModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), SimpleGridTuning(Dict{String,Int64}[Dict(\"x1\"=>1,\"x0\"=>0), Dict(\"x1\"=>2,\"x0\"=>0), Dict(\"x1\"=>3,\"x0\"=>0)]), ModelFit{LinModel}[ModelFit{LinModel}(LinModel(Dict(\"x1\"=>1,\"x0\"=>0)), 1), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>2,\"x0\"=>0)), 2), ModelFit{LinModel}(LinModel(Dict(\"x1\"=>3,\"x0\"=>0)), 3)])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selectionTunedModel_fit.tuning_result[3] # With the results of tuning that model being here" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "11" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(selectionTunedModel_fit, []) # This just predicts based on the simplest underlying model that was selected after ModelSelectionTuning and SimpleGridTuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Traits - old\n", + "\n", + "\n", + "# using SimpleTraits\n", + "# @traitdef IsLinearModelFit{X}\n", + "# @traitimpl IsLinearModelFit{X} <- (isa(X.model,LinModel)) # This is not the way to do this\n", + "\n", + "# @traitfn predict{X; IsLinearModelFit{X}}(modelFit::X, Xnew) = 11\n", + "# @traitfn predict{X; !IsLinearModelFit{X}}(modelFit::X, Xnew) = 12" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 0.6.2", + "language": "julia", + "name": "julia-0.6" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "0.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}