diff --git a/examples/evaluation_examples/calibrated_predictive_distributions_demo.ipynb b/examples/evaluation_examples/calibrated_predictive_distributions_demo.ipynb new file mode 100644 index 00000000..2133f889 --- /dev/null +++ b/examples/evaluation_examples/calibrated_predictive_distributions_demo.ipynb @@ -0,0 +1,290 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Demo notebook for the calibrated predictive distributions implementation in RAIL" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Author: Luca Tortorelli, Bitrateep Dey\n", + "\n", + "last run successfully: Nov 7, 2022\n", + "\n", + "The purpose of this notebook is to demonstrate the implementation of the calibrated predictive distribution (Dey at al. 2022) in RAIL.\n", + "Bitrateep provided a test data in .npz format (/src/rail/examples/testdata/bpz_test_red.npz) that contained:\n", + "- a galaxy catalogue with spectroscopic redshifts, magnitudes and their errors\n", + "- conditional density estimates for each galaxy, PDFs evaluated with a photo-z method on representative sample of object\n", + "- redshift grid of the conditional density estimate" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "import qp\n", + "from src.rail.evaluation.metrics.pit import ConditionPIT" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "We do a small degree of preprocessing before feeding the data into RAIL" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# read Bitrateep's test data\n", + "\n", + "root = 'src/rail/examples/testdata/'\n", + "data = np.load(os.path.join(root, 'bpz_test_red.npz'), allow_pickle=True)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# display the keywords to access the data\n", + "for name in data.keys(): print(name)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# conveniently read the galaxy catalogue as pandas dataframe\n", + "cat = pd.DataFrame(data[\"test_cat\"])" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# create new features for training the method, in this case colours and their errors\n", + "\n", + "cat[\"UG\"] = cat[\"U\"]-cat[\"G\"]\n", + "cat[\"UGERR\"] = np.sqrt(cat[\"UERR\"]**2 + cat[\"GERR\"]**2)\n", + "cat[\"UR\"] = cat[\"U\"]-cat[\"R\"]\n", + "cat[\"URERR\"] = np.sqrt(cat[\"UERR\"]**2 + cat[\"RERR\"]**2)\n", + "cat[\"UI\"] = cat[\"U\"]-cat[\"I\"]\n", + "cat[\"UIERR\"] = np.sqrt(cat[\"UERR\"]**2 + cat[\"IERR\"]**2)\n", + "cat[\"UZ\"] = cat[\"U\"]-cat[\"Z\"]\n", + "cat[\"UZERR\"] = np.sqrt(cat[\"UERR\"]**2 + cat[\"ZERR\"]**2)\n", + "cat[\"UY\"] = cat[\"U\"]-cat[\"Y\"]\n", + "cat[\"UYERR\"] = np.sqrt(cat[\"UERR\"]**2 + cat[\"YERR\"]**2)\n", + "\n", + "cat[\"GR\"] = cat[\"G\"]-cat[\"R\"]\n", + "cat[\"GRERR\"] = np.sqrt(cat[\"GERR\"]**2 + cat[\"RERR\"]**2)\n", + "cat[\"GI\"] = cat[\"G\"]-cat[\"I\"]\n", + "cat[\"GIERR\"] = np.sqrt(cat[\"GERR\"]**2 + cat[\"IERR\"]**2)\n", + "cat[\"GZ\"] = cat[\"G\"]-cat[\"Z\"]\n", + "cat[\"GZERR\"] = np.sqrt(cat[\"GERR\"]**2 + cat[\"ZERR\"]**2)\n", + "cat[\"GY\"] = cat[\"G\"]-cat[\"Y\"]\n", + "cat[\"GYERR\"] = np.sqrt(cat[\"GERR\"]**2 + cat[\"YERR\"]**2)\n", + "\n", + "cat[\"RI\"] = cat[\"R\"]-cat[\"I\"]\n", + "cat[\"RIERR\"] = np.sqrt(cat[\"RERR\"]**2 + cat[\"IERR\"]**2)\n", + "cat[\"RZ\"] = cat[\"R\"]-cat[\"Z\"]\n", + "cat[\"RZERR\"] = np.sqrt(cat[\"RERR\"]**2 + cat[\"ZERR\"]**2)\n", + "cat[\"RY\"] = cat[\"R\"]-cat[\"Y\"]\n", + "cat[\"RYERR\"] = np.sqrt(cat[\"RERR\"]**2 + cat[\"YERR\"]**2)\n", + "\n", + "cat[\"IZ\"] = cat[\"I\"]-cat[\"Z\"]\n", + "cat[\"IZERR\"] = np.sqrt(cat[\"IERR\"]**2 + cat[\"ZERR\"]**2)\n", + "cat[\"IY\"] = cat[\"I\"]-cat[\"Y\"]\n", + "cat[\"IYERR\"] = np.sqrt(cat[\"IERR\"]**2 + cat[\"YERR\"]**2)\n", + "\n", + "cat[\"ZY\"] = cat[\"Z\"]-cat[\"Y\"]\n", + "cat[\"ZYERR\"] = np.sqrt(cat[\"ZERR\"]**2 + cat[\"YERR\"]**2)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# normalise the conditional density estimates across the redshift grid\n", + "z_grid = data[\"z_grid\"]\n", + "\n", + "cde = data[\"cde_test\"] # conditional density estimate\n", + "norm = np.trapz(cde, z_grid) # normalize across the redshift grid\n", + "norm[norm==0] = 1\n", + "cde = cde/norm[:,None]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# define the number of galaxies to train the method and split the sample into training and testing set\n", + "SEED = 299792458\n", + "\n", + "num_calib = 800\n", + "n_gal = len(cat)\n", + "num_test = n_gal - num_calib\n", + "\n", + "rng = np.random.default_rng(SEED)\n", + "indices = rng.permutation(n_gal) # creating index permutation for splitting in train and test\n", + "\n", + "cde_calib = cde[indices[:num_calib]] # splitting cde in training set\n", + "cde_test = cde[indices[num_calib:]] # and test set\n", + "\n", + "z_calib = cat[\"SPECZ\"][indices[:num_calib]].values\n", + "z_test = cat[\"SPECZ\"][indices[num_calib:]].values\n", + "\n", + "cat_calib = cat.iloc[indices[:num_calib]]\n", + "cat_test = cat.iloc[indices[num_calib:]]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# define a list of features for the method\n", + "features = [\"I\", \"UG\", \"GR\", \"RI\", \"IZ\", \"ZY\", \"IZERR\", \"RIERR\", \"GRERR\", \"UGERR\", \"IERR\", \"ZYERR\"]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# store the conditional density estimates for the training and test set into qp ensembles\n", + "qp_ens_cde_calib = qp.Ensemble(qp.interp, data=dict(xvals=z_grid, yvals=cde_calib))\n", + "qp_ens_cde_test = qp.Ensemble(qp.interp, data=dict(xvals=z_grid, yvals=cde_test))" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Initialisation of the ConditionPIT class" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "cond_pit = ConditionPIT(cde_calib, cde_test, z_grid, z_calib, z_test, cat_calib[features].values,\n", + " cat_test[features].values, qp_ens_cde_calib)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# train the method using the provided data\n", + "cond_pit.train(patience=10, n_epochs=10, lr=0.001, weight_decay=0.01, batch_size=100, frac_mlp_train=0.9,\n", + " lr_decay=0.95, oversample=50, n_alpha=201, checkpt_path=\"checkpoint_GPZ_wide_CDE_test.pt\",\n", + " hidden_layers=[2, 2, 2])" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# compute the local pit\n", + "pit_local, pit_local_fit = cond_pit.evaluate(model_checkpt_path='checkpoint_GPZ_wide_CDE_test.pt',\n", + " model_hidden_layers=[2, 2, 2], nn_type='monotonic',\n", + " batch_size=100, num_basis=40, num_cores=1)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# plot the local P-P plot diagnostics\n", + "cond_pit.diagnostics(pit_local, pit_local_fit)" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pyproject.toml b/pyproject.toml index 3bdd3877..c809e28d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,4 +151,4 @@ show_error_codes = true strict_equality = true warn_redundant_casts = true warn_unreachable = true -warn_unused_ignores = true +warn_unused_ignores = true \ No newline at end of file diff --git a/src/rail/evaluation/metrics/condition_pit_utils/MonotonicNN.py b/src/rail/evaluation/metrics/condition_pit_utils/MonotonicNN.py new file mode 100644 index 00000000..236eb692 --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/MonotonicNN.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021,2022 Bitrateep Dey, University of Pittsburgh, USA + +import torch +import torch.nn as nn +from .NeuralIntegral import NeuralIntegral +from .ParallelNeuralIntegral import ParallelNeuralIntegral + + +def _flatten(sequence): + flat = [p.contiguous().view(-1) for p in sequence] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + +def clipped_relu(x, device): + return torch.minimum(torch.maximum(torch.Tensor([0]).to(device),x), torch.Tensor([1]).to(device)) + +class IntegrandNN(nn.Module): + def __init__(self, in_d, hidden_layers): + super(IntegrandNN, self).__init__() + self.net = [] + hs = [in_d] + hidden_layers + [1] + for h0, h1 in zip(hs, hs[1:]): + self.net.extend([ + nn.Linear(h0, h1), + nn.ReLU(), + ]) + self.net.pop() # pop the last ReLU for the output layer + self.net.append(nn.ELU()) + self.net = nn.Sequential(*self.net) + + def forward(self, x, h): + return self.net(torch.cat((x, h), 1)) + 1. + +class MonotonicNN(nn.Module): + def __init__(self, in_d, hidden_layers, nb_steps=50, sigmoid=False, dev="cpu"): + super(MonotonicNN, self).__init__() + self.integrand = IntegrandNN(in_d, hidden_layers) + self.net = [] + hs = [in_d-1] + hidden_layers + [2] + for h0, h1 in zip(hs, hs[1:]): + self.net.extend([ + nn.Linear(h0, h1), + nn.ReLU(), + ]) + self.net.pop() # pop the last ReLU for the output layer + # It will output the scaling and offset factors. + self.net = nn.Sequential(*self.net) + self.device = dev + self.nb_steps = nb_steps + self.sigmoid = sigmoid + + ''' + The forward procedure takes as input x which is the variable for which the integration must be made, h is just other conditionning variables. + ''' + def forward(self, x_input): + x = x_input[:, 0][:, None] + h = x_input[:, 1:] + x0 = torch.zeros(x.shape).to(self.device) + out = self.net(h) + offset = out[:, [0]] + scaling = torch.exp(out[:, [1]]) + if self.sigmoid: + return torch.sigmoid(scaling*ParallelNeuralIntegral.apply(x0, x, self.integrand, _flatten(self.integrand.parameters()), h, self.nb_steps) + offset) + else: + return scaling*ParallelNeuralIntegral.apply(x0, x, self.integrand, _flatten(self.integrand.parameters()), h, self.nb_steps) + offset diff --git a/src/rail/evaluation/metrics/condition_pit_utils/NeuralIntegral.py b/src/rail/evaluation/metrics/condition_pit_utils/NeuralIntegral.py new file mode 100644 index 00000000..c30188cc --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/NeuralIntegral.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +import math + + +def _flatten(sequence): + flat = [p.contiguous().view(-1) for p in sequence] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +def compute_cc_weights(nb_steps): + lam = np.arange(0, nb_steps + 1, 1).reshape(-1, 1) + lam = np.cos((lam @ lam.T) * math.pi / nb_steps) + lam[:, 0] = .5 + lam[:, -1] = .5 * lam[:, -1] + lam = lam * 2 / nb_steps + W = np.arange(0, nb_steps + 1, 1).reshape(-1, 1) + W[np.arange(1, nb_steps + 1, 2)] = 0 + W = 2 / (1 - W ** 2) + W[0] = 1 + W[np.arange(1, nb_steps + 1, 2)] = 0 + cc_weights = torch.tensor(lam.T @ W).float() + steps = torch.tensor(np.cos(np.arange(0, nb_steps + 1, 1).reshape(-1, 1) * math.pi / nb_steps)).float() + + return cc_weights, steps + + +def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None): + #Clenshaw-Curtis Quadrature Method + cc_weights, steps = compute_cc_weights(nb_steps) + + device = x0.get_device() if x0.is_cuda else "cpu" + cc_weights, steps = cc_weights.to(device), steps.to(device) + + if compute_grad: + g_param = 0. + g_h = 0. + else: + z = 0. + xT = x0 + nb_steps*step_sizes + for i in range(nb_steps + 1): + x = (x0 + (xT - x0)*(steps[i] + 1)/2) + if compute_grad: + dg_param, dg_h = computeIntegrand(x, h, integrand, x_tot*(xT - x0)/2) + g_param += cc_weights[i]*dg_param + g_h += cc_weights[i]*dg_h + else: + dz = integrand(x, h) + z = z + cc_weights[i]*dz + + if compute_grad: + return g_param, g_h + + return z*(xT - x0)/2 + + +def computeIntegrand(x, h, integrand, x_tot): + with torch.enable_grad(): + f = integrand.forward(x, h) + g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True)) + g_h = _flatten(torch.autograd.grad(f, h, x_tot)) + + return g_param, g_h + + +class NeuralIntegral(torch.autograd.Function): + + @staticmethod + def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20): + with torch.no_grad(): + x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False) + # Save for backward + ctx.integrand = integrand + ctx.nb_steps = nb_steps + ctx.save_for_backward(x0.clone(), x.clone(), h) + return x_tot + + @staticmethod + def backward(ctx, grad_output): + x0, x, h = ctx.saved_tensors + integrand = ctx.integrand + nb_steps = ctx.nb_steps + integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output) + x_grad = integrand(x, h) + x0_grad = integrand(x0, h) + # Leibniz formula + return -x0_grad*grad_output, x_grad*grad_output, None, integrand_grad, h_grad.view(h.shape), None diff --git a/src/rail/evaluation/metrics/condition_pit_utils/ParallelNeuralIntegral.py b/src/rail/evaluation/metrics/condition_pit_utils/ParallelNeuralIntegral.py new file mode 100644 index 00000000..4fa20384 --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/ParallelNeuralIntegral.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import math + + +def _flatten(sequence): + flat = [p.contiguous().view(-1) for p in sequence] + return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +def compute_cc_weights(nb_steps): + lam = np.arange(0, nb_steps + 1, 1).reshape(-1, 1) + lam = np.cos((lam @ lam.T) * math.pi / nb_steps) + lam[:, 0] = .5 + lam[:, -1] = .5 * lam[:, -1] + lam = lam * 2 / nb_steps + W = np.arange(0, nb_steps + 1, 1).reshape(-1, 1) + W[np.arange(1, nb_steps + 1, 2)] = 0 + W = 2 / (1 - W ** 2) + W[0] = 1 + W[np.arange(1, nb_steps + 1, 2)] = 0 + cc_weights = torch.tensor(lam.T @ W).float() + steps = torch.tensor(np.cos(np.arange(0, nb_steps + 1, 1).reshape(-1, 1) * math.pi / nb_steps)).float() + + return cc_weights, steps + + +def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None): + #Clenshaw-Curtis Quadrature Method + cc_weights, steps = compute_cc_weights(nb_steps) + + device = x0.get_device() if x0.is_cuda else "cpu" + cc_weights, steps = cc_weights.to(device), steps.to(device) + + xT = x0 + nb_steps*step_sizes + if not compute_grad: + x0_t = x0.unsqueeze(1).expand(-1, nb_steps + 1, -1) + xT_t = xT.unsqueeze(1).expand(-1, nb_steps + 1, -1) + h_steps = h.unsqueeze(1).expand(-1, nb_steps + 1, -1) + steps_t = steps.unsqueeze(0).expand(x0_t.shape[0], -1, x0_t.shape[2]) + X_steps = x0_t + (xT_t-x0_t)*(steps_t + 1)/2 + X_steps = X_steps.contiguous().view(-1, x0_t.shape[2]) + h_steps = h_steps.contiguous().view(-1, h.shape[1]) + dzs = integrand(X_steps, h_steps) + dzs = dzs.view(xT_t.shape[0], nb_steps+1, -1) + dzs = dzs*cc_weights.unsqueeze(0).expand(dzs.shape) + z_est = dzs.sum(1) + return z_est*(xT - x0)/2 + else: + + x0_t = x0.unsqueeze(1).expand(-1, nb_steps + 1, -1) + xT_t = xT.unsqueeze(1).expand(-1, nb_steps + 1, -1) + x_tot = x_tot * (xT - x0) / 2 + x_tot_steps = x_tot.unsqueeze(1).expand(-1, nb_steps + 1, -1) * cc_weights.unsqueeze(0).expand(x_tot.shape[0], -1, x_tot.shape[1]) + h_steps = h.unsqueeze(1).expand(-1, nb_steps + 1, -1) + steps_t = steps.unsqueeze(0).expand(x0_t.shape[0], -1, x0_t.shape[2]) + X_steps = x0_t + (xT_t - x0_t) * (steps_t + 1) / 2 + X_steps = X_steps.contiguous().view(-1, x0_t.shape[2]) + h_steps = h_steps.contiguous().view(-1, h.shape[1]) + x_tot_steps = x_tot_steps.contiguous().view(-1, x_tot.shape[1]) + + g_param, g_h = computeIntegrand(X_steps, h_steps, integrand, x_tot_steps, nb_steps+1) + return g_param, g_h + + +def computeIntegrand(x, h, integrand, x_tot, nb_steps): + h.requires_grad_(True) + with torch.enable_grad(): + f = integrand.forward(x, h) + g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True)) + g_h = _flatten(torch.autograd.grad(f, h, x_tot)) + + return g_param, g_h.view(int(x.shape[0]/nb_steps), nb_steps, -1).sum(1) + + +class ParallelNeuralIntegral(torch.autograd.Function): + + @staticmethod + def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20): + with torch.no_grad(): + x_tot = integrate(x0, nb_steps, (x - x0)/nb_steps, integrand, h, False) + # Save for backward + ctx.integrand = integrand + ctx.nb_steps = nb_steps + ctx.save_for_backward(x0.clone(), x.clone(), h) + return x_tot + + @staticmethod + def backward(ctx, grad_output): + x0, x, h = ctx.saved_tensors + integrand = ctx.integrand + nb_steps = ctx.nb_steps + integrand_grad, h_grad = integrate(x0, nb_steps, x/nb_steps, integrand, h, True, grad_output) + x_grad = integrand(x, h) + x0_grad = integrand(x0, h) + # Leibniz formula + return -x0_grad*grad_output, x_grad*grad_output, None, integrand_grad, h_grad.view(h.shape), None diff --git a/src/rail/evaluation/metrics/condition_pit_utils/ispline.py b/src/rail/evaluation/metrics/condition_pit_utils/ispline.py new file mode 100644 index 00000000..cec511c6 --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/ispline.py @@ -0,0 +1,1277 @@ +""" +================= +ispline class obtained from: https://jbloomlab.github.io/dms_variants/_modules/dms_variants/ispline.html#Isplines +pdf and cdf fitting funtions added by Biprateep Dey +================= + +Implements :class:`Isplines`, which are monotonic spline functions that are +defined in terms of :class:`Msplines`. Also implements :class:`Isplines_total` +for the weighted sum of a :class:`Isplines` family. + +See `Ramsay (1988)`_ for details about these splines, and also note the +corrections in the `Praat manual`_ to the errors in the I-spline formula +by `Ramsay (1988)`_. + +.. _`Ramsay (1988)`: https://www.jstor.org/stable/2245395 +.. _`Praat manual`: http://www.fon.hum.uva.nl/praat/manual/spline.html + +""" + +import numpy as np + +# This is optional but makes fitting faster +# from sklearnex import patch_sklearn +# patch_sklearn() + +from sklearn.linear_model import LinearRegression + + +def fit_cdf(x, y, x_predict=None, num_basis=10, fit_intercept=True): + """[summary] + + Parameters + ---------- + x : [type] + [description] + y : [type] + [description] + num_basis : int, optional + [description], by default 10 + fit_intercept : bool, optional + [description], by default True + + Returns + ------- + [type] + [description] + + Raises + ------ + ValueError + [description] + """ + order = 3 # fixing the I-spline order + num_mesh_points = num_basis + 2 - order # num_splines = num_mesh_points + 2 - order + if (type(num_basis) != int) or (num_mesh_points <= 0): + raise ValueError(f"num_basis should be an integer greater than {order - 2}") + mesh = np.linspace(0, 1, num_mesh_points) + isplines = Isplines(order, mesh, x) + + # if fit_intercept: + # num_basis = num_basis + 1 + X = np.ones((len(x), num_basis)) + for i in range(isplines.n): + X[:, i] = isplines.I(i + 1) + model = LinearRegression(positive=True, fit_intercept=fit_intercept) + model.fit(X, y) + if x_predict is not None: + isplines = Isplines(order, mesh, x_predict) + X = np.ones((len(x_predict), num_basis)) + for i in range(isplines.n): + X[:, i] = isplines.I(i + 1) + y_fit = model.predict(X) + else: + y_fit = model.predict(X) + + return y_fit, model.coef_, model.intercept_ + + +def get_pdf(cdf_grid, cdf, pdf_grid, num_basis=10, fit_intercept=True): + """[summary] + + Parameters + ---------- + cdf_grid : [type] + [description] + cdf : [type] + [description] + pdf_grid : [type] + [description] + num_basis : int, optional + [description], by default 10 + fit_intercept : bool, optional + [description], by default True + + Returns + ------- + [type] + [description] + """ + _, coef, intercept = fit_cdf( + x=cdf_grid, y=cdf, num_basis=num_basis, fit_intercept=fit_intercept + ) + + order = 3 # fixing the I-spline order + num_mesh_points = num_basis + 2 - order # num_splines = num_mesh_points + 2 - order + mesh = np.linspace(0, 1, num_mesh_points) + isplines = Isplines(order, mesh, pdf_grid) + + pdf = np.ones((len(pdf_grid), (isplines.n))) + for i in range(isplines.n): + pdf[:, i] = isplines.dI_dx(i + 1) + # if fit_intercept: + # coef = coef[:-1] #The last coefficient is the intercept + norm = coef.sum() # The basis CDF range from 0 to 1, i.e. basis PDFs integrate to 1 + if norm <= 0.0: + norm = 1.0 + pdf = np.sum(coef * pdf, axis=-1) # /norm + + return pdf, coef, intercept + + +class Isplines_total: + r"""Evaluate the weighted sum of an I-spline family (see `Ramsay (1988)`_). + + Parameters + ---------- + order : int + Sets :attr:`Isplines_total.order`. + mesh : array-like + Sets :attr:`Isplines_total.mesh`. + x : np.ndarray + Sets :attr:`Isplines_total.x`. + + Attributes + ---------- + order : int + See :attr:`Isplines.order`. + mesh : np.ndarray + See :attr:`Isplines.mesh`. + n : int + See :attr:`Isplines.n`. + lower : float + See :attr:`Isplines.lower`. + upper : float + See :attr:`Isplines.upper`. + + Note + ---- + Evaluates the full interpolating curve from the I-splines. When + :math:`x` falls within the lower :math:`L` and upper :math:`U` + bounds of the range covered by the I-splines (:math:`L \le x \le U`), + then this curve is defined as: + + .. math:: + + I_{\rm{total}}\left(x\right) + = + w_{\rm{lower}} + \sum_i w_i I_i\left(x\right). + + When :math:`x` is outside the range of the mesh covered by the splines, + the values are linearly extrapolated from first derivative at the + bounds. Specifically, if :math:`x < L` then: + + .. math:: + + I_{\rm{total}}\left(x\right) + = + I_{\rm{total}}\left(L\right) + + \left(x - L\right) + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=L}, + + and if :math:`x > U` then: + + .. math:: + + I_{\rm{total}}\left(x\right) + = + I_{\rm{total}}\left(U\right) + + \left(x - U\right) + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=U}. + + Note also that: + + .. math:: + + I_{\rm{total}}\left(L\right) &=& w_{\rm{lower}}, \\ + I_{\rm{total}}\left(U\right) &=& w_{\rm{lower}} + \sum_i w_i + + Example + ------- + Short examples to demonstrate and test :class:`Isplines_total`: + + .. plot:: + :context: reset + + >>> import itertools + >>> import numpy as np + >>> import pandas as pd + >>> import scipy.optimize + >>> from dms_variants.ispline import Isplines_total + + >>> order = 3 + >>> mesh = [0.0, 0.3, 0.5, 0.6, 1.0] + >>> x = np.array([0, 0.2, 0.3, 0.4, 0.8, 0.99999]) + >>> isplines_total = Isplines_total(order, mesh, x) + >>> weights = np.array([1.2, 2, 1.2, 1.2, 3, 0]) / 6 + >>> np.round(isplines_total.Itotal(weights, w_lower=0), 2) + array([0. , 0.38, 0.54, 0.66, 1.21, 1.43]) + + Now calculate using some points that require linear extrapolation + outside the mesh and also have a nonzero `w_lower`: + + >>> x2 = np.array([-0.5, -0.25, 0, 0.01, 1.0, 1.5]) + >>> isplines_total2 = Isplines_total(order, mesh, x2) + >>> np.round(isplines_total2.Itotal(weights, w_lower=1), 3) + array([0. , 0.5 , 1. , 1.02 , 2.433, 2.433]) + + Test :meth:`Isplines_total.dItotal_dx`: + + >>> x_deriv = np.array([-0.5, -0.25, 0, 0.01, 0.5, 0.7, 1.0, 1.5]) + >>> for xval in x_deriv: + ... xval = np.array([xval]) + ... def func(xval): + ... return Isplines_total(order, mesh, xval).Itotal(weights, 0) + ... def dfunc(xval): + ... return Isplines_total(order, mesh, xval).dItotal_dx(weights) + ... err = scipy.optimize.check_grad(func, dfunc, xval) + ... if err > 1e-5: + ... raise ValueError(f"excess err {err} for {xval}") + + >>> (isplines_total.dItotal_dw_lower() == np.ones(x.shape)).all() + True + + Test :meth:`Isplines_total.dItotal_dweights`: + + >>> isplines_total3 = Isplines_total(order, mesh, x_deriv) + >>> wl = 1.5 + >>> (isplines_total3.dItotal_dweights(weights, wl).shape == + ... (len(x_deriv), len(weights))) + True + >>> weightslist = list(weights) + >>> for ix, iw in itertools.product(range(len(x_deriv)), + ... range(len(weights))): + ... w = np.array([weightslist[iw]]) + ... def func(w): + ... iweights = np.array(weightslist[: iw] + + ... list(w) + + ... weightslist[iw + 1:]) + ... return isplines_total3.Itotal(iweights, wl)[ix] + ... def dfunc(w): + ... iweights = np.array(weightslist[: iw] + + ... list(w) + + ... weightslist[iw + 1:]) + ... return isplines_total3.dItotal_dweights(iweights, wl)[ix, + ... iw] + ... err = scipy.optimize.check_grad(func, dfunc, w) + ... if err > 1e-6: + ... raise ValueError(f"excess err {err} for {ix, iw}") + + Plot the total of the I-spline family shown in Fig. 1 of + `Ramsay (1988)`_, adding some linear extrapolation outside the + mesh range: + + >>> xplot = np.linspace(-0.2, 1.2, 1000) + >>> isplines_totalplot = Isplines_total(order, mesh, xplot) + >>> df = pd.DataFrame({'x': xplot, + ... 'Itotal': isplines_totalplot.Itotal(weights, 0)}) + >>> _ = df.plot(x='x', y='Itotal') + + .. _`Ramsay (1988)`: https://www.jstor.org/stable/2245395 + + """ + + def __init__(self, order, mesh, x): + """See main class docstring.""" + if not (isinstance(order, int) and order >= 1): + raise ValueError(f"`order` not int >= 1: {order}") + self.order = order + + self.mesh = np.array(mesh, dtype="float") + if self.mesh.ndim != 1: + raise ValueError(f"`mesh` not array-like of dimension 1: {mesh}") + if len(self.mesh) < 2: + raise ValueError(f"`mesh` not length >= 2: {mesh}") + if not np.array_equal(self.mesh, np.unique(self.mesh)): + raise ValueError(f"`mesh` elements not unique and sorted: {mesh}") + self.lower = self.mesh[0] + self.upper = self.mesh[-1] + assert self.lower < self.upper + + self.n = len(self.mesh) - 2 + self.order + + self._x = x.copy() + self._x.flags.writeable = False + + # indices of `x` in, above, or below I-spline range + self._index = { + "lower": np.flatnonzero(self.x < self.lower), + "upper": np.flatnonzero(self.x > self.upper), + "in": np.flatnonzero((self.x >= self.lower) & (self.x <= self.upper)), + } + + # values of x in each range + self._x_byrange = { + rangename: self.x[index] for rangename, index in self._index.items() + } + + # Isplines for each range: for lower and upper it is value at bound + self._isplines = { + "in": Isplines(self.order, self.mesh, self._x_byrange["in"]), + "lower": Isplines(self.order, self.mesh, np.array([self.lower])), + "upper": Isplines(self.order, self.mesh, np.array([self.upper])), + } + + # for caching values + self._cache = {} + self._max_cache_size = 100 + + @property + def x(self): + """np.ndarray: Points at which spline is evaluated.""" + return self._x + + def Itotal(self, weights, w_lower): + r"""Weighted sum of spline family at points :attr:`Isplines_total.x`. + + Parameters + ---------- + weights : array-like + Nonnegative weights :math:`w_i` of members :math:`I_i` of spline + family, should be of length equal to :attr:`Isplines.n`. + w_lower : float + The value at the lower bound :math:`L` of the spline range, + :math:`w_{\rm{lower}}`. + + Returns + ------- + np.ndarray + :math:`I_{\rm{total}}` for each point in :attr:`Isplines_total.x`. + + """ + args = (tuple(weights), w_lower, "Itotal") + if args not in self._cache: + if len(self._cache) > self._max_cache_size: + self._cache = {} + self._cache[args] = self._calculate_Itotal_or_dItotal(*args) + return self._cache[args] + + def _calculate_Itotal_or_dItotal(self, weights, w_lower, quantity): + """Calculate :meth:`Isplines.Itotal` or derivatives. + + Parameters have same meaning as for :meth:`Isplines.Itotal` + except for `quantity`, which should be + + - 'Itotal' to compute :meth:`Isplines.Itotal` + - 'dItotal_dx' to compute :meth:`Isplines.dItotal_dx` + - 'dItotal_dweights` to compute :meth:`Isplines.dItotal_dweights` + + Also, `weights` must be hashable (e.g., a tuple). + + """ + # check validity of `weights` + if len(weights) != self.n: + raise ValueError(f"invalid length of `weights`: {weights}") + if any(weight < 0 for weight in weights): + raise ValueError(f"`weights` not all non-negative: {weights}") + + # compute return values for each category of indices + returnvals = {} + + if quantity == "Itotal": + returnshape = len(self.x) + if len(self._index["in"]): + returnvals["in"] = ( + np.sum( + [ + self._isplines["in"].I(i) * weights[i - 1] + for i in range(1, self.n + 1) + ], + axis=0, + ) + + w_lower + ) + # values of Itotal at limits + Itotal_limits = {"lower": w_lower, "upper": w_lower + sum(weights)} + for name, limit in [("lower", self.lower), ("upper", self.upper)]: + if not len(self._index[name]): + continue + returnvals[name] = Itotal_limits[name] + ( + self._x_byrange[name] - limit + ) * sum( + self._isplines[name].dI_dx(i) * weights[i - 1] + for i in range(1, self.n + 1) + ) + + elif quantity == "dItotal_dx": + returnshape = len(self.x) + if len(self._index["in"]): + returnvals["in"] = np.sum( + [ + self._isplines["in"].dI_dx(i) * weights[i - 1] + for i in range(1, self.n + 1) + ], + axis=0, + ) + for name in ["lower", "upper"]: + if not len(self._index[name]): + continue + returnvals[name] = sum( + self._isplines[name].dI_dx(i) * weights[i - 1] + for i in range(1, self.n + 1) + ) + + elif quantity == "dItotal_dweights": + returnshape = (len(self.x), len(weights)) + if len(self._index["in"]): + returnvals["in"] = ( + np.vstack([self._isplines["in"].I(i) for i in range(1, self.n + 1)]) + ).transpose() + # values of I at limits + I_limits = {"lower": 0.0, "upper": 1.0} + for name, limit in [("lower", self.lower), ("upper", self.upper)]: + if not len(self._index[name]): + continue + returnvals[name] = np.vstack( + [ + I_limits[name] + + (self._x_byrange[name] - limit) + * self._isplines[name].dI_dx(i) + for i in range(1, self.n + 1) + ] + ).transpose() + + else: + raise ValueError(f"invalid `quantity` {quantity}") + + # reconstruct single return value from indices and returnvalues + returnval = np.full(returnshape, fill_value=np.nan) + for name, name_index in self._index.items(): + if len(name_index): + returnval[name_index] = returnvals[name] + assert not np.isnan(returnval).any() + returnval.flags.writeable = False + return returnval + + def dItotal_dx(self, weights): + r"""Deriv :meth:`Isplines_total.Itotal` by :attr:`Isplines_total.x`. + + Note + ---- + Derivatives calculated from equations in :meth:`Isplines_total.Itotal`: + + .. math:: + + \frac{\partial I_{\rm{total}}\left(x\right)}{\partial x} + = + \begin{cases} + \sum_i w_i \frac{\partial I_i\left(x\right)}{\partial x} + & \rm{if\;} L \le x \le U, \\ + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=L} + & \rm{if\;} x < L, \\ + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=U} + & \rm{otherwise}. + \end{cases} + + Note that + + .. math:: + + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=L} + &=& + \sum_i w_i \left.\frac{\partial I_i\left(y\right)}{\partial y} + \right\rvert_{y=L} + \\ + \left.\frac{\partial I_{\rm{total}}\left(y\right)} + {\partial y}\right\rvert_{y=U} + &=& + \sum_i w_i \left.\frac{\partial I_i\left(y\right)}{\partial y} + \right\rvert_{y=U} + + Parameters + ---------- + weights : array-like + Same meaning as for :meth:`Isplines_total.Itotal`. + + Returns + ------- + np.ndarray + Derivative :math:`\frac{\partial I_{\rm{total}}}{\partial x}` + for each point in :attr:`Isplines_total.x`. + + """ + args = (tuple(weights), None, "dItotal_dx") + if args not in self._cache: + if len(self._cache) > self._max_cache_size: + self._cache = {} + self._cache[args] = self._calculate_Itotal_or_dItotal(*args) + return self._cache[args] + + def dItotal_dweights(self, weights, w_lower): + r"""Derivative of :meth:`Isplines_total.Itotal` by :math:`w_i`. + + Parameters + ---------- + weights : array-like + Same meaning as for :meth:`Isplines.Itotal`. + w_lower : float + Same meaning as for :meth:`Isplines.Itotal`. + + Returns + ------- + np.ndarray + Array is of shape `(len(x), len(weights))`, and element + `ix, iweight` gives derivative with respect to weight + `weights[iweight]` at element `[ix]` of :attr:`Isplines_total.x`. + + Note + ---- + The derivative is: + + .. math:: + + \frac{\partial I_{\rm{total}}\left(x\right)}{\partial w_i} + = + \begin{cases} + I_i\left(x\right) + & \rm{if\;} L \le x \le U, \\ + I_i\left(L\right) + \left(x-L\right) + \left.\frac{\partial I_i\left(y\right)}{\partial y}\right\vert_{y=L} + & \rm{if\;} x < L, \\ + I_i\left(U\right) + \left(x-U\right) + \left.\frac{\partial I_i\left(y\right)}{\partial y}\right\vert_{y=U} + & \rm{if\;} x > U. + \end{cases} + + Note that: + + .. math:: + + I_i\left(L\right) &=& 0 \\ + I_i\left(U\right) &=& 1. + + """ + return self._calculate_Itotal_or_dItotal( + tuple(weights), w_lower, "dItotal_dweights" + ) + + def dItotal_dw_lower(self): + r"""Deriv of :meth:`Isplines_total.Itotal` by :math:`w_{\rm{lower}}`. + + Returns + ------- + np.ndarray + :math:`\frac{\partial{I_{\rm{total}}}}{\partial w_{\rm{lower}}}`, + which is just one for all :attr:`Isplines_total.x`. + + """ + res = np.ones(self.x.shape, dtype="float") + res.flags.writeable = False + return res + + +class Isplines: + r"""Implements I-splines (see `Ramsay (1988)`_). + + Parameters + ---------- + order : int + Sets :attr:`Isplines.order`. + mesh : array-like + Sets :attr:`Isplines.mesh`. + x : np.ndarray + Sets :attr:`Isplines.x`. + + Attributes + ---------- + order : int + Order of spline, :math:`k` in notation of `Ramsay (1988)`_. Note that + the degree of the I-spline is equal to :math:`k`, while the + associated M-spline has order :math:`k` but degree :math:`k - 1`. + mesh : np.ndarray + Mesh sequence, :math:`\xi_1 < \ldots < \xi_q` in the notation + of `Ramsay (1988)`_. This class implements **fixed** mesh sequences. + n : int + Number of members in spline, denoted as :math:`n` in `Ramsay (1988)`_. + Related to number of points :math:`q` in the mesh and the order + :math:`k` by :math:`n = q - 2 + k`. + lower : float + Lower end of interval spanned by the splines (first point in mesh). + upper : float + Upper end of interval spanned by the splines (last point in mesh). + + Note + ---- + The methods of this class cache their results and return immutable + numpy arrays. Do **not** make these arrays mutable and change their + values, as this will lead to invalid caching. + + Example + ------- + Short examples to demonstrate and test :class:`Isplines`: + + .. plot:: + :context: reset + + >>> import itertools + >>> import numpy as np + >>> import pandas as pd + >>> import scipy.optimize + >>> from dms_variants.ispline import Isplines + + >>> order = 3 + >>> mesh = [0.0, 0.3, 0.5, 0.6, 1.0] + >>> x = np.array([0, 0.2, 0.3, 0.4, 0.8, 0.99999]) + >>> isplines = Isplines(order, mesh, x) + >>> isplines.order + 3 + >>> isplines.mesh + array([0. , 0.3, 0.5, 0.6, 1. ]) + >>> isplines.n + 6 + >>> isplines.lower + 0.0 + >>> isplines.upper + 1.0 + + Evaluate the I-splines at some selected points: + + >>> for i in range(1, isplines.n + 1): + ... print(f"I{i}: {np.round(isplines.I(i), 2)}") + ... # doctest: +NORMALIZE_WHITESPACE + I1: [0. 0.96 1. 1. 1. 1. ] + I2: [0. 0.52 0.84 0.98 1. 1. ] + I3: [0. 0.09 0.3 0.66 1. 1. ] + I4: [0. 0. 0. 0.02 0.94 1. ] + I5: [0. 0. 0. 0. 0.58 1. ] + I6: [0. 0. 0. 0. 0.13 1. ] + + Check that gradients are correct for :meth:`Isplines.dI_dx`: + + >>> for i, xval in itertools.product(range(1, isplines.n + 1), x): + ... xval = np.array([xval]) + ... def func(xval): + ... return Isplines(order, mesh, xval).I(i) + ... def dfunc(xval): + ... return Isplines(order, mesh, xval).dI_dx(i) + ... err = scipy.optimize.check_grad(func, dfunc, xval) + ... if err > 1e-5: + ... raise ValueError(f"excess err {err} for {i}, {xval}") + + Plot the I-splines in Fig. 1 of `Ramsay (1988)`_: + + >>> xplot = np.linspace(0, 1, 1000) + >>> isplines_xplot = Isplines(order, mesh, xplot) + >>> data = {'x': xplot} + >>> for i in range(1, isplines.n + 1): + ... data[f"I{i}"] = isplines_xplot.I(i) + >>> df = pd.DataFrame(data) + >>> _ = df.plot(x='x') + + .. _`Ramsay (1988)`: https://www.jstor.org/stable/2245395 + + """ + + def __init__(self, order, mesh, x): + """See main class docstring.""" + if not (isinstance(order, int) and order >= 1): + raise ValueError(f"`order` not int >= 1: {order}") + self.order = order + + self.mesh = np.array(mesh, dtype="float") + if self.mesh.ndim != 1: + raise ValueError(f"`mesh` not array-like of dimension 1: {mesh}") + if len(self.mesh) < 2: + raise ValueError(f"`mesh` not length >= 2: {mesh}") + if not np.array_equal(self.mesh, np.unique(self.mesh)): + raise ValueError(f"`mesh` elements not unique and sorted: {mesh}") + self.lower = self.mesh[0] + self.upper = self.mesh[-1] + assert self.lower < self.upper + + self.n = len(self.mesh) - 2 + self.order + + if not (isinstance(x, np.ndarray) and x.ndim == 1): + raise ValueError("`x` is not np.ndarray of dimension 1") + if (x < self.lower).any() or (x > self.upper).any(): + raise ValueError(f"`x` outside {self.lower} and {self.upper}: {x}") + self._x = x.copy() + self._x.flags.writeable = False + + self._msplines = Msplines(order + 1, mesh, self.x) + + # for caching values + self._cache = {} + self._max_cache_size = 100 + + @property + def x(self): + """np.ndarray: Points at which spline is evaluated.""" + return self._x + + def I(self, i): # noqa: E743,E741 + r"""Evaluate spline :math:`I_i` at point(s) :attr:`Isplines.x`. + + Parameters + ---------- + i : int + Spline member :math:`I_i`, where :math:`1 \le i \le` + :attr:`Isplines.n`. + + Returns + ------- + np.ndarray + The values of the I-spline at each point in :attr:`Isplines.x`. + + Note + ---- + The spline is evaluated using the formula given in the + `Praat manual`_, which corrects some errors in the formula + provided by `Ramsay (1988)`_: + + .. math:: + + I_i\left(x\right) + = + \begin{cases} + 0 & \rm{if\;} i > j, \\ + 1 & \rm{if\;} i < j - k, \\ + \sum_{m=i+1}^j \left(t_{m+k+1} - t_m\right) + M_m\left(x \mid k + 1\right) / \left(k + 1 \right) + & \rm{otherwise}, + \end{cases} + + where :math:`j` is the index such that :math:`t_j \le x < t_{j+1}` + (the :math:`\left\{t_j\right\}` are the :attr:`Msplines.knots` for a + M-spline of order :math:`k + 1`) and :math:`k` is + :attr:`Isplines.order`. + + .. _`Ramsay (1988)`: https://www.jstor.org/stable/2245395 + .. _`Praat manual`: http://www.fon.hum.uva.nl/praat/manual/spline.html + + """ + args = (i, "I") + if args not in self._cache: + if len(self._cache) > self._max_cache_size: + self._cache = {} + self._cache[args] = self._calculate_I_or_dI(*args) + return self._cache[args] + + @property + def j(self): + """np.ndarray: :math:`j` as defined in :meth:`Isplines.I`.""" + if not hasattr(self, "_j"): + self._j = np.searchsorted(self._msplines.knots, self.x, "right") + assert (1 <= self._j).all() and (self._j <= len(self._msplines.knots)).all() + assert self.x.shape == self._j.shape + return self._j + + @property + def _sum_terms_I(self): + """np.ndarray: sum terms for :meth:`Isplines.I`. + + Row `m - 1` has summation term for `m`. + + """ + if not hasattr(self, "_sum_terms_I_val"): + k = self.order + self._sum_terms_I_val = np.vstack( + [ + (self._msplines.knots[m + k] - self._msplines.knots[m - 1]) + * self._msplines.M(m, k + 1) + / (k + 1) + for m in range(1, self._msplines.n + 1) + ] + ) + assert self._sum_terms_I_val.shape == (self._msplines.n, len(self.x)) + return self._sum_terms_I_val + + @property + def _sum_terms_dI_dx(self): + """np.ndarray: sum terms for :meth:`Isplines.dI_dx`. + + Row `m - 1` has summation term for `m`. + + """ + if not hasattr(self, "_sum_terms_dI_dx_val"): + k = self.order + self._sum_terms_dI_dx_val = np.vstack( + [ + (self._msplines.knots[m + k] - self._msplines.knots[m - 1]) + * self._msplines.dM_dx(m, k + 1) + / (k + 1) + for m in range(1, self._msplines.n + 1) + ] + ) + assert self._sum_terms_dI_dx_val.shape == (self._msplines.n, len(self.x)) + return self._sum_terms_dI_dx_val + + def _calculate_I_or_dI(self, i, quantity): + """Calculate :meth:`Isplines.I` or :meth:`Isplines.dI_dx`. + + Parameters + ---------- + i : int + Same meaning as for :meth:`Isplines.I`. + quantity : {'I', 'dI'} + Calculate :meth:`Isplines.I` or :meth:`Isplines.dI_dx`? + + Returns + ------- + np.ndarray + The return value of :meth:`Isplines.I` or :meth:`Isplines.dI_dx`. + + Note + ---- + Most calculations for :meth:`Isplines.I` and :meth:`Isplines.dI_dx` + are the same, so this method implements both. + + """ + if quantity == "I": + sum_terms = self._sum_terms_I + i_lt_jminusk = 1.0 + elif quantity == "dI": + sum_terms = self._sum_terms_dI_dx + i_lt_jminusk = 0.0 + else: + raise ValueError(f"invalid `quantity` {quantity}") + + if not (1 <= i <= self.n): + raise ValueError(f"invalid spline member `i` of {i}") + + k = self.order + + # create `binary_terms` where entry (m - 1, x) is 1 if and only if + # the corresponding `sum_terms` entry is part of the sum. + binary_terms = np.vstack( + [ + np.zeros(len(self.x)) if m < i + 1 else (m <= self.j).astype(int) + for m in range(1, self._msplines.n + 1) + ] + ) + assert binary_terms.shape == sum_terms.shape + + # compute sums from `sum_terms` and `binary_terms` + sums = np.sum(sum_terms * binary_terms, axis=0) + assert sums.shape == self.x.shape + + # return value with sums, 0, or 1 + res = np.where(i > self.j, 0.0, np.where(i < self.j - k, i_lt_jminusk, sums)) + res.flags.writeable = False + return res + + def dI_dx(self, i): + r"""Derivative of :meth:`Isplines.I` by :attr:`Isplines.x`. + + Parameters + ---------- + i : int + Same meaning as for :meth:`Isplines.I`. + + Returns + ------- + np.ndarray + Derivative of I-spline with respect to :attr:`Isplines.x`. + + Note + ---- + The derivative is calculated from the equation in :meth:`Isplines.I`: + + .. math:: + + \frac{\partial I_i\left(x\right)}{\partial x} + = + \begin{cases} + 0 & \rm{if\;} i > j \rm{\; or \;} i < j - k, \\ + \sum_{m=i+1}^j\left(t_{m+k+1} - t_m\right) + \frac{\partial M_m\left(x \mid k+1\right)}{\partial x} + \frac{1}{k + 1} + & \rm{otherwise}. + \end{cases} + + """ + args = (i, "dI") + if args not in self._cache: + if len(self._cache) > self._max_cache_size: + self._cache = {} + self._cache[args] = self._calculate_I_or_dI(*args) + return self._cache[args] + + +class Msplines: + r"""Implements M-splines (see `Ramsay (1988)`_). + + Parameters + ---------- + order : int + Sets :attr:`Msplines.order`. + mesh : array-like + Sets :attr:`Msplines.mesh`. + x : np.ndarray + Sets :attr:`Msplines.x`. + + Attributes + ---------- + order : int + Order of spline, :math:`k` in notation of `Ramsay (1988)`_. + Polynomials are of degree :math:`k - 1`. + mesh : np.ndarray + Mesh sequence, :math:`\xi_1 < \ldots < \xi_q` in the notation + of `Ramsay (1988)`_. This class implements **fixed** mesh sequences. + n : int + Number of members in spline, denoted as :math:`n` in `Ramsay (1988)`_. + Related to number of points :math:`q` in the mesh and the order + :math:`k` by :math:`n = q - 2 + k`. + knots : np.ndarray + The knot sequence, :math:`t_1, \ldots, t_{n + k}` in the notation of + `Ramsay (1988)`_. + lower : float + Lower end of interval spanned by the splines (first point in mesh). + upper : float + Upper end of interval spanned by the splines (last point in mesh). + + Note + ---- + The methods of this class cache their results and return immutable + numpy arrays. Do **not** make those arrays mutable and change their + values as this will lead to invalid caching. + + Example + ------- + Demonstrate and test :class:`Msplines`: + + .. plot:: + :context: reset + + >>> import functools + >>> import itertools + >>> import numpy as np + >>> import pandas as pd + >>> import scipy.optimize + >>> from dms_variants.ispline import Msplines + + >>> order = 3 + >>> mesh = [0.0, 0.3, 0.5, 0.6, 1.0] + >>> x = np.array([0, 0.2, 0.3, 0.4, 0.8, 0.99999]) + >>> msplines = Msplines(order, mesh, x) + >>> msplines.order + 3 + >>> msplines.mesh + array([0. , 0.3, 0.5, 0.6, 1. ]) + >>> msplines.n + 6 + >>> msplines.knots + array([0. , 0. , 0. , 0.3, 0.5, 0.6, 1. , 1. , 1. ]) + >>> msplines.lower + 0.0 + >>> msplines.upper + 1.0 + + Evaluate the M-splines at some selected points: + + >>> for i in range(1, msplines.n + 1): + ... print(f"M{i}: {np.round(msplines.M(i), 2)}") + ... # doctest: +NORMALIZE_WHITESPACE + M1: [10. 1.11 0. 0. 0. 0. ] + M2: [0. 3.73 2.4 0.6 0. 0. ] + M3: [0. 1.33 3. 3.67 0. 0. ] + M4: [0. 0. 0. 0.71 0.86 0. ] + M5: [0. 0. 0. 0. 3.3 0. ] + M6: [0. 0. 0. 0. 1.88 7.5 ] + + Check that the gradients are correct: + + >>> for i, xval in itertools.product(range(1, msplines.n + 1), x): + ... xval = np.array([xval]) + ... def func(xval): + ... return Msplines(order, mesh, xval).M(i) + ... def dfunc(xval): + ... return Msplines(order, mesh, xval).dM_dx(i) + ... err = scipy.optimize.check_grad(func, dfunc, xval) + ... if err > 1e-5: + ... raise ValueError(f"excess err {err} for {i}, {xval}") + + Plot the M-splines in in Fig. 1 of `Ramsay (1988)`_: + + >>> xplot = np.linspace(0, 1, 1000, endpoint=False) + >>> msplines_plot = Msplines(order, mesh, xplot) + >>> data = {'x': xplot} + >>> for i in range(1, msplines_plot.n + 1): + ... data[f"M{i}"] = msplines_plot.M(i) + >>> df = pd.DataFrame(data) + >>> _ = df.plot(x='x') + + .. _`Ramsay (1988)`: https://www.jstor.org/stable/2245395 + + """ + + def __init__(self, order, mesh, x): + """See main class docstring.""" + if not (isinstance(order, int) and order >= 1): + raise ValueError(f"`order` not int >= 1: {order}") + self.order = order + + self.mesh = np.array(mesh, dtype="float") + if self.mesh.ndim != 1: + raise ValueError(f"`mesh` not array-like of dimension 1: {mesh}") + if len(self.mesh) < 2: + raise ValueError(f"`mesh` not length >= 2: {mesh}") + if not np.array_equal(self.mesh, np.unique(self.mesh)): + raise ValueError(f"`mesh` elements not unique and sorted: {mesh}") + self.lower = self.mesh[0] + self.upper = self.mesh[-1] + assert self.lower < self.upper + + self.knots = np.array( + [self.lower] * self.order + + list(self.mesh[1:-1]) + + [self.upper] * self.order, + dtype="float", + ) + + self.n = len(self.knots) - self.order + assert self.n == len(self.mesh) - 2 + self.order + + if not (isinstance(x, np.ndarray) and x.ndim == 1): + raise ValueError("`x` is not np.ndarray of dimension 1") + if (x < self.lower).any() or (x > self.upper).any(): + raise ValueError(f"`x` outside {self.lower} and {self.upper}: {x}") + self._x = x.copy() + self._x.flags.writeable = False + + self._ti_le_x_lt_tiplusk_cache = {} + + # for caching values + self._M_cache = {} + self._dM_cache = {} + self._max_cache_size = 100 + + def _ti_le_x_lt_tiplusk(self, ti, tiplusk): + r"""Indices where :math:`t_i \le x \le t_{i+k}`. + + Parameters + ---------- + ti : float + :math:`t_i` + tiplusk : float + :math:`t_{i+k}` + + Returns + ------- + np.ndarray + Array of booleans of same length as :attr:`Msplines.x` indicating + if :math:`t_i \le x \le t_{i+k}`. + + """ + key = (ti, tiplusk) + if key not in self._ti_le_x_lt_tiplusk_cache: + val = (ti <= self.x) & (self.x < tiplusk) + val.flags.writeable = False + assert val.dtype == bool + if len(self._ti_le_x_lt_tiplusk_cache) > self._max_cache_size: + self._ti_le_x_lt_tiplusk_cache = {} + self._ti_le_x_lt_tiplusk_cache[key] = val + return self._ti_le_x_lt_tiplusk_cache[key] + + @property + def x(self): + """np.ndarray: Points at which spline is evaluated.""" + return self._x + + def M(self, i, k=None, invalid_i="raise"): + r"""Evaluate spline :math:`M_i` at point(s) :attr:`Msplines.x`. + + Parameters + ---------- + i : int + Spline member :math:`M_i`, where :math:`1 \le i \le` + :attr:`Msplines.n`. + k : int or None + Order of spline. If `None`, assumed to be :attr:`Msplines.order`. + invalid_i : {'raise', 'zero'} + If `i` is invalid, do we raise an error or return 0? + + Returns + ------- + np.ndarray + The values of the M-spline at each point in :attr:`Msplines.x`. + + Note + ---- + The spline is evaluated using the recursive relationship given by + `Ramsay (1988) `_: + + .. math:: + + M_i\left(x \mid k=1\right) + &=& + \begin{cases} + 1 / \left(t_{i+1} - t_i\right), & \rm{if\;} t_i \le x < t_{i+1} \\ + 0, & \rm{otherwise} + \end{cases} \\ + M_i\left(x \mid k > 1\right) &=& + \begin{cases} + \frac{k\left[\left(x - t_i\right) M_i\left(x \mid k-1\right) + + \left(t_{i+k} -x\right) M_{i+1}\left(x \mid k-1\right) + \right]} + {\left(k - 1\right)\left(t_{i + k} - t_i\right)}, + & \rm{if\;} t_i \le x < t_{i+k} \\ + 0, & \rm{otherwise} + \end{cases} + + """ + args = (i, k, invalid_i) + if args not in self._M_cache: + if len(self._M_cache) > self._max_cache_size: + self._M_cache = {} + self._M_cache[args] = self._calculate_M(*args) + return self._M_cache[args] + + def _calculate_M(self, i, k, invalid_i): + """Calculate :meth:`Msplines.M` with result caching.""" + if not (1 <= i <= self.n): + if invalid_i == "raise": + raise ValueError(f"invalid spline member `i` of {i}") + elif invalid_i == "zero": + return 0 + else: + raise ValueError(f"invalid `invalid_i` of {invalid_i}") + if k is None: + k = self.order + if not 1 <= k <= self.order: + raise ValueError(f"invalid spline order `k` of {k}") + + tiplusk = self.knots[i + k - 1] + ti = self.knots[i - 1] + if tiplusk == ti: + return 0 + + boolindex = self._ti_le_x_lt_tiplusk(ti, tiplusk) + if k == 1: + res = np.where(boolindex, 1.0 / (tiplusk - ti), 0.0) + res.flags.writeable = False + return res + else: + assert k > 1 + res = np.where( + boolindex, + ( + k + * ( + (self.x - ti) * self.M(i, k - 1) + + (tiplusk - self.x) * self.M(i + 1, k - 1, invalid_i="zero") + ) + / ((k - 1) * (tiplusk - ti)) + ), + 0.0, + ) + res.flags.writeable = False + return res + + def dM_dx(self, i, k=None, invalid_i="raise"): + r"""Derivative of :meth:`Msplines.M` by to :attr:`Msplines.x`. + + Parameters + ---------- + i : int + Same as for :meth:`Msplines.M`. + k : int or None + Same as for :meth:`Msplines.M`. + invalid_i : {'raise', 'zero'} + Same as for :meth:`Msplines.M`. + + Returns + ------- + np.ndarray + Derivative of M-spline with respect to :attr:`Msplines.x`. + + Note + ---- + The derivative is calculated from the equation in :meth:`Msplines.M`: + + .. math:: + + \frac{\partial M_i\left(x \mid k=1\right)}{\partial x} &=& 0 + \\ + \frac{\partial M_i\left(x \mid k > 1\right)}{\partial x} + &=& + \begin{cases} + \frac{k\left[\left(x - t_i\right) + \frac{\partial M_i\left(x \mid k-1\right)}{\partial x} + + + M_i\left(x \mid k-1\right) + + + \left(t_{i+k} -x\right) + \frac{\partial M_{i+1}\left(x \mid k-1\right)} + {\partial x} + - + M_{i+1}\left(x \mid k-1\right) + \right]} + {\left(k - 1\right)\left(t_{i + k} - t_i\right)}, + & \rm{if\;} t_i \le x < t_{i+1} \\ + 0, & \rm{otherwise} + \end{cases} + + """ + args = (i, k, invalid_i) + if args not in self._dM_cache: + if len(self._dM_cache) > self._max_cache_size: + self._dM_cache = {} + self._dM_cache[args] = self._calculate_dM_dx(*args) + return self._dM_cache[args] + + def _calculate_dM_dx(self, i, k, invalid_i): + """Calculate :meth:`Msplines.dM_dx` with results caching.""" + if not (1 <= i <= self.n): + if invalid_i == "raise": + raise ValueError(f"invalid spline member `i` of {i}") + elif invalid_i == "zero": + return 0 + else: + raise ValueError(f"invalid `invalid_i` of {invalid_i}") + if k is None: + k = self.order + if not 1 <= k <= self.order: + raise ValueError(f"invalid spline order `k` of {k}") + + tiplusk = self.knots[i + k - 1] + ti = self.knots[i - 1] + if tiplusk == ti or k == 1: + return 0 + else: + assert k > 1 + boolindex = self._ti_le_x_lt_tiplusk(ti, tiplusk) + res = np.where( + boolindex, + ( + k + * ( + (self.x - ti) * self.dM_dx(i, k - 1) + + self.M(i, k - 1) + + (tiplusk - self.x) + * self.dM_dx(i + 1, k - 1, invalid_i="zero") + - self.M(i + 1, k - 1, invalid_i="zero") + ) + / ((k - 1) * (tiplusk - ti)) + ), + 0.0, + ) + res.flags.writeable = False + return res + + +if __name__ == "__main__": + import doctest + + doctest.testmod() + +# from scipy.interpolate import PchipInterpolator + +# def fit_cdf(x, y, x_predict, **kwargs): +# cdf_funct = PchipInterpolator(x,y, extrapolate=True) +# cdf = cdf_funct(x_predict) + +# return cdf, 0,0 + +# def get_pdf(cdf_grid, cdf, pdf_grid, **kwargs): + +# cdf_funct = PchipInterpolator(cdf_grid,cdf, extrapolate=True) +# pdf_func = cdf_funct.derivative(1) +# pdf = pdf_func(pdf_grid) + +# return pdf, 0,0 diff --git a/src/rail/evaluation/metrics/condition_pit_utils/mlp_training.py b/src/rail/evaluation/metrics/condition_pit_utils/mlp_training.py new file mode 100644 index 00000000..0cf6fd93 --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/mlp_training.py @@ -0,0 +1,367 @@ +# Copyright (C) 2021,2022 Bitrateep Dey, University of Pittsburgh, USA + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader, TensorDataset +from src.rail.evaluation.metrics.condition_pit_utils.MonotonicNN import MonotonicNN +from prettytable import PrettyTable +from tqdm import trange + +use_amp = True # Flag to use automatic mixed precision + + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + param = parameter.numel() + table.add_row([name, param]) + total_params += param + print(table) + print(f"Total Trainable Params: {total_params}") + return total_params + + +# Define an MLP model +class MLP(nn.Module): + def __init__(self, input_dim=6, output_dim=1, hidden_layers=[512, 512, 512]): + super().__init__() + self.all_layers = [input_dim] + self.all_layers.extend(hidden_layers) + self.all_layers.append(output_dim) + + self.layer_list = [] + for i in range(len(self.all_layers) - 1): + self.layer_list.append(nn.Linear(self.all_layers[i], self.all_layers[i + 1])) + self.layer_list.append(nn.PReLU()) + + self.layer_list.pop() + # self.layer_list.append( nn.Sigmoid()) + self.layers = nn.Sequential(*self.layer_list) + + def init_weights(m): + if isinstance(m, nn.Linear): + torch.nn.init.kaiming_normal_(m.weight) + m.bias.data.fill_(0.01) + + self.layers.apply(init_weights) + + def forward(self, x): + + return self.layers(x) + + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + + def __init__( + self, patience=7, verbose=False, delta=0, path="checkpoint.pt", trace_func=print + ): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + path (str): Path for the checkpoint to be saved to. + Default: 'checkpoint.pt' + trace_func (function): trace print function. + Default: print + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.path = path + self.trace_func = trace_func + + def __call__(self, val_loss, model): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + self.trace_func( + f"EarlyStopping counter: {self.counter} out of {self.patience}" + ) + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + """Saves model when validation loss decrease.""" + if self.verbose: + self.trace_func( + f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." + ) + torch.save(model.state_dict(), self.path) + self.val_loss_min = val_loss + + +class RandomDataset(Dataset): + def __init__(self, X, Y, oversample=1): + self.X = X + self.Y = Y + self.len_x = len(X) + self.oversample = oversample + + def __len__(self): + return int(len(self.X) * self.oversample) + + def __getitem__(self, idx): + alpha = torch.rand(1) + feature = torch.hstack((alpha, torch.Tensor(self.X[idx % self.len_x]))) + target = (self.Y[idx % self.len_x] <= alpha).float() + return feature, target + + +def load_model(input_size, hidden_layers, checkpt_path, nn_type="monotonic", sigmoid=False, gpu_id="cuda:0"): + # Use gpu if available + device = torch.device(gpu_id if torch.cuda.is_available() else "cpu") + if nn_type == "mlp": + rhat = MLP(input_size, 1, hidden_layers).to(device) + if nn_type == "monotonic": + rhat = MonotonicNN(input_size, hidden_layers, nb_steps=200, dev=device, sigmoid=sigmoid).to(device) + + count_parameters(rhat) + rhat.load_state_dict(torch.load(checkpt_path)) + + return rhat + + +def train_local_pit(X, + pit_values, + patience=20, + n_epochs=1000, + lr=0.001, + weight_decay=1e-5, + batch_size=2048, + frac_mlp_train=0.9, + lr_decay=0.99, + trace_func=print, + oversample=1, + n_alpha=201, + checkpt_path="./checkpoint_.pt", + nn_type="monotonic", + hidden_layers=[512, 512, 512], + gpu_id="cuda:0", + sigmoid=False): + _EPSILON = 0.01 + # Use gpu if available + device = torch.device(gpu_id if torch.cuda.is_available() else "cpu") + # alpha grid for validation set + alphas_grid = np.linspace(0.001, 0.999, n_alpha) + + # Split into train and valid sets + train_size = int(frac_mlp_train * len(X)) + valid_size = len(X) - train_size + + rnd_idx = np.random.default_rng().permutation(len(X)) + + x_train_rnd = X[rnd_idx[:train_size]] + x_val_rnd = X[rnd_idx[train_size:]] + + pit_train_rand = pit_values[rnd_idx[:train_size]] + pit_val_rand = pit_values[rnd_idx[train_size:]] + + # Creat randomized Data set for training + trainset = RandomDataset(x_train_rnd, pit_train_rand, oversample=oversample) + + # Create static dataset for testing + feature_val = torch.cat( + [ + torch.Tensor(np.repeat(alphas_grid, len(x_val_rnd)))[:, None], + torch.Tensor(np.tile(x_val_rnd, (len(alphas_grid), 1))), + ], + dim=-1, + ) + target_val = torch.Tensor( + np.tile(pit_val_rand, len(alphas_grid)) + <= np.repeat(alphas_grid, len(x_val_rnd)) + ).float()[:, None] + + validset = TensorDataset(feature_val, target_val) + + # Create Data loader + train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + valid_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False) + + # Initialize the Model and optimizer, etc. + training_loss = [] + validation_mse = [] + validation_bce = [] + validation_weighted_mse = [] + cal_loss = [] + + input_size = X.shape[1] + 1 + if nn_type == "mlp": + rhat = MLP(input_size, 1, hidden_layers).to(gpu_id) + if nn_type == "monotonic": + rhat = MonotonicNN(input_size, hidden_layers, nb_steps=200, dev=device, sigmoid=sigmoid).to(device) + + count_parameters(rhat) + # Optimizer + optimizer = torch.optim.AdamW(rhat.parameters(), lr=lr, weight_decay=weight_decay) + # optimizer = torch.optim.SGD(rhat.parameters(), lr=lr) + # Use lr decay + schedule_rule = lambda epoch: lr_decay ** epoch + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=schedule_rule) + + # Cosine annelaing + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5) + + early_stopping = EarlyStopping( + patience=patience, verbose=True, path=checkpt_path, trace_func=trace_func + ) + + # Training loop + for epoch in range(1, n_epochs + 1): + training_loss_batch = [] + validation_mse_batch = [] + validation_bce_batch = [] + validation_weighted_mse_batch = [] + alpha_arr = [] + out_arr = [] + target_arr = [] + + # Training + rhat.train() # prep model for training + for batch, (feature, target) in enumerate(train_dataloader, start=1): + feature = feature.to(device) # .requires_grad_() + target = target.to(device) # .requires_grad_() + alpha = feature[:, 0] + # clear the gradients of all optimized variables + output = rhat(feature.float()) + loss = ((output - target.float()) ** 2).sum() + + # loss = (torch.squeeze((output - target.float()) ** 2)/((_EPSILON + alpha)*(_EPSILON + 1.0 - alpha))).sum() + + # loss_fn = torch.nn.BCELoss(reduction='sum') + # loss = loss_fn(torch.clamp(torch.squeeze(output), min=0.0, max=1.0), torch.squeeze(target)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # record training loss + training_loss_batch.append(loss.item()) + # Validation + rhat.eval() # prep model for evaluation + + for feature, target in valid_dataloader: + feature = feature.to(device) + target = target.to(device) + alpha = feature[:, 0] + + # forward pass: compute predicted outputs by passing inputs to the model + output = rhat(feature.float()) + + # calculate the loss + mse = ((output - target.float()) ** 2).sum() + # record validation loss + validation_mse_batch.append(mse.item()) + + weighted_mse = (torch.squeeze((output - target.float()) ** 2) / ( + (_EPSILON + alpha) * (_EPSILON + 1.0 - alpha))).sum() + validation_weighted_mse_batch.append(weighted_mse.item()) + + criterion = torch.nn.BCELoss(reduction='sum') + bce = criterion(torch.clamp(torch.squeeze(output), min=1e-6, max=0.9999999), torch.squeeze(target)) + validation_bce_batch.append(bce.item()) + + alpha_arr.extend(alpha.tolist()) + out_arr.extend(torch.squeeze(output).tolist()) + target_arr.extend(torch.squeeze(target).tolist()) + + out_arr = (np.array(out_arr) <= np.array(alpha_arr)) + out_arr = np.array(out_arr).reshape(-1, n_alpha) + target_arr = np.array(target_arr).reshape(-1, n_alpha) + cal_loss_epoch = np.mean((np.mean(out_arr, axis=0) - np.mean(target_arr)) ** 2) + # print training/validation statistics + # calculate average loss over an epoch + train_loss_epoch = np.sum(training_loss_batch) / (train_size * oversample) + valid_mse_epoch = np.sum(validation_mse_batch) / (valid_size * n_alpha) + valid_bce_epoch = np.sum(validation_bce_batch) / (valid_size * n_alpha) + valid_weighted_mse_epoch = np.sum(validation_weighted_mse_batch) / (valid_size * n_alpha) + training_loss.append(train_loss_epoch) + validation_mse.append(valid_mse_epoch) + validation_mse.append(valid_bce_epoch) + validation_weighted_mse.append(valid_weighted_mse_epoch) + cal_loss.append(cal_loss_epoch) + + epoch_len = len(str(n_epochs)) + + msg = ( + f"[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] | " + + f"train_loss: {train_loss_epoch:.5f} |\n" + + f"valid_ece: {cal_loss_epoch:.5f} | " + + f"valid_mse: {valid_mse_epoch:.5f} | " + + f"valid_wght_mse: {valid_weighted_mse_epoch:.5f} | " + + f"valid_bce: {valid_bce_epoch:.5f} | " + ) + + trace_func(msg) + + # change the lr + scheduler.step() + + # early_stopping needs the validation loss to check if it has decresed, + # and if it has, it will make a checkpoint of the current model + early_stopping(valid_mse_epoch, rhat) + # early_stopping(cal_loss_epoch, rhat) + + if early_stopping.early_stop: + print("Early stopping") + break + + # # load the last checkpoint with the best model + rhat.load_state_dict(torch.load(checkpt_path)) + return rhat, training_loss, (validation_mse, validation_bce) + + +def get_local_pit(rhat, x_test, alphas, batch_size=1, gpu_id="cuda:0"): + # Use gpu if available + + device = torch.device(gpu_id if torch.cuda.is_available() else "cpu") + n_test = len(x_test) + all_betas = [] + rhat.to(device) + n_alpha = len(alphas) + n_batches = (n_test - 1) // batch_size + 1 + for i in trange(n_batches): + x = x_test[i * batch_size: (i + 1) * batch_size] + with torch.no_grad(): + all_betas_batch = rhat( + torch.Tensor(np.hstack([np.repeat(alphas, len(x))[:, None], np.tile(x, (n_alpha, 1))])).to(device) + ).detach().cpu().numpy().reshape(n_alpha, -1).T + + all_betas_batch[all_betas_batch < 0] = 0 + all_betas_batch[all_betas_batch > 1] = 1 + all_betas.extend(all_betas_batch) + # print(f"Batch: {i}/{n_test}", end="\r") + + return np.array(all_betas) + + +def trapz_grid(y, x): + """ + Does trapezoid integration between the same limits as the grid. + """ + dx = np.diff(x) + trapz_area = dx* (y[:, 1:]+y[:, :-1])/2 + integral = np.cumsum(trapz_area, axis =-1) + return np.hstack((np.zeros(len(integral))[:,None], integral)) diff --git a/src/rail/evaluation/metrics/condition_pit_utils/utils.py b/src/rail/evaluation/metrics/condition_pit_utils/utils.py new file mode 100644 index 00000000..4bf3947b --- /dev/null +++ b/src/rail/evaluation/metrics/condition_pit_utils/utils.py @@ -0,0 +1,210 @@ +from scipy.stats import binom +from matplotlib import pyplot as plt +import numpy as np + + +def normalize(cde_estimates, x_grid, tol=1e-6, max_iter=200): + """Normalizes conditional density estimates to be non-negative and + integrate to one. + + :param cde_estimates: a numpy array or matrix of conditional density estimates. + :param tol: float, the tolerance to accept for abs(area - 1). + :param max_iter: int, the maximal number of search iterations. + :returns: the normalized conditional density estimates. + :rtype: numpy array or matrix. + + """ + if cde_estimates.ndim == 1: + normalized_cde = _normalize(cde_estimates, x_grid, tol, max_iter) + else: + normalized_cde = np.apply_along_axis(_normalize, 1, cde_estimates, x_grid, tol=tol, max_iter=max_iter) + return normalized_cde + + +def _normalize(density, x_grid, tol=1e-6, max_iter=500): + """Normalizes a density estimate to be non-negative and integrate to + one. + + :param density: a numpy array of density estimates. + :param z_grid: an array, the grid points at the density is estimated. + :param tol: float, the tolerance to accept for abs(area - 1). + :param max_iter: int, the maximal number of search iterations. + :returns: the normalized density estimate. + :rtype: numpy array. + + """ + hi = np.max(density) + lo = 0.0 + + area = np.trapz(np.maximum(density, 0.0), x_grid) + if area == 0.0: + # replace with uniform if all negative density + density[:] = 1 / (x_grid.max() - x_grid.min()) + elif area < 1: + density /= area + density[density < 0.0] = 0.0 + return density + + for _ in range(max_iter): + mid = (hi + lo) / 2 + area = np.trapz(np.maximum(density - mid, 0.0), x_grid) + if abs(1.0 - area) <= tol: + break + if area < 1.0: + hi = mid + else: + lo = mid + + # update in place + density -= mid + density[density < 0.0] = 0.0 + + return density + + +def kolmogorov_smirnov_statistic(cdf_test, cdf_ref): + """ + cdf_test: CDF of the test distribution (array) + cdf_ref: CDF of the reference distribution on the same grid (array) + """ + ks = np.max(np.abs(cdf_test - cdf_ref), axis=-1) + + return ks + + +def cramer_von_mises(cdf_test, cdf_ref): + """ + cdf_test: CDF of the test distribution (1D array) + cdf_ref: CDF of the reference distribution on the same grid (1D array) + """ + diff = (cdf_test - cdf_ref) ** 2 + + cvm2 = np.trapz(diff, cdf_ref, axis=-1) + return np.sqrt(cvm2) + + +def anderson_darling_statistic(cdf_test, cdf_ref, n_tot=1): + """ + cdf_test: CDF of the test distribution (1D array) + cdf_ref: CDF of the reference distribution on the same grid (1D array) + n_tot:Scaling factor equal to the number of PDFs used to construct ECDF + """ + num = (cdf_test - cdf_ref) ** 2 + den = cdf_ref * (1 - cdf_ref) + + ad2 = n_tot * np.trapz((num / den), cdf_ref, axis=-1) + return np.sqrt(ad2) + + +def get_pit(cdes: np.ndarray, z_grid: np.ndarray, z_test: np.ndarray) -> np.ndarray: + """ + Calculates PIT based on CDE + + cdes: a numpy array of conditional density estimates; + each row corresponds to an observation, each column corresponds to a grid + point + z_grid: a numpy array of the grid points at which cde_estimates is evaluated + z_test: a numpy array of the true z values corresponding to the rows of cde_estimates + + returns: A numpy array of values + + """ + # flatten the input arrays to 1D + z_grid = np.ravel(z_grid) + z_test = np.ravel(z_test) + + # Sanity checks + nrow_cde, ncol_cde = cdes.shape + n_samples = z_test.shape[0] + n_grid_points = z_grid.shape[0] + + if nrow_cde != n_samples: + raise ValueError( + "Number of samples in CDEs should be the same as in z_test." + "Currently %s and %s." % (nrow_cde, n_samples) + ) + if ncol_cde != n_grid_points: + raise ValueError( + "Number of grid points in CDEs should be the same as in z_grid." + "Currently %s and %s." % (nrow_cde, n_grid_points) + ) + + z_min = np.min(z_grid) + z_max = np.max(z_grid) + z_delta = (z_max - z_min) / (n_grid_points - 1) + + # Vectorized implementation using masked arrays + pit = np.ma.masked_array(cdes, (z_grid > z_test[:, np.newaxis])) + pit = np.trapz(pit, z_grid) + + return np.array(pit) + + +def plot_pit(pit_values, ci_level, n_bins=30, y_true=None, ax=None, **fig_kw): + """ + Plots the PIT/HPD histogram and calculates the confidence interval for the bin values, were the PIT/HPD values follow an uniform distribution + + @param values: a numpy array with PIT/HPD values + @param ci_level: a float between 0 and 1 indicating the size of the confidence level + @param x_label: a string, populates the x_label of the plot + @param n_bins: an integer, the number of bins in the histogram + @param figsize: a tuple, the plot size (width, height) + @param ylim: a list of two elements, including the lower and upper limit for the y axis + + @returns The matplotlib figure object with the histogram of the PIT/HPD values and the CI for the uniform distribution + """ + + # Extract the number of CDEs + n = pit_values.shape[0] + + # Creating upper and lower limit for selected uniform band + ci_quantity = (1 - ci_level) / 2 + low_lim = binom.ppf(q=ci_quantity, n=n, p=1 / n_bins) + upp_lim = binom.ppf(q=ci_level + ci_quantity, n=n, p=1 / n_bins) + + # Creating figure + + if ax is None: + fig, ax = plt.subplots(1, 2, **fig_kw) + + # plot PIT histogram + ax[0].hist(pit_values, bins=n_bins) + ax[0].axhline(y=low_lim, color="grey") + ax[0].axhline(y=upp_lim, color="grey") + ax[0].axhline(y=n / n_bins, label="Uniform Average", color="red") + ax[0].fill_between( + x=np.linspace(0, 1, 100), + y1=np.repeat(low_lim, 100), + y2=np.repeat(upp_lim, 100), + color="grey", + alpha=0.2, + ) + ax[0].set_xlabel("PIT Values") + ax[0].legend(loc="best") + + # plot P-P plot + prob_theory = np.linspace(0.01, 0.99, 100) + prob_data = [np.sum(pit_values < i) / len(pit_values) for i in prob_theory] + # # plot Q-Q + # quants = np.linspace(0, 100, 100) + # quant_theory = quants/100. + # quant_data = np.percentile(pit_values,quants) + + ax[1].scatter(prob_theory, prob_data, marker=".") + ax[1].plot(prob_theory, prob_theory, c="k", ls="--") + ax[1].set_xlim(0, 1) + ax[1].set_ylim(0, 1) + ax[1].set_xlabel("Expected Cumulative Probability") + ax[1].set_ylabel("Empirical Cumulative Probability") + xlabels = np.linspace(0, 1, 6)[1:] + ax[1].set_xticks(xlabels) + ax[1].set_aspect("equal") + if y_true is not None: + ks = kolmogorov_smirnov_statistic(prob_data, prob_theory) + ad = anderson_darling_statistic(prob_data, prob_theory, len(y_true)) + cvm = cramer_von_mises(prob_data, prob_theory) + ax[1].text(0.05, 0.9, f"KS: ${ks:.3f} $", fontsize=15) + ax[1].text(0.05, 0.84, f"CvM: ${cvm:.3f} $", fontsize=15) + ax[1].text(0.05, 0.78, f"AD: ${ad:.2f} $", fontsize=15) + + return fig, ax \ No newline at end of file diff --git a/src/rail/evaluation/metrics/pit.py b/src/rail/evaluation/metrics/pit.py new file mode 100644 index 00000000..c9ea0d70 --- /dev/null +++ b/src/rail/evaluation/metrics/pit.py @@ -0,0 +1,418 @@ +import inspect +import numpy as np +from scipy import stats +import qp +from deprecated import deprecated +from .base import MetricEvaluator +from rail.evaluation.utils import stat_and_pval, stat_crit_sig +from sklearn.preprocessing import StandardScaler +from src.rail.evaluation.metrics.condition_pit_utils.mlp_training import train_local_pit, load_model, get_local_pit, trapz_grid +from joblib import Parallel, delayed +from src.rail.evaluation.metrics.condition_pit_utils.ispline import fit_cdf +from src.rail.evaluation.metrics.condition_pit_utils.utils import get_pit +from tqdm import trange +import matplotlib.pyplot as plt + + +default_quants = np.linspace(0, 1, 100) +_pitMetaMetrics = {} + + +def PITMetaMetric(cls): + """Decorator function to attach metrics to a class""" + argspec = inspect.getargspec(cls.evaluate) + if argspec.defaults is not None: + num_defaults = len(argspec.defaults) + kwargs = dict(zip(argspec.args[-num_defaults:], argspec.defaults)) + _pitMetaMetrics.setdefault(cls, {})["default"] = kwargs + return cls + +@deprecated( + reason=""" + This implementation of PIT is deprecated. + Please use qp.metrics.pit.PIT(qp_ens, z_true, quant_grid) from the qp package. + """, + category=DeprecationWarning) +class PIT(MetricEvaluator): + """ Probability Integral Transform """ + + def __init__(self, qp_ens, ztrue): + """Class constructor""" + super().__init__(qp_ens) + + self._ztrue = ztrue + self._pit_samps = np.array([self._qp_ens[i].cdf(self._ztrue[i])[0][0] for i in range(len(self._ztrue))]) + + + @property + def pit_samps(self): + """Return the samples used to compute the PIT""" + return self._pit_samps + + + def evaluate(self, eval_grid=default_quants, meta_options=_pitMetaMetrics): + """Compute PIT array using qp.Ensemble class + Notes + ----- + We will create a quantile Ensemble to store the PIT distribution, but also store the + full set of PIT values as ancillary data of the (single PDF) ensemble. I think + the current metrics do not actually need the distribution, but we'll keep it here + in case future PIT metrics need to make use of it. + """ + n_pit = np.min([len(self._pit_samps), len(eval_grid)]) + if n_pit < len(eval_grid): #pragma: no cover + eval_grid = np.linspace(0, 1, n_pit) + data_quants = np.quantile(self._pit_samps, eval_grid) + pit = qp.Ensemble(qp.quant_piecewise, data=dict(quants=eval_grid, locs=np.atleast_2d(data_quants))) + + #pit = qp.spline_from_samples(xvals=eval_grid, + # samples=np.atleast_2d(self._pit_samps)) + #pit.samples = self._pit_samps + + if meta_options is not None: + metamets = {} + for cls, params in meta_options.items(): + meta = cls(self._pit_samps, pit) + for name, kwargs in params.items(): + metamets[(cls, name)] = meta.evaluate(**kwargs) + + # self.qq = _evaluate_qq(self._eval_grid) + return pit, metamets + + # def _evaluate_qq(self): + # q_data = qp.convert(self._pit, 'quant', quants=self._eval_grid) + # return q_data + # + # def plot_all_pit(self): + # plot_utils.ks_plot(self) + # + # + # def plot_pit_qq(self, bins=None, code=None, title=None, show_pit=True, + # show_qq=True, pit_out_rate=None, savefig=False): + # """Make plot PIT-QQ as Figure 2 from Schmidt et al. 2020.""" + # fig_filename = plot_utils.plot_pit_qq(self, bins=bins, code=code, title=title, + # show_pit=show_pit, show_qq=show_qq, + # pit_out_rate=pit_out_rate, + # savefig=savefig) + # return fig_filename + + +class ConditionPIT(MetricEvaluator): + #def __init__(self, qp_ens_cde_calib, qp_ens_cde_test, z_grid, z_true_calib, z_true_test, + # features_calib, features_test): + def __init__(self, cde_calib, cde_test, z_grid, z_true_calib, z_true_test, + features_calib, features_test, qp_ens_cde_calib): + """ + + Parameters + ---------- + cde_calib + cde_test + z_grid + z_true_calib + z_true_test + features_calib + features_test + qp_ens_cde_calib + """ + + super().__init__(qp_ens_cde_calib) + + # cde conditional density estimate + # reading and storing + # input are PDFs evaluated with a photo=z method on representative sample of objects, + # features and ztrue (zspec for real data). + # calculate uncondition pit on the fly because it's input to training + # up to block 10 in bitrateep notebook goes outside, but the standard scaling will be done inside init + # because the scaler itself is needed for both training and evaluation + # pit_calib = get_pit(cde_calib, z_grid, z_calib) + # pit_test = get_pit(cde_test, z_grid, z_test) + # the train are going to be those galaxies that have spectroscopic redshifts, test the others + + # single leading underscore is indicates to the user of the class that the attribute should only be accessed + # by the class's internals (or perhaps those of a subclass) and that they need not directly access it + # and probably shouldn't modify it. when you import everything from the + # class you don't import objects whose name starts with an underscore + #self._qp_ens_cde_test = qp_ens_cde_test + self._cde_calib = cde_calib + self._cde_test = cde_test + self._zgrid = z_grid + self._ztrue_calib = z_true_calib + self._ztrue_test = z_true_test + self._features_calib = features_calib + self._features_test = features_test + + # now let's apply the standard scaler + scaler = StandardScaler() + self.x_calib = scaler.fit_transform(self._features_calib) # with or without the underscore? + self.x_test = scaler.transform(self._features_test) + + # now let's do pit using Bitrateep utils get_pit + self.uncond_pit_calib = get_pit(cde_calib, z_grid, self._ztrue_calib) + self.uncond_pit_test = get_pit(cde_test, z_grid, self._ztrue_test) + + # now let's do pit using the unconditional pit coded above + # uncond_pit_calib_class = PIT(self._qp_ens, self._ztrue_calib) + # self.uncond_pit_calib = uncond_pit_calib_class.evaluate(eval_grid=self._zgrid) + # uncond_pit_test_class = PIT(self._qp_ens_cde_test, self._ztrue_test) + # self.uncond_pit_test = uncond_pit_test_class.evaluate(eval_grid=self._zgrid) + + def train(self, patience=10, n_epochs=10000, lr=0.001, weight_decay=0.01, batch_size=2048, frac_mlp_train=0.9, + lr_decay=0.95, oversample=50, n_alpha=201, checkpt_path="./checkpoint_GPZ_wide_CDE_1024x512x512.pt", + hidden_layers=None): + """ + + Parameters + ---------- + patience + n_epochs + lr + weight_decay + batch_size + frac_mlp_train + lr_decay + oversample + n_alpha + checkpt_path + hidden_layers + + Returns + ------- + + """ + + # training, hyperparameters need to be tuned + if hidden_layers is None: #pragma: no cover + hidden_layers = [256, 256, 256] + rhat, _, _ = train_local_pit(X=self.x_calib, pit_values=self.uncond_pit_calib, patience=patience, + n_epochs=n_epochs, lr=lr, weight_decay=weight_decay, batch_size=batch_size, + frac_mlp_train=frac_mlp_train, lr_decay=lr_decay, trace_func=print, + oversample=oversample, n_alpha=n_alpha, checkpt_path=checkpt_path, + hidden_layers=hidden_layers) + + def evaluate(self, eval_grid=default_quants, meta_options=None, model_checkpt_path='model_checkpt_path', + model_hidden_layers=None, nn_type='monotonic', batch_size=100, num_basis=40, + num_cores=1): + """ + + Parameters + ---------- + eval_grid + meta_options + model_checkpt_path + model_hidden_layers + nn_type + batch_size + num_basis + num_cores + + Returns + ------- + + """ + + # we just need the features X test since the model has been trained in the function train and we just need to + # run the model on the features to obtain directly the calibrated PDFs. + # get pit local and ispline fits + + if meta_options is None: #pragma: no cover + meta_options = _pitMetaMetrics + if model_hidden_layers is None: #pragma: no cover + model_hidden_layers = [1024, 512, 512] + + rhat = load_model(input_size=self.x_test.shape[1] + 1, hidden_layers=model_hidden_layers, + checkpt_path=model_checkpt_path, nn_type=nn_type) + self.alphas = np.linspace(0.0, 1, len(self._zgrid)) + pit_local = get_local_pit(rhat, self.x_test, alphas=self.alphas, batch_size=batch_size) + + self.cdf_test = trapz_grid(self._cde_test, self._zgrid) + self.cdf_test[self.cdf_test > 1] = 1 + + pit_local_fit, _, _ = zip(*Parallel(n_jobs=num_cores)( + delayed(fit_cdf)(self.alphas, pit_local[i, :], self.cdf_test[i, :], num_basis=num_basis) for i in + trange(len(pit_local)))) + + return pit_local, np.array(pit_local_fit) + + def diagnostics(self, pit_local, pit_local_fit, figure_filepath): + """ + + Parameters + ---------- + pit_local + pit_local_fit + figure_filepath + + Returns + ------- + + """ + + # P-P plot creation, not one for every galaxy but something clever + plt.clf() + rng = np.random.default_rng(42) + random_idx = rng.choice(len(self.x_test), 25, replace=False) + fig, axs = plt.subplots(5,5, figsize=(15, 15)) + axs = np.ravel(axs) + + for count, index in enumerate(random_idx): + axs[count].scatter(self.alphas, pit_local[index], s=1) + axs[count].scatter(self.cdf_test[index], pit_local_fit[index], c="C1") + axs[count].plot(self._zgrid, self.cdf_test[index], c="k") + axs[count].plot(np.linspace(0, 1, 10), np.linspace(0, 1, 10), color="k", ls="--") + axs[count].set_xlim(0, 1) + axs[count].set_ylim(0, 1) + axs[count].set_aspect("equal") + fig.suptitle("Local P-P plot", fontsize=30) + + fig.text(0.5,-0.05,"Theoretical P", fontsize=30) + fig.text(-0.05,0.5,"Empirical P", rotation=90, fontsize=30) + plt.tight_layout() + plt.savefig(figure_filepath) + plt.close() + + +@deprecated( + reason=""" + This class is deprecated. + It has been superseded by the qp.metrics.pit.PIT class in the qp-prob package. + """, + category=DeprecationWarning) +class PITMeta(): + """ A superclass for metrics of the PIT""" + + def __init__(self, pit_vals, pit): + """Class constructor. + Parameters + ---------- + pit: qp.spline_from_samples object + PIT values + """ + self._pit = pit + self._pit_vals = pit_vals + + # they all seem to have a way to trim the ends, so maybe just bring those together here? + # def _trim(self, pit_min=0., pit_max=1.): + # + + def evaluate(self): #pragma: no cover + """ + Evaluates the metric a function of the truth and prediction + + Returns + ------- + metric: dictionary + value of the metric and statistics thereof + """ + raise NotImplementedError + + +@deprecated( + reason=""" + This class is deprecated. + It has been incorporated into the qp.metrics.pit.PIT class in the qp-prob package. + """, + category=DeprecationWarning) +@PITMetaMetric +class PITOutRate(PITMeta): + """ Fraction of PIT outliers """ + + def evaluate(self, pit_min=0.0001, pit_max=0.9999): + """Compute fraction of PIT outliers""" + out_area = (self._pit.cdf(pit_min) + (1. - self._pit.cdf(pit_max)))[0][0] + return out_area + + +@deprecated( + reason=""" + This class is deprecated. + It has been incorporated into the qp.metrics.pit.PIT class in the qp-prob package. + """, + category=DeprecationWarning) +@PITMetaMetric +class PITKS(PITMeta): + """ Kolmogorov-Smirnov test statistic """ + + def evaluate(self): + """ Use scipy.stats.kstest to compute the Kolmogorov-Smirnov test statistic for + the PIT values by comparing with a uniform distribution between 0 and 1. """ + stat, pval = stats.kstest(self._pit_vals, 'uniform') + return stat_and_pval(stat, pval) + +@deprecated( + reason=""" + This class is deprecated. + It has been incorporated into the qp.metrics.pit.PIT class in the qp-prob package. + """, + category=DeprecationWarning) +@PITMetaMetric +class PITCvM(PITMeta): + """ Cramer-von Mises statistic """ + + def evaluate(self): + """ Use scipy.stats.cramervonmises to compute the Cramer-von Mises statistic for + the PIT values by comparing with a uniform distribution between 0 and 1. """ + cvm_stat_and_pval = stats.cramervonmises(self._pit_vals, 'uniform') + return stat_and_pval(cvm_stat_and_pval.statistic, + cvm_stat_and_pval.pvalue) + +@deprecated( + reason=""" + This class is deprecated. + It has been incorporated into the qp.metrics.pit.PIT class in the qp-prob package. + """, + category=DeprecationWarning) +@PITMetaMetric +class PITAD(PITMeta): + """ Anderson-Darling statistic """ + + def evaluate(self, pit_min=0., pit_max=1.): + """ Use scipy.stats.anderson_ksamp to compute the Anderson-Darling statistic + for the PIT values by comparing with a uniform distribution between 0 and 1. + Up to the current version (1.6.2), scipy.stats.anderson does not support + uniform distributions as reference for 1-sample test. + + Parameters + ---------- + pit_min: float, optional + PIT values below this are discarded + pit_max: float, optional + PIT values greater than this are discarded + + Returns + ------- + + """ + pits = self._pit_vals + mask = (pits >= pit_min) & (pits <= pit_max) + pits_clean = pits[mask] + diff = len(pits) - len(pits_clean) + if diff > 0: + print(f"{diff} PITs removed from the sample.") + uniform_yvals = np.linspace(pit_min, pit_max, len(pits_clean)) + ad_results = stats.anderson_ksamp([pits_clean, uniform_yvals]) + stat, crit_vals, sig_lev = ad_results + + return stat_crit_sig(stat, crit_vals, sig_lev) + +# comment out for now due to discrete approx +#@PITMetaMetric +#class PITKLD(PITMeta): +# """ Kullback-Leibler Divergence """ +# +# def __init__(self, pit_vals, pit): +# super().__init__(pit_vals, pit) +# +# def evaluate(self, eval_grid=default_quants): +# """ Use scipy.stats.entropy to compute the Kullback-Leibler +# Divergence between the empirical PIT distribution and a +# theoretical uniform distribution between 0 and 1.""" +# warnings.warn("This KLD implementation is based on scipy.stats.entropy, " + +# "therefore it uses a discrete distribution of PITs " + +# "(internally obtained from PIT object).") +# pits = self._pit_vals +# uniform_yvals = np.linspace(0., 1., len(pits)) +# pit_pdf, _ = np.histogram(pits, bins=eval_grid) +# uniform_pdf, _ = np.histogram(uniform_yvals, bins=eval_grid) +# kld_metric = stats.entropy(pit_pdf, uniform_pdf) +# return stat_and_pval(kld_metric, None) diff --git a/src/rail/examples_data/testdata/bpz_test_red.npz b/src/rail/examples_data/testdata/bpz_test_red.npz new file mode 100644 index 00000000..f1e7a2ab Binary files /dev/null and b/src/rail/examples_data/testdata/bpz_test_red.npz differ diff --git a/tests/evaluation/test_evaluation.py b/tests/evaluation/test_evaluation.py index 1f0571fc..f1375f65 100644 --- a/tests/evaluation/test_evaluation.py +++ b/tests/evaluation/test_evaluation.py @@ -1,7 +1,17 @@ import os import numpy as np +from rail.core.stage import RailStage +from rail.core.data import QPHandle, TableHandle +from rail.evaluation.metrics.pit import PIT, PITOutRate, PITKS, PITCvM, PITAD +from src.rail.evaluation.metrics.pit import ConditionPIT +from rail.evaluation.metrics.cdeloss import CDELoss +import rail.evaluation.metrics.pointestimates as pe +from rail.evaluation.evaluator import Evaluator +from rail.core.utils import RAILDIR import qp +import pandas as pd +import subprocess import rail.evaluation.metrics.pointestimates as pe from rail.core.data import QPHandle, TableHandle @@ -18,9 +28,9 @@ CDEVAL = -4.31200 SIGIQR = 0.0045947 BIAS = -0.00001576 -OUTRATE = 0.0 SIGMAD = 0.0046489 +default_testdata_folder = os.path.join(RAILDIR, 'rail', 'examples', 'testdata') def construct_test_ensemble(): np.random.seed(87) @@ -43,6 +53,53 @@ def test_cdeloss_metric(): assert np.isclose(cde_stat, CDEVAL) +def test_condition_pit_metric(): + """ + Unit test for condition pit metric + + Returns + ------- + + """ + + data = np.load(os.path.join(default_testdata_folder, 'bpz_test_red.npz'), allow_pickle=True) + z_grid = data['z_grid'] + cat = pd.DataFrame(data["test_cat"]) + + cde = data["cde_test"] # conditional density estimate + norm = np.trapz(cde, z_grid) # normalize across the redshift grid + norm[norm == 0] = 1 + cde = cde / norm[:, None] + num_calib = 800 + SEED = 299792458 + n_gal = 1000 + rng = np.random.default_rng(SEED) + indices = rng.permutation(n_gal) # creating index permutation for splitting in train and test + cde_calib = cde[indices[:num_calib]] # splitting cde in training set + cde_test = cde[indices[num_calib:]] # and test set + z_calib = cat["SPECZ"][indices[:num_calib]].values + z_test = cat["SPECZ"][indices[num_calib:]].values + cat_calib = cat.iloc[indices[:num_calib]] + cat_test = cat.iloc[indices[num_calib:]] + features = ["I", "UG", "GR", "RI", "IZ", "ZY", "IZERR", "RIERR", "GRERR", "UGERR", "IERR", "ZYERR"] + + qp_ens_cde_calib = qp.Ensemble(qp.interp, data=dict(xvals=z_grid, yvals=cde_calib)) + cond_pit = ConditionPIT(cde_calib, cde_test, z_grid, z_calib, z_test, cat_calib[features].values, + cat_test[features].values, qp_ens_cde_calib) + # cond_pit.train(patience=10, n_epochs=2, lr=0.001, weight_decay=0.01, batch_size=100, frac_mlp_train=0.9, + # lr_decay=0.95, oversample=50, n_alpha=201, + # checkpt_path=os.path.join(default_testdata_folder, 'checkpoint_GPZ_wide_CDE_test.pt'), + # hidden_layers=[2, 2, 2]) + # pit_local, pit_local_fit = cond_pit.evaluate(model_checkpt_path=os.path.join(default_testdata_folder, + # 'checkpoint_GPZ_wide_CDE_test.pt'), + # model_hidden_layers=[2, 2, 2], nn_type='monotonic', + # batch_size=100, num_basis=40, num_cores=1) + # subprocess.run(['rm', os.path.join(default_testdata_folder, 'checkpoint_GPZ_wide_CDE_test.pt')]) + # cond_pit.diagnostics(pit_local, pit_local_fit, os.path.join(default_testdata_folder, 'local_pp_plot.pdf')) + # assert os.path.isfile(os.path.join(default_testdata_folder, 'local_pp_plot.pdf')) + assert len(cond_pit.x_test) != 0 + + def test_point_metrics(): zgrid, zspec, pdf_ens, true_ez = construct_test_ensemble() zb = pdf_ens.mode(grid=zgrid).flatten()