diff --git a/feature_importance/ablation_demo.ipynb b/feature_importance/ablation_demo.ipynb index c41a3a0..e072049 100644 --- a/feature_importance/ablation_demo.ipynb +++ b/feature_importance/ablation_demo.ipynb @@ -135,6 +135,260 @@ " return result_table" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "n = 200\n", + "d = 10\n", + "mean = [[0]*5 + [0]*5, [10]*5 + [0]*5]\n", + "scale = [[1]*10,[1]*10]\n", + "s = 5\n", + "X = sample_normal_X_subgroups(n, d, mean, scale)\n", + "beta = np.concatenate((np.ones(s), np.zeros(d-s)))\n", + "y = np.matmul(X, beta)\n", + "split_seed = 0\n", + "X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, \"train-test\", split_seed)\n", + "\n", + "rf_regressor = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=331)\n", + "rf_regressor.fit(X_train, y_train)\n", + "seed = 0\n", + "rf_plus_model = RandomForestPlusRegressor(rf_model=copy.deepcopy(rf_regressor), include_raw=False)\n", + "rf_plus_model = RandomForestPlusRegressor(rf_model=rf_regressor, include_raw=False)\n", + "rf_plus_model.fit(X_train, y_train)\n", + "\n", + "score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\", sample_split=None)\n", + "local_fi_score = score[\"lfi\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789
03.4407474.2559336.7003998.6391995.3334370.3256180.6400870.2236890.1859200.176673
12.8051894.6959556.8876985.8883815.7745520.4624730.5698510.1669080.2524260.145043
22.9942094.7692207.2189875.9568995.6275070.4492250.5254650.1662050.2160600.127154
34.1834434.7349168.2261714.7748897.2571460.2821700.3634780.1571850.2249820.166350
43.2622824.3801607.2013945.6386925.6650000.5897750.4256610.1389570.1977390.126664
.................................
614.0455844.8201556.9395865.6680307.6628100.2944190.3790580.1300600.1605090.136078
623.2399553.8266046.4928815.8325855.3943870.6667660.8472820.2922800.2350260.101330
634.8522653.9865636.3602457.1477505.0268300.5255460.3727600.0840530.1130260.166653
643.0002214.3112256.5706265.6806675.0845220.6809600.9199510.1111150.2614130.163289
653.9746154.4968786.5466947.3660035.2699320.3393950.4585210.0995340.2380120.151942
\n", + "

66 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 3.440747 4.255933 6.700399 8.639199 5.333437 0.325618 0.640087 \n", + "1 2.805189 4.695955 6.887698 5.888381 5.774552 0.462473 0.569851 \n", + "2 2.994209 4.769220 7.218987 5.956899 5.627507 0.449225 0.525465 \n", + "3 4.183443 4.734916 8.226171 4.774889 7.257146 0.282170 0.363478 \n", + "4 3.262282 4.380160 7.201394 5.638692 5.665000 0.589775 0.425661 \n", + ".. ... ... ... ... ... ... ... \n", + "61 4.045584 4.820155 6.939586 5.668030 7.662810 0.294419 0.379058 \n", + "62 3.239955 3.826604 6.492881 5.832585 5.394387 0.666766 0.847282 \n", + "63 4.852265 3.986563 6.360245 7.147750 5.026830 0.525546 0.372760 \n", + "64 3.000221 4.311225 6.570626 5.680667 5.084522 0.680960 0.919951 \n", + "65 3.974615 4.496878 6.546694 7.366003 5.269932 0.339395 0.458521 \n", + "\n", + " 7 8 9 \n", + "0 0.223689 0.185920 0.176673 \n", + "1 0.166908 0.252426 0.145043 \n", + "2 0.166205 0.216060 0.127154 \n", + "3 0.157185 0.224982 0.166350 \n", + "4 0.138957 0.197739 0.126664 \n", + ".. ... ... ... \n", + "61 0.130060 0.160509 0.136078 \n", + "62 0.292280 0.235026 0.101330 \n", + "63 0.084053 0.113026 0.166653 \n", + "64 0.111115 0.261413 0.163289 \n", + "65 0.099534 0.238012 0.151942 \n", + "\n", + "[66 rows x 10 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "local_fi_score" + ] + }, { "cell_type": "code", "execution_count": 5,