From df580b5531c70a7755e47e0705397e753daee5c2 Mon Sep 17 00:00:00 2001 From: zachrewolinski Date: Mon, 23 Dec 2024 17:08:43 -0800 Subject: [PATCH] added baseline of raw data --- .../subgroup/current/subgroup-debug.ipynb | 978 ++++++++++++++++++ .../subgroup/current/subgroup.py | 263 +++-- 2 files changed, 1155 insertions(+), 86 deletions(-) create mode 100644 feature_importance/subgroup/current/subgroup-debug.ipynb diff --git a/feature_importance/subgroup/current/subgroup-debug.ipynb b/feature_importance/subgroup/current/subgroup-debug.ipynb new file mode 100644 index 0000000..0caee94 --- /dev/null +++ b/feature_importance/subgroup/current/subgroup-debug.ipynb @@ -0,0 +1,978 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# standard data science packages\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# imodels imports\n", + "from imodels.tree.rf_plus.rf_plus.rf_plus_models import \\\n", + " RandomForestPlusRegressor, RandomForestPlusClassifier\n", + "from imodels.tree.rf_plus.feature_importance.rfplus_explainer import \\\n", + " RFPlusMDI, AloRFPlusMDI\n", + "\n", + "# functions for subgroup experiments\n", + "from subgroup_detection import *\n", + "from subgroup_experiment import *\n", + "import shap\n", + "\n", + "# sklearn imports\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.linear_model import LogisticRegression, LinearRegression\n", + "from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, \\\n", + " accuracy_score, r2_score, f1_score, log_loss, root_mean_squared_error\n", + "\n", + "# pipeline imports\n", + "from subgroup import *" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# set inputs\n", + "seed = 1\n", + "dataids = [361247, 361243, 361242, 361251, 361253, 361260, 361259, 361256, 361254, 361622]\n", + "dataid = dataids[0]\n", + "clustertype = \"hierarchical\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1111007/2881731511.py:2: FutureWarning: Starting from Version 0.15.0 `download_splits` will default to ``False`` instead of ``True`` and be independent from `download_data`. To disable this message until version 0.15 explicitly set `download_splits` to a bool.\n", + " X, y = get_openml_data(dataid)\n", + "/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/openml/tasks/functions.py:442: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n", + " dataset = get_dataset(task.dataset_id, *dataset_args, **get_dataset_kwargs)\n", + "/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/openml/tasks/task.py:150: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n", + " return datasets.get_dataset(self.dataset_id)\n", + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 11.8s\n", + "[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 14.3s finished\n", + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 1.6min\n", + "[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 4.5min finished\n" + ] + } + ], + "source": [ + "# get data\n", + "X, y = get_openml_data(dataid)\n", + "\n", + "# split data\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3,\n", + " random_state=seed)\n", + "\n", + "# check if task is regression or classification\n", + "if len(np.unique(y)) == 2:\n", + " task = 'classification'\n", + "else:\n", + " task = 'regression'\n", + " \n", + "# fit the prediction models\n", + "rf, rf_plus_baseline, rf_plus = fit_models(X_train, y_train, task)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain shap feature importances\n", + "shap_explainer = shap.TreeExplainer(rf)\n", + "shap_train_values, shap_train_rankings = get_shap(X_train, shap_explainer,\n", + " task)\n", + "shap_test_values, shap_test_rankings = get_shap(X_test, shap_explainer,\n", + " task)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# create list of lmdi variants\n", + "lmdi_variants = create_lmdi_variant_map()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# obtain lmdi feature importances\n", + "lmdi_explainers = get_lmdi_explainers(rf_plus, lmdi_variants,\n", + " rf_plus_baseline = rf_plus_baseline)\n", + "lfi_train_values, lfi_train_rankings = get_lmdi(X_train, y_train,\n", + " lmdi_variants,\n", + " lmdi_explainers)\n", + "lfi_test_values, lfi_test_rankings = get_lmdi(X_test, None,\n", + " lmdi_variants,\n", + " lmdi_explainers)\n", + "# add shap to the dictionaries\n", + "lfi_train_values[\"shap\"] = shap_train_values\n", + "lfi_train_rankings[\"shap\"] = shap_train_rankings\n", + "lfi_test_values[\"shap\"] = shap_test_values\n", + "lfi_test_rankings[\"shap\"] = shap_test_rankings\n", + "\n", + "# add the raw data to the dictionaries as a baseline of comparison\n", + "lfi_train_values[\"rawdata\"] = X_train\n", + "lfi_test_values[\"rawdata\"] = X_test" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# get the clusterings\n", + "# method_to_labels, method_to_indices = get_train_clusters(lfi_train_values, clustertype)\n", + "train_clusters = get_train_clusters(lfi_train_values, clustertype)\n", + "cluster_centroids = get_cluster_centroids(lfi_train_values, train_clusters)\n", + "test_clusters = get_test_clusters(lfi_test_values, cluster_centroids)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# compute the performance\n", + "metrics_to_scores = compute_performance(X_train, X_test, y_train, y_test,\n", + " train_clusters, test_clusters, task)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'r2': {'lmdi_baseline': {2: 0.6445649451638324,\n", + " 3: 0.8680766126738856,\n", + " 4: 0.913119037406979,\n", + " 5: -215.23091111985252,\n", + " 6: -215.91991426100282,\n", + " 7: -78.7535766025578,\n", + " 8: -84.01922175742995,\n", + " 9: -141.54733632707973,\n", + " 10: -149.47990644371674},\n", + " 'aloo_l2_signed_normed_leafavg_rank': {2: 0.9927271668877967,\n", + " 3: 0.998027875479179,\n", + " 4: 0.9980136251092391,\n", + " 5: 0.9980137595945218,\n", + " 6: 0.9980868009483398,\n", + " 7: 0.9989535457257578,\n", + " 8: 0.9989539218165313,\n", + " 9: 0.9990000203695893,\n", + " 10: 0.9994746285775445},\n", + " 'aloo_l2_signed_normed_leafavg_norank': {2: 0.9927271668877967,\n", + " 3: 0.998027875479179,\n", + " 4: 0.9980136251092391,\n", + " 5: 0.9980137595945218,\n", + " 6: 0.9980868009483398,\n", + " 7: 0.9989535457257578,\n", + " 8: 0.9989539218165313,\n", + " 9: 0.9990000203695893,\n", + " 10: 0.9994746285775445},\n", + " 'aloo_l2_signed_normed_noleafavg_rank': {2: 0.9927271668877967,\n", + " 3: 0.8129611156493682,\n", + " 4: 0.6971885071248699,\n", + " 5: 0.8822557273813058,\n", + " 6: 0.9980904410847861,\n", + " 7: 0.9993043660518036,\n", + " 8: 0.9993188781851715,\n", + " 9: 0.999347240392281,\n", + " 10: 0.9994359402433696},\n", + " 'aloo_l2_signed_normed_noleafavg_norank': {2: 0.9927271668877967,\n", + " 3: 0.8129611156493682,\n", + " 4: 0.6971885071248699,\n", + " 5: 0.8822557273813058,\n", + " 6: 0.9980904410847861,\n", + " 7: 0.9993043660518036,\n", + " 8: 0.9993188781851715,\n", + " 9: 0.999347240392281,\n", + " 10: 0.9994359402433696},\n", + " 'aloo_l2_signed_nonnormed_leafavg_rank': {2: -52.9816719619314,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9989663293158638,\n", + " 6: 0.9996827174346178,\n", + " 7: 0.999737262860902,\n", + " 8: 0.9997371956413068,\n", + " 9: 0.9997812873442385,\n", + " 10: 0.9998603303452446},\n", + " 'aloo_l2_signed_nonnormed_leafavg_norank': {2: -52.9816719619314,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9989663293158638,\n", + " 6: 0.9996827174346178,\n", + " 7: 0.999737262860902,\n", + " 8: 0.9997371956413068,\n", + " 9: 0.9997812873442385,\n", + " 10: 0.9998603303452446},\n", + " 'aloo_l2_signed_nonnormed_noleafavg_rank': {2: -29.26605305234625,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9966306361354462,\n", + " 6: 0.9996782292486582,\n", + " 7: 0.9957629106275575,\n", + " 8: 0.9958174560538414,\n", + " 9: 0.9958173703663622,\n", + " 10: 0.9953993015397371},\n", + " 'aloo_l2_signed_nonnormed_noleafavg_norank': {2: -29.26605305234625,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9966306361354462,\n", + " 6: 0.9996782292486582,\n", + " 7: 0.9957629106275575,\n", + " 8: 0.9958174560538414,\n", + " 9: 0.9958173703663622,\n", + " 10: 0.9953993015397371},\n", + " 'aloo_l2_unsigned_normed_leafavg_rank': {2: 0.8482029670826879,\n", + " 3: 0.8490268540525132,\n", + " 4: 0.4752099034178157,\n", + " 5: 0.7866454423884263,\n", + " 6: 0.8864053082688182,\n", + " 7: 0.9161437520446064,\n", + " 8: 0.9161429805289311,\n", + " 9: 0.9089617638902763,\n", + " 10: 0.9483759544174902},\n", + " 'aloo_l2_unsigned_normed_leafavg_norank': {2: 0.8482029670826879,\n", + " 3: 0.8490268540525132,\n", + " 4: 0.4752099034178157,\n", + " 5: 0.7866454423884263,\n", + " 6: 0.8864053082688182,\n", + " 7: 0.9161437520446064,\n", + " 8: 0.9161429805289311,\n", + " 9: 0.9089617638902763,\n", + " 10: 0.9483759544174902},\n", + " 'aloo_l2_unsigned_normed_noleafavg_rank': {2: -47713267.04342204,\n", + " 3: 0.867252100105942,\n", + " 4: 0.6655578674159756,\n", + " 5: 0.6345143125955007,\n", + " 6: 0.804234774948962,\n", + " 7: 0.8881409168059027,\n", + " 8: 0.9488394790547148,\n", + " 9: 0.9488377681214137,\n", + " 10: 0.7713817873735194},\n", + " 'aloo_l2_unsigned_normed_noleafavg_norank': {2: -47713267.04342204,\n", + " 3: 0.867252100105942,\n", + " 4: 0.6655578674159756,\n", + " 5: 0.6345143125955007,\n", + " 6: 0.804234774948962,\n", + " 7: 0.8881409168059027,\n", + " 8: 0.9488394790547148,\n", + " 9: 0.9488377681214137,\n", + " 10: 0.7713817873735194},\n", + " 'aloo_l2_unsigned_nonnormed_leafavg_rank': {2: 0.8745572555125906,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9983512181147747,\n", + " 5: 0.9991032706284847,\n", + " 6: 0.9991035789215647,\n", + " 7: 0.9996230204161064,\n", + " 8: 0.9994815946931114,\n", + " 9: 0.9995007296846626,\n", + " 10: 0.9994995502282262},\n", + " 'aloo_l2_unsigned_nonnormed_leafavg_norank': {2: 0.8745572555125906,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9983512181147747,\n", + " 5: 0.9991032706284847,\n", + " 6: 0.9991035789215647,\n", + " 7: 0.9996230204161064,\n", + " 8: 0.9994815946931114,\n", + " 9: 0.9995007296846626,\n", + " 10: 0.9994995502282262},\n", + " 'aloo_l2_unsigned_nonnormed_noleafavg_rank': {2: -658.8304129932814,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9985217433061586,\n", + " 5: 0.9550494786428747,\n", + " 6: 0.9550492096203371,\n", + " 7: 0.9551393110163726,\n", + " 8: 0.36180130458821824,\n", + " 9: 0.28500823784611395,\n", + " 10: 0.28500724039632985},\n", + " 'aloo_l2_unsigned_nonnormed_noleafavg_norank': {2: -658.8304129932814,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9985217433061586,\n", + " 5: 0.9550494786428747,\n", + " 6: 0.9550492096203371,\n", + " 7: 0.9551393110163726,\n", + " 8: 0.36180130458821824,\n", + " 9: 0.28500823784611395,\n", + " 10: 0.28500724039632985},\n", + " 'aloo_nonl2_unsigned_nonnormed_leafavg_rank': {2: 0.9927271668877967,\n", + " 3: -2.2008705734289133,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: -75.07335140556614,\n", + " 9: -7.538142546851767,\n", + " 10: -7.53814242118957},\n", + " 'aloo_nonl2_unsigned_nonnormed_leafavg_norank': {2: 0.9927271668877967,\n", + " 3: -2.2008705734289133,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: -75.07335140556614,\n", + " 9: -7.538142546851767,\n", + " 10: -7.53814242118957},\n", + " 'aloo_nonl2_unsigned_nonnormed_noleafavg_rank': {2: 0.9927271668877967,\n", + " 3: -2.22103085531022,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.999757159351784,\n", + " 9: 0.9911520759914881,\n", + " 10: 0.9912216418992041},\n", + " 'aloo_nonl2_unsigned_nonnormed_noleafavg_norank': {2: 0.9927271668877967,\n", + " 3: -2.22103085531022,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.999757159351784,\n", + " 9: 0.9911520759914881,\n", + " 10: 0.9912216418992041},\n", + " 'nonloo_l2_signed_normed_leafavg_rank': {2: 0.9927271668877967,\n", + " 3: 0.998027875479179,\n", + " 4: 0.9980136251092391,\n", + " 5: 0.9980137595945218,\n", + " 6: 0.9980669352513551,\n", + " 7: 0.9989461317620801,\n", + " 8: 0.9989590498768045,\n", + " 9: 0.9989995847211592,\n", + " 10: 0.9995338286777976},\n", + " 'nonloo_l2_signed_normed_leafavg_norank': {2: 0.9927271668877967,\n", + " 3: 0.998027875479179,\n", + " 4: 0.9980136251092391,\n", + " 5: 0.9980137595945218,\n", + " 6: 0.9980669352513551,\n", + " 7: 0.9989461317620801,\n", + " 8: 0.9989590498768045,\n", + " 9: 0.9989995847211592,\n", + " 10: 0.9995338286777976},\n", + " 'nonloo_l2_signed_normed_noleafavg_rank': {2: 0.9927271668877967,\n", + " 3: 0.8129611156493682,\n", + " 4: 0.6971885071248699,\n", + " 5: 0.8822557273813058,\n", + " 6: 0.9980906108267896,\n", + " 7: 0.9989931209611429,\n", + " 8: 0.9990100639855947,\n", + " 9: 0.9990366613191032,\n", + " 10: 0.9995749480788625},\n", + " 'nonloo_l2_signed_normed_noleafavg_norank': {2: 0.9927271668877967,\n", + " 3: 0.8129611156493682,\n", + " 4: 0.6971885071248699,\n", + " 5: 0.8822557273813058,\n", + " 6: 0.9980906108267896,\n", + " 7: 0.9989931209611429,\n", + " 8: 0.9990100639855947,\n", + " 9: 0.9990366613191032,\n", + " 10: 0.9995749480788625},\n", + " 'nonloo_l2_signed_nonnormed_leafavg_rank': {2: -52.9816719619314,\n", + " 3: 0.5756124996208882,\n", + " 4: -50.92942236079049,\n", + " 5: 0.9896764983573555,\n", + " 6: 0.9995400537972231,\n", + " 7: 0.9995945992235074,\n", + " 8: 0.9997769501779349,\n", + " 9: 0.9997764171304747,\n", + " 10: 0.9997768842758006},\n", + " 'nonloo_l2_signed_nonnormed_leafavg_norank': {2: -52.9816719619314,\n", + " 3: 0.5756124996208882,\n", + " 4: -50.92942236079049,\n", + " 5: 0.9896764983573555,\n", + " 6: 0.9995400537972231,\n", + " 7: 0.9995945992235074,\n", + " 8: 0.9997769501779349,\n", + " 9: 0.9997764171304747,\n", + " 10: 0.9997768842758006},\n", + " 'nonloo_l2_signed_nonnormed_noleafavg_rank': {2: -29.26605305234625,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9966306361354462,\n", + " 6: 0.9996782292486582,\n", + " 7: 0.9997688522669854,\n", + " 8: 0.9998233976932694,\n", + " 9: 0.9998229137977586,\n", + " 10: 0.9998551396737074},\n", + " 'nonloo_l2_signed_nonnormed_noleafavg_norank': {2: -29.26605305234625,\n", + " 3: 0.9973996973346272,\n", + " 4: 0.9985780941826747,\n", + " 5: 0.9966306361354462,\n", + " 6: 0.9996782292486582,\n", + " 7: 0.9997688522669854,\n", + " 8: 0.9998233976932694,\n", + " 9: 0.9998229137977586,\n", + " 10: 0.9998551396737074},\n", + " 'nonloo_l2_unsigned_normed_leafavg_rank': {2: 0.7722969912831866,\n", + " 3: 0.8780570274586403,\n", + " 4: 0.7502217286627606,\n", + " 5: 0.6267247355471606,\n", + " 6: 0.8640735533939706,\n", + " 7: 0.8699235557372856,\n", + " 8: 0.8699229538020387,\n", + " 9: 0.9097004235949722,\n", + " 10: 0.944465712417476},\n", + " 'nonloo_l2_unsigned_normed_leafavg_norank': {2: 0.7722969912831866,\n", + " 3: 0.8780570274586403,\n", + " 4: 0.7502217286627606,\n", + " 5: 0.6267247355471606,\n", + " 6: 0.8640735533939706,\n", + " 7: 0.8699235557372856,\n", + " 8: 0.8699229538020387,\n", + " 9: 0.9097004235949722,\n", + " 10: 0.944465712417476},\n", + " 'nonloo_l2_unsigned_normed_noleafavg_rank': {2: -4243.494895430736,\n", + " 3: -4275.535490057371,\n", + " 4: -3812.4902858534997,\n", + " 5: -3812.572032759011,\n", + " 6: -1848.282758661586,\n", + " 7: -577.0981130078727,\n", + " 8: 0.9513335156605569,\n", + " 9: 0.9513317478982386,\n", + " 10: 0.9627777208883127},\n", + " 'nonloo_l2_unsigned_normed_noleafavg_norank': {2: -4243.494895430736,\n", + " 3: -4275.535490057371,\n", + " 4: -3812.4902858534997,\n", + " 5: -3812.572032759011,\n", + " 6: -1848.282758661586,\n", + " 7: -577.0981130078727,\n", + " 8: 0.9513335156605569,\n", + " 9: 0.9513317478982386,\n", + " 10: 0.9627777208883127},\n", + " 'nonloo_l2_unsigned_nonnormed_leafavg_rank': {2: 0.8745572555125906,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9983793104588853,\n", + " 5: 0.9991313629725951,\n", + " 6: 0.9991307853136793,\n", + " 7: 0.9996242893857193,\n", + " 8: 0.9995068982867004,\n", + " 9: 0.9995071975348331,\n", + " 10: 0.9995644448819649},\n", + " 'nonloo_l2_unsigned_nonnormed_leafavg_norank': {2: 0.8745572555125906,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9983793104588853,\n", + " 5: 0.9991313629725951,\n", + " 6: 0.9991307853136793,\n", + " 7: 0.9996242893857193,\n", + " 8: 0.9995068982867004,\n", + " 9: 0.9995071975348331,\n", + " 10: 0.9995644448819649},\n", + " 'nonloo_l2_unsigned_nonnormed_noleafavg_rank': {2: -658.8304129932814,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9986192282152526,\n", + " 5: 0.9551469635519686,\n", + " 6: 0.6578314231077766,\n", + " 7: 0.6578311989028846,\n", + " 8: 0.6043819673645229,\n", + " 9: 0.6043813139199089,\n", + " 10: 0.6043811645905632},\n", + " 'nonloo_l2_unsigned_nonnormed_noleafavg_norank': {2: -658.8304129932814,\n", + " 3: 0.9962856146908532,\n", + " 4: 0.9986192282152526,\n", + " 5: 0.9551469635519686,\n", + " 6: 0.6578314231077766,\n", + " 7: 0.6578311989028846,\n", + " 8: 0.6043819673645229,\n", + " 9: 0.6043813139199089,\n", + " 10: 0.6043811645905632},\n", + " 'nonloo_nonl2_unsigned_nonnormed_leafavg_rank': {2: 0.9927271668877967,\n", + " 3: -2.2008705734289133,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.8882223157339382,\n", + " 9: 0.9998110097411719,\n", + " 10: 0.9998111354033692},\n", + " 'nonloo_nonl2_unsigned_nonnormed_leafavg_norank': {2: 0.9927271668877967,\n", + " 3: -2.2008705734289133,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.8882223157339382,\n", + " 9: 0.9998110097411719,\n", + " 10: 0.9998111354033692},\n", + " 'nonloo_nonl2_unsigned_nonnormed_noleafavg_rank': {2: 0.9927271668877967,\n", + " 3: -2.22103085531022,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.9997574324459434,\n", + " 9: 0.9997633734563234,\n", + " 10: 0.9998264060218799},\n", + " 'nonloo_nonl2_unsigned_nonnormed_noleafavg_norank': {2: 0.9927271668877967,\n", + " 3: -2.22103085531022,\n", + " 4: 0.999276923103626,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993389618038269,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.9997574324459434,\n", + " 9: 0.9997633734563234,\n", + " 10: 0.9998264060218799},\n", + " 'shap': {2: 0.29463484466336426,\n", + " 3: -50.062463572581315,\n", + " 4: -100.56554667809782,\n", + " 5: -136.81628954851814,\n", + " 6: -32.02680562951272,\n", + " 7: -60.33838945037482,\n", + " 8: -15.745605023818893,\n", + " 9: 0.3053033586575728,\n", + " 10: 0.34611583681096836},\n", + " 'rawdata': {2: 0.9927271668877967,\n", + " 3: 0.9928404246313651,\n", + " 4: 0.9928596794536289,\n", + " 5: 0.9993201962730601,\n", + " 6: 0.9993342833962491,\n", + " 7: 0.9993530489270159,\n", + " 8: 0.999752876689095,\n", + " 9: 0.9998425132653208,\n", + " 10: 0.9998917498213533}},\n", + " 'rmse': {'lmdi_baseline': {2: 0.0047210746863881135,\n", + " 3: 0.0023212981104360976,\n", + " 4: 0.0019270920598536637,\n", + " 5: 0.05946331151554615,\n", + " 6: 0.05919640530791291,\n", + " 7: 0.03545394967155662,\n", + " 8: 0.04231645458226893,\n", + " 9: 0.04480164709499497,\n", + " 10: 0.05206890943814465},\n", + " 'aloo_l2_signed_normed_leafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.000536739872320485,\n", + " 4: 0.00053294713724969,\n", + " 5: 0.0005296889943642573,\n", + " 6: 0.0004747451089050948,\n", + " 7: 0.0003495539250932273,\n", + " 8: 0.00034647000291166254,\n", + " 9: 0.0003273818086388954,\n", + " 10: 0.00022901527768305624},\n", + " 'aloo_l2_signed_normed_leafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.000536739872320485,\n", + " 4: 0.00053294713724969,\n", + " 5: 0.0005296889943642573,\n", + " 6: 0.0004747451089050948,\n", + " 7: 0.0003495539250932273,\n", + " 8: 0.00034647000291166254,\n", + " 9: 0.0003273818086388954,\n", + " 10: 0.00022901527768305624},\n", + " 'aloo_l2_signed_normed_noleafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.004300224499885554,\n", + " 4: 0.005053756772585791,\n", + " 5: 0.0012867913090043796,\n", + " 6: 0.00047547323046174807,\n", + " 7: 0.0002828423048293196,\n", + " 8: 0.0002747572861646693,\n", + " 9: 0.0002589466190436697,\n", + " 10: 0.00023309436012589806},\n", + " 'aloo_l2_signed_normed_noleafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.004300224499885554,\n", + " 4: 0.005053756772585791,\n", + " 5: 0.0012867913090043796,\n", + " 6: 0.00047547323046174807,\n", + " 7: 0.0002828423048293196,\n", + " 8: 0.0002747572861646693,\n", + " 9: 0.0002589466190436697,\n", + " 10: 0.00023309436012589806},\n", + " 'aloo_l2_signed_nonnormed_leafavg_rank': {2: 0.04362311263699772,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.00039775556231891107,\n", + " 6: 0.00021093025527390966,\n", + " 7: 0.0001675986628028056,\n", + " 8: 0.00016581218948602696,\n", + " 9: 0.0001449748649842588,\n", + " 10: 0.0001246268639072646},\n", + " 'aloo_l2_signed_nonnormed_leafavg_norank': {2: 0.04362311263699772,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.00039775556231891107,\n", + " 6: 0.00021093025527390966,\n", + " 7: 0.0001675986628028056,\n", + " 8: 0.00016581218948602696,\n", + " 9: 0.0001449748649842588,\n", + " 10: 0.0001246268639072646},\n", + " 'aloo_l2_signed_nonnormed_noleafavg_rank': {2: 0.03324421596018426,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.000634375534552158,\n", + " 6: 0.00020783171745813686,\n", + " 7: 0.0003901173103087517,\n", + " 8: 0.00034678571783764757,\n", + " 9: 0.0003450834802206439,\n", + " 10: 0.0003189984716008969},\n", + " 'aloo_l2_signed_nonnormed_noleafavg_norank': {2: 0.03324421596018426,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.000634375534552158,\n", + " 6: 0.00020783171745813686,\n", + " 7: 0.0003901173103087517,\n", + " 8: 0.00034678571783764757,\n", + " 9: 0.0003450834802206439,\n", + " 10: 0.0003189984716008969},\n", + " 'aloo_l2_unsigned_normed_leafavg_rank': {2: 0.0054218889942934485,\n", + " 3: 0.00517150600374909,\n", + " 4: 0.008823311387533365,\n", + " 5: 0.0054810169479879755,\n", + " 6: 0.003923156755000956,\n", + " 7: 0.0033396637262877006,\n", + " 8: 0.0033393150177579243,\n", + " 9: 0.003393317378827907,\n", + " 10: 0.00260741294777391},\n", + " 'aloo_l2_unsigned_normed_leafavg_norank': {2: 0.0054218889942934485,\n", + " 3: 0.00517150600374909,\n", + " 4: 0.008823311387533365,\n", + " 5: 0.0054810169479879755,\n", + " 6: 0.003923156755000956,\n", + " 7: 0.0033396637262877006,\n", + " 8: 0.0033393150177579243,\n", + " 9: 0.003393317378827907,\n", + " 10: 0.00260741294777391},\n", + " 'aloo_l2_unsigned_normed_noleafavg_rank': {2: 19.94148697062952,\n", + " 3: 0.004946765983662409,\n", + " 4: 0.007265081166580929,\n", + " 5: 0.007222219490041424,\n", + " 6: 0.005473624441481075,\n", + " 7: 0.004088962169097856,\n", + " 8: 0.0027309398154122476,\n", + " 9: 0.0027302555225232296,\n", + " 10: 0.003978847512238699},\n", + " 'aloo_l2_unsigned_normed_noleafavg_norank': {2: 19.94148697062952,\n", + " 3: 0.004946765983662409,\n", + " 4: 0.007265081166580929,\n", + " 5: 0.007222219490041424,\n", + " 6: 0.005473624441481075,\n", + " 7: 0.004088962169097856,\n", + " 8: 0.0027309398154122476,\n", + " 9: 0.0027302555225232296,\n", + " 10: 0.003978847512238699},\n", + " 'aloo_l2_unsigned_nonnormed_leafavg_rank': {2: 0.004837433473419856,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.0005170377869226024,\n", + " 5: 0.00030579750188431266,\n", + " 6: 0.0003034355700354145,\n", + " 7: 0.00022209253681759975,\n", + " 8: 0.00024723493934226634,\n", + " 9: 0.00024207451364163218,\n", + " 10: 0.00024203165614719571},\n", + " 'aloo_l2_unsigned_nonnormed_leafavg_norank': {2: 0.004837433473419856,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.0005170377869226024,\n", + " 5: 0.00030579750188431266,\n", + " 6: 0.0003034355700354145,\n", + " 7: 0.00022209253681759975,\n", + " 8: 0.00024723493934226634,\n", + " 9: 0.00024207451364163218,\n", + " 10: 0.00024203165614719571},\n", + " 'aloo_l2_unsigned_nonnormed_noleafavg_rank': {2: 0.13240418573095328,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.000497900499762969,\n", + " 5: 0.001747029897643657,\n", + " 6: 0.001745549373276284,\n", + " 7: 0.0017270719834404395,\n", + " 8: 0.0047689147547438655,\n", + " 9: 0.004530104287197566,\n", + " 10: 0.004530023527798446},\n", + " 'aloo_l2_unsigned_nonnormed_noleafavg_norank': {2: 0.13240418573095328,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.000497900499762969,\n", + " 5: 0.001747029897643657,\n", + " 6: 0.001745549373276284,\n", + " 7: 0.0017270719834404395,\n", + " 8: 0.0047689147547438655,\n", + " 9: 0.004530104287197566,\n", + " 10: 0.004530023527798446},\n", + " 'aloo_nonl2_unsigned_nonnormed_leafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.014669224406046547,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.04455038012818413,\n", + " 9: 0.014314887941730794,\n", + " 10: 0.014313029298138996},\n", + " 'aloo_nonl2_unsigned_nonnormed_leafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.014669224406046547,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.04455038012818413,\n", + " 9: 0.014314887941730794,\n", + " 10: 0.014313029298138996},\n", + " 'aloo_nonl2_unsigned_nonnormed_noleafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.01467016319492342,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.00015775457939034754,\n", + " 9: 0.000582632076442091,\n", + " 10: 0.0005646898909798848},\n", + " 'aloo_nonl2_unsigned_nonnormed_noleafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.01467016319492342,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.00015775457939034754,\n", + " 9: 0.000582632076442091,\n", + " 10: 0.0005646898909798848},\n", + " 'nonloo_l2_signed_normed_leafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.000536739872320485,\n", + " 4: 0.00053294713724969,\n", + " 5: 0.0005296889943642573,\n", + " 6: 0.00048000190772776906,\n", + " 7: 0.00035355145997252504,\n", + " 8: 0.00034957392127082015,\n", + " 9: 0.00033139044486839775,\n", + " 10: 0.00022390808018535533},\n", + " 'nonloo_l2_signed_normed_leafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.000536739872320485,\n", + " 4: 0.00053294713724969,\n", + " 5: 0.0005296889943642573,\n", + " 6: 0.00048000190772776906,\n", + " 7: 0.00035355145997252504,\n", + " 8: 0.00034957392127082015,\n", + " 9: 0.00033139044486839775,\n", + " 10: 0.00022390808018535533},\n", + " 'nonloo_l2_signed_normed_noleafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.004300224499885554,\n", + " 4: 0.005053756772585791,\n", + " 5: 0.0012867913090043796,\n", + " 6: 0.00047598649846853573,\n", + " 7: 0.00034202731284905754,\n", + " 8: 0.0003341798631485496,\n", + " 9: 0.0003175368917328057,\n", + " 10: 0.00020995637112838862},\n", + " 'nonloo_l2_signed_normed_noleafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.004300224499885554,\n", + " 4: 0.005053756772585791,\n", + " 5: 0.0012867913090043796,\n", + " 6: 0.00047598649846853573,\n", + " 7: 0.00034202731284905754,\n", + " 8: 0.0003341798631485496,\n", + " 9: 0.0003175368917328057,\n", + " 10: 0.00020995637112838862},\n", + " 'nonloo_l2_signed_nonnormed_leafavg_rank': {2: 0.04362311263699772,\n", + " 3: 0.006831112666336973,\n", + " 4: 0.0676746623208454,\n", + " 5: 0.0008223947293657446,\n", + " 6: 0.00024914333640892633,\n", + " 7: 0.00020581174393782223,\n", + " 8: 0.000159765967155884,\n", + " 9: 0.00015890301795810433,\n", + " 10: 0.00015751895482127068},\n", + " 'nonloo_l2_signed_nonnormed_leafavg_norank': {2: 0.04362311263699772,\n", + " 3: 0.006831112666336973,\n", + " 4: 0.0676746623208454,\n", + " 5: 0.0008223947293657446,\n", + " 6: 0.00024914333640892633,\n", + " 7: 0.00020581174393782223,\n", + " 8: 0.000159765967155884,\n", + " 9: 0.00015890301795810433,\n", + " 10: 0.00015751895482127068},\n", + " 'nonloo_l2_signed_nonnormed_noleafavg_rank': {2: 0.03324421596018426,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.000634375534552158,\n", + " 6: 0.00020783171745813686,\n", + " 7: 0.00018129147472408712,\n", + " 8: 0.00013795988225298302,\n", + " 9: 0.00013684951532915487,\n", + " 10: 0.0001263749749850754},\n", + " 'nonloo_l2_signed_nonnormed_noleafavg_norank': {2: 0.03324421596018426,\n", + " 3: 0.0006923221574011059,\n", + " 4: 0.0004826977777701225,\n", + " 5: 0.000634375534552158,\n", + " 6: 0.00020783171745813686,\n", + " 7: 0.00018129147472408712,\n", + " 8: 0.00013795988225298302,\n", + " 9: 0.00013684951532915487,\n", + " 10: 0.0001263749749850754},\n", + " 'nonloo_l2_unsigned_normed_leafavg_rank': {2: 0.006571336052670374,\n", + " 3: 0.004757091793588651,\n", + " 4: 0.006413862892245409,\n", + " 5: 0.0064660997513278605,\n", + " 6: 0.004395800415431629,\n", + " 7: 0.0040796854424555665,\n", + " 8: 0.004079238949976211,\n", + " 9: 0.00347126300103831,\n", + " 10: 0.002751661944334366},\n", + " 'nonloo_l2_unsigned_normed_leafavg_norank': {2: 0.006571336052670374,\n", + " 3: 0.004757091793588651,\n", + " 4: 0.006413862892245409,\n", + " 5: 0.0064660997513278605,\n", + " 6: 0.004395800415431629,\n", + " 7: 0.0040796854424555665,\n", + " 8: 0.004079238949976211,\n", + " 9: 0.00347126300103831,\n", + " 10: 0.002751661944334366},\n", + " 'nonloo_l2_unsigned_normed_noleafavg_rank': {2: 0.2761825250944572,\n", + " 3: 0.2671282086148904,\n", + " 4: 0.1425099785803983,\n", + " 5: 0.1442282152282163,\n", + " 6: 0.09518249773189962,\n", + " 7: 0.040043649321861746,\n", + " 8: 0.0027270103463411218,\n", + " 9: 0.0027263491743422016,\n", + " 10: 0.0021793819732008938},\n", + " 'nonloo_l2_unsigned_normed_noleafavg_norank': {2: 0.2761825250944572,\n", + " 3: 0.2671282086148904,\n", + " 4: 0.1425099785803983,\n", + " 5: 0.1442282152282163,\n", + " 6: 0.09518249773189962,\n", + " 7: 0.040043649321861746,\n", + " 8: 0.0027270103463411218,\n", + " 9: 0.0027263491743422016,\n", + " 10: 0.0021793819732008938},\n", + " 'nonloo_l2_unsigned_nonnormed_leafavg_rank': {2: 0.004837433473419856,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.0005132270259987337,\n", + " 5: 0.0003019867409604438,\n", + " 6: 0.00030078296661500194,\n", + " 7: 0.00022138381893241834,\n", + " 8: 0.00023981757075261644,\n", + " 9: 0.00023876395222442026,\n", + " 10: 0.00022358577621710168},\n", + " 'nonloo_l2_unsigned_nonnormed_leafavg_norank': {2: 0.004837433473419856,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.0005132270259987337,\n", + " 5: 0.0003019867409604438,\n", + " 6: 0.00030078296661500194,\n", + " 7: 0.00022138381893241834,\n", + " 8: 0.00023981757075261644,\n", + " 9: 0.00023876395222442026,\n", + " 10: 0.00022358577621710168},\n", + " 'nonloo_l2_unsigned_nonnormed_noleafavg_rank': {2: 0.13240418573095328,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.000497137026154365,\n", + " 5: 0.001746266424035053,\n", + " 6: 0.004618157364729109,\n", + " 7: 0.004616506469518288,\n", + " 8: 0.004355707968446244,\n", + " 9: 0.004355501252683398,\n", + " 10: 0.004355132006055126},\n", + " 'nonloo_l2_unsigned_nonnormed_noleafavg_norank': {2: 0.13240418573095328,\n", + " 3: 0.0008029092680965121,\n", + " 4: 0.000497137026154365,\n", + " 5: 0.001746266424035053,\n", + " 6: 0.004618157364729109,\n", + " 7: 0.004616506469518288,\n", + " 8: 0.004355707968446244,\n", + " 9: 0.004355501252683398,\n", + " 10: 0.004355132006055126},\n", + " 'nonloo_nonl2_unsigned_nonnormed_leafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.014669224406046547,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.0018554642536271896,\n", + " 9: 0.0001348629730128722,\n", + " 10: 0.00013300432942107306},\n", + " 'nonloo_nonl2_unsigned_nonnormed_leafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.014669224406046547,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.0018554642536271896,\n", + " 9: 0.0001348629730128722,\n", + " 10: 0.00013300432942107306},\n", + " 'nonloo_nonl2_unsigned_nonnormed_noleafavg_rank': {2: 0.0010008008075137371,\n", + " 3: 0.01467016319492342,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.00015790485849546005,\n", + " 9: 0.00015436606456363365,\n", + " 10: 0.00013739965557381148},\n", + " 'nonloo_nonl2_unsigned_nonnormed_noleafavg_norank': {2: 0.0010008008075137371,\n", + " 3: 0.01467016319492342,\n", + " 4: 0.00032001266862670267,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026570106606690696,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.00015790485849546005,\n", + " 9: 0.00015436606456363365,\n", + " 10: 0.00013739965557381148},\n", + " 'shap': {2: 0.008380277473891064,\n", + " 3: 0.035960980457123246,\n", + " 4: 0.05047451188845567,\n", + " 5: 0.05623627046448045,\n", + " 6: 0.028939192047380397,\n", + " 7: 0.03962259570880885,\n", + " 8: 0.01750540713575003,\n", + " 9: 0.0016792764024878538,\n", + " 10: 0.0015849334182907898},\n", + " 'rawdata': {2: 0.0010008008075137371,\n", + " 3: 0.0009342802609423693,\n", + " 4: 0.000913464733850751,\n", + " 5: 0.00028012512198598245,\n", + " 6: 0.00026134472465609755,\n", + " 7: 0.000246920668737022,\n", + " 8: 0.00015393669132091455,\n", + " 9: 0.00013034343829786626,\n", + " 10: 0.00010604838976719766}}}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metrics_to_scores" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mdi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/feature_importance/subgroup/current/subgroup.py b/feature_importance/subgroup/current/subgroup.py index 1224d89..20ae223 100644 --- a/feature_importance/subgroup/current/subgroup.py +++ b/feature_importance/subgroup/current/subgroup.py @@ -138,6 +138,44 @@ def fit_models(X_train: np.ndarray, y_train: np.ndarray, task: str): # return tuple of models return rf, rf_plus_baseline, rf_plus +def create_lmdi_variant_map() -> Dict[str, Dict[str, bool]]: + """ + Create a mapping of LMDI+ variants to argument mappings. + + Outputs: + - lmdi_variants (Dict[str, Dict[str, bool]]): The LMDI variants to use. + """ + + # enumerate the different options when initializing a LMDI+ explainer. + loo = {True: "aloo", False: "nonloo"} + l2norm = {True: "l2", False: "nonl2"} + sign = {True: "signed", False: "unsigned"} + normalize = {True: "normed", False: "nonnormed"} + leaf_average = {True: "leafavg", False: "noleafavg"} + ranking = {True: "rank", False: "norank"} + + # create the mapping of variants to argument mappings + lmdi_variants = {} + for l in loo: + for n in l2norm: + for s in sign: + for nn in normalize: + # sign and normalize are only relevant if l2norm is True + if (not n) and (s or nn): + continue + for la in leaf_average: + for r in ranking: + # create the name the variant will be stored under + variant_name = f"{loo[l]}_{l2norm[n]}_{sign[s]}" + \ + f"_{normalize[nn]}_{leaf_average[la]}_{ranking[r]}" + # store the arguments for the lmdi+ explainer + arg_map = {"loo": l, "l2norm": n, "sign": s, + "normalize": nn, "leaf_average": la, + "ranking": r} + lmdi_variants[variant_name] = arg_map + + return lmdi_variants + def get_shap(X: np.ndarray, shap_explainer: shap.TreeExplainer, task: str): """ Get the SHAP values and rankings for the given data. @@ -203,9 +241,9 @@ def get_lmdi_explainers(rf_plus, lmdi_variants: Dict[str, Dict[str, bool]], # if a baseline is provided, we need to treat it separately if rf_plus_baseline is not None: # evaluate on inbag samples only - lmdi_explainers["baseline"] = RFPlusMDI(rf_plus_baseline, - mode = "only_k", - evaluate_on = "inbag") + lmdi_explainers["lmdi_baseline"] = RFPlusMDI(rf_plus_baseline, + mode = "only_k", + evaluate_on = "inbag") # create the explainer objects for each variant, using AloRFPlusMDI if loo # is True and RFPlusMDI if loo is False for variant_name in lmdi_variants.keys(): @@ -252,19 +290,19 @@ def get_lmdi(X: np.ndarray, y: np.ndarray, # if the explainer mapping has a baseline, we need to treat it differently if len(lmdi_explainers) == len(lmdi_variants) + 1 and \ - "baseline" in lmdi_explainers: + "lmdi_baseline" in lmdi_explainers: # we need to get the values with all of the params set to False - lmdi_values["baseline"] = \ - lmdi_explainers["baseline"].explain_linear_partial(X, y, + lmdi_values["lmdi_baseline"] = \ + lmdi_explainers["lmdi_baseline"].explain_linear_partial(X, y, l2norm=False, sign=False, normalize=False, leaf_average=False, ranking=False) # get the rankings using the method in the explainer class - lmdi_rankings["baseline"] = \ - lmdi_explainers["baseline"].get_rankings( - np.abs(lmdi_values["baseline"]) + lmdi_rankings["lmdi_baseline"] = \ + lmdi_explainers["lmdi_baseline"].get_rankings( + np.abs(lmdi_values["lmdi_baseline"]) ) # for all the other variants, we loop through the explainer objects, @@ -272,7 +310,7 @@ def get_lmdi(X: np.ndarray, y: np.ndarray, for name, explainer in lmdi_explainers.items(): # skip through the baseline model, since we have already done it - if name == "baseline": + if name == "lmdi_baseline": continue # get the argument mapping @@ -550,8 +588,128 @@ def get_test_clusters(lfi_test_values: Dict[str, np.ndarray], method_to_indices[variant] = num_cluster_map return method_to_indices + +def compute_performance(X_train: np.ndarray, X_test: np.ndarray, + y_train: np.ndarray, y_test: np.ndarray, + train_clusters: Dict[str, Dict[int, Dict[int, np.ndarray]]], + test_clusters: Dict[str, Dict[int, Dict[int, np.ndarray]]], + task: str): + """ + Fit regression models on the train data for each cluster, and calculate + the performance on the test data. + + Inputs: + - X_train (np.ndarray): The feature matrix for the training set. + - X_test (np.ndarray): The feature matrix for the testing set. + - y_train (np.ndarray): The target vector for the training set. + - y_test (np.ndarray): The target vector for the testing set. + - train_clusters (Dict[str, Dict[int, Dict[int, np.ndarray]]]): Training + cluster representation. + - test_clusters (Dict[str, Dict[int, Dict[int, np.ndarray]]]): Testing + cluster representation. + - task (str): The task type, either 'classification' or 'regression'. + + Outputs: + - metrics_to_variants (Dict[str, Dict[str, Dict[int, float]]]): Mapping from + metrics (str) -> variants (str) -> nclust (int) -> score (float) + """ + + # create a mapping of metrics to measure + if task == "classification": + metrics = {"accuracy": accuracy_score, "roc_auc": roc_auc_score, + "average_precision": average_precision_score, + "f1": f1_score, "log_loss": log_loss} + else: + metrics = {"r2": r2_score, "rmse": root_mean_squared_error} + + # metrics (str) -> variants (str) -> nclust (int) -> score (float) + metrics_to_variants = {} + for metric_name, metric_func in metrics.items(): + variants_to_nclust = {} + for variant, nclust_map in train_clusters.items(): + nclust_to_score = {} + # for each number of clusters, get each cluster, fit a model, and + # calculate the metric + for nclust in range(2, 11): + # store scores in list in case some clusters have no test points + cluster_scores = [] + cluster_sizes = [] + # c = 1, ..., nclust, get the cluster and fit a model + for c in range(nclust): + # for train we can use nclust_map, but for test + # we need to use test_clusters, since nclust_map is the + # value for the training data + X_cluster_train = X_train[nclust_map[nclust][c]] + y_cluster_train = y_train[nclust_map[nclust][c]] + X_cluster_test = X_test[test_clusters[variant][nclust][c]] + y_cluster_test = y_test[test_clusters[variant][nclust][c]] + + # if no test points have been assigned to this cluster, skip + if X_cluster_test.shape[0] == 0: + continue + + # fit regression model to the cluster's training data + if task == "classification": + model = LogisticRegression() + else: + model = LinearRegression() + model.fit(X_cluster_train, y_cluster_train) + + # store the cluster scores and sizes for weighted average + y_cluster_pred = model.predict(X_cluster_test) + cluster_scores.append(metric_func(y_cluster_test, + y_cluster_pred)) + cluster_sizes.append(X_cluster_test.shape[0]) + + # now back in loop over nclust + nclust_to_score[nclust] = \ + weighted_metric(np.array(cluster_scores), + np.array(cluster_sizes)) + # now back in loop over variants + variants_to_nclust[variant] = nclust_to_score + # now back in loop over metrics + metrics_to_variants[metric_name] = variants_to_nclust + + return metrics_to_variants + +def write_results(result_dir: str, dataid: int, seed: int, clustertype: str, + metrics_to_scores: Dict[str, Dict[str, Dict[int, float]]]): + """ + Writes the results to a csv file. + + Inputs: + - result_dir (str): The directory to save the results. + - dataid (int): The OpenML dataset ID. + - seed (int): The random seed used. + - clustertype (str): The clustering method used. + - metrics_to_scores (Dict[str, Dict[str, Dict[int, float]]]): Results + calculated from compute_performance. + Outputs: + - None + """ + + # for each metric, save the results + for metric_name in metrics_to_scores.keys(): + # write the results to a csv file + print(f"Saving {metric_name} results...") + for variant in metrics_to_scores[metric_name].keys(): + # create dataframe with # of clusters and scores as columns + df = pd.DataFrame( + list(metrics_to_scores[metric_name][variant].items()), + columns=["nclust", f"{metric_name}"] + ) + # if the path does not exist, create it + if not os.path.exists(oj(result_dir, f"dataid{dataid}/seed{seed}"+ \ + f"/metric{metric_name}/{clustertype}")): + os.makedirs(oj(result_dir, f"dataid{dataid}/seed{seed}" + \ + f"/metric{metric_name}/{clustertype}")) + # save the dataframe to a csv file + df.to_csv(oj(result_dir, f"dataid{dataid}/seed{seed}/metric" + \ + f"{metric_name}/{clustertype}", f"{variant}.csv")) + return + if __name__ == '__main__': # store command-line arguments @@ -595,27 +753,7 @@ def get_test_clusters(lfi_test_values: Dict[str, np.ndarray], task) # create list of lmdi variants - loo = {True: "aloo", False: "nonloo"} - l2norm = {True: "l2", False: "nonl2"} - sign = {True: "signed", False: "unsigned"} - normalize = {True: "normed", False: "nonnormed"} - leaf_average = {True: "leafavg", False: "noleafavg"} - ranking = {True: "rank", False: "norank"} - lmdi_variants = {} - for l in loo: - for n in l2norm: - for s in sign: - for nn in normalize: - # sign and normalize are only relevant if l2norm is True - if (not n) and (s or nn): - continue - for la in leaf_average: - for r in ranking: - variant_name = f"{loo[l]}_{l2norm[n]}_{sign[s]}_{normalize[nn]}_{leaf_average[la]}_{ranking[r]}" - arg_map = {"loo": l, "l2norm": n, "sign": s, - "normalize": nn, "leaf_average": la, - "ranking": r} - lmdi_variants[variant_name] = arg_map + lmdi_variants = create_lmdi_variant_map() # obtain lmdi feature importances lmdi_explainers = get_lmdi_explainers(rf_plus, lmdi_variants, @@ -632,69 +770,22 @@ def get_test_clusters(lfi_test_values: Dict[str, np.ndarray], lfi_test_values["shap"] = shap_test_values lfi_test_rankings["shap"] = shap_test_rankings + # add the raw data to the dictionaries as a baseline of comparison + lfi_train_values["rawdata"] = X_train + lfi_test_values["rawdata"] = X_test + # get the clusterings # method_to_labels, method_to_indices = get_train_clusters(lfi_train_values, clustertype) train_clusters = get_train_clusters(lfi_train_values, clustertype) cluster_centroids = get_cluster_centroids(lfi_train_values, train_clusters) test_clusters = get_test_clusters(lfi_test_values, cluster_centroids) - # create a mapping of metrics to measure - if task == "classification": - metrics = {"accuracy": accuracy_score, "roc_auc": roc_auc_score, - "average_precision": average_precision_score, - "f1": f1_score, "log_loss": log_loss} - else: - metrics = {"r2": r2_score, "rmse": root_mean_squared_error} - - # for each method, for each number of clusters, - # train a linear model on the training set for each cluster and - # use it to predict the testing set for each cluster. save the results. - metrics_to_methods = {} - for metric_name, metric_func in metrics.items(): - metrics_to_methods[metric_name] = {} - for method in train_clusters.keys(): - methods_to_scores = {} - for num_clusters in range(2, 11): - cluster_scores = [] - cluster_sizes = [] - for cluster_idx in range(num_clusters): - X_cluster_train = X_train[train_clusters[method][num_clusters][cluster_idx]] - y_cluster_train = y_train[train_clusters[method][num_clusters][cluster_idx]] - X_cluster_test = X_test[test_clusters[method][num_clusters][cluster_idx]] - y_cluster_test = y_test[test_clusters[method][num_clusters][cluster_idx]] - if X_cluster_test.shape[0] == 0: - continue - if task == "classification": - model = LogisticRegression() - else: - model = LinearRegression() - model.fit(X_cluster_train, y_cluster_train) - # print("Method:", method, "; # Clusters:", num_clusters, "; Cluster:", cluster_idx) - # print(X_cluster_test.shape) - # print(X_cluster_train.shape) - y_cluster_pred = model.predict(X_cluster_test) - cluster_scores.append(metric_func(y_cluster_test, y_cluster_pred)) - cluster_sizes.append(X_cluster_test.shape[0]) - methods_to_scores[num_clusters] = \ - weighted_metric(np.array(cluster_scores), np.array(cluster_sizes)) - # average accuracy across clusters - metrics_to_methods[metric_name][method] = methods_to_scores + # compute the performance + metrics_to_scores = compute_performance(X_train, X_test, y_train, y_test, + train_clusters, test_clusters, task) # save the results result_dir = oj(os.path.dirname(os.path.realpath(__file__)), 'results/') - # print(result_dir) - # print(metrics_to_methods) - for metric_name in metrics_to_methods.keys(): - # write the results to a csv file - print(f"Saving {metric_name} results...") - for method in metrics_to_methods[metric_name].keys(): - print("Method:", method) - # print(metrics_to_methods[metric_name]) - df = pd.DataFrame(list(metrics_to_methods[metric_name][method].items()), columns=["nclust", f"{metric_name}"]) - # print(df) - if not os.path.exists(oj(result_dir, f"dataid{dataid}/seed{seed}/metric{metric_name}/{clustertype}")): - os.makedirs(oj(result_dir, f"dataid{dataid}/seed{seed}/metric{metric_name}/{clustertype}")) - df.to_csv(oj(result_dir, - f"dataid{dataid}/seed{seed}/metric{metric_name}/{clustertype}", f"{method}.csv")) + write_results(result_dir, dataid, seed, clustertype, metrics_to_scores) print("Results saved!") \ No newline at end of file