diff --git a/feature_importance/ablation_demo.ipynb b/feature_importance/ablation_demo.ipynb index c41a3a0..e072049 100644 --- a/feature_importance/ablation_demo.ipynb +++ b/feature_importance/ablation_demo.ipynb @@ -135,6 +135,260 @@ " return result_table" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "n = 200\n", + "d = 10\n", + "mean = [[0]*5 + [0]*5, [10]*5 + [0]*5]\n", + "scale = [[1]*10,[1]*10]\n", + "s = 5\n", + "X = sample_normal_X_subgroups(n, d, mean, scale)\n", + "beta = np.concatenate((np.ones(s), np.zeros(d-s)))\n", + "y = np.matmul(X, beta)\n", + "split_seed = 0\n", + "X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, \"train-test\", split_seed)\n", + "\n", + "rf_regressor = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=331)\n", + "rf_regressor.fit(X_train, y_train)\n", + "seed = 0\n", + "rf_plus_model = RandomForestPlusRegressor(rf_model=copy.deepcopy(rf_regressor), include_raw=False)\n", + "rf_plus_model = RandomForestPlusRegressor(rf_model=rf_regressor, include_raw=False)\n", + "rf_plus_model.fit(X_train, y_train)\n", + "\n", + "score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\", sample_split=None)\n", + "local_fi_score = score[\"lfi\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | 0 | \n", + "1 | \n", + "2 | \n", + "3 | \n", + "4 | \n", + "5 | \n", + "6 | \n", + "7 | \n", + "8 | \n", + "9 | \n", + "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "3.440747 | \n", + "4.255933 | \n", + "6.700399 | \n", + "8.639199 | \n", + "5.333437 | \n", + "0.325618 | \n", + "0.640087 | \n", + "0.223689 | \n", + "0.185920 | \n", + "0.176673 | \n", + "
1 | \n", + "2.805189 | \n", + "4.695955 | \n", + "6.887698 | \n", + "5.888381 | \n", + "5.774552 | \n", + "0.462473 | \n", + "0.569851 | \n", + "0.166908 | \n", + "0.252426 | \n", + "0.145043 | \n", + "
2 | \n", + "2.994209 | \n", + "4.769220 | \n", + "7.218987 | \n", + "5.956899 | \n", + "5.627507 | \n", + "0.449225 | \n", + "0.525465 | \n", + "0.166205 | \n", + "0.216060 | \n", + "0.127154 | \n", + "
3 | \n", + "4.183443 | \n", + "4.734916 | \n", + "8.226171 | \n", + "4.774889 | \n", + "7.257146 | \n", + "0.282170 | \n", + "0.363478 | \n", + "0.157185 | \n", + "0.224982 | \n", + "0.166350 | \n", + "
4 | \n", + "3.262282 | \n", + "4.380160 | \n", + "7.201394 | \n", + "5.638692 | \n", + "5.665000 | \n", + "0.589775 | \n", + "0.425661 | \n", + "0.138957 | \n", + "0.197739 | \n", + "0.126664 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
61 | \n", + "4.045584 | \n", + "4.820155 | \n", + "6.939586 | \n", + "5.668030 | \n", + "7.662810 | \n", + "0.294419 | \n", + "0.379058 | \n", + "0.130060 | \n", + "0.160509 | \n", + "0.136078 | \n", + "
62 | \n", + "3.239955 | \n", + "3.826604 | \n", + "6.492881 | \n", + "5.832585 | \n", + "5.394387 | \n", + "0.666766 | \n", + "0.847282 | \n", + "0.292280 | \n", + "0.235026 | \n", + "0.101330 | \n", + "
63 | \n", + "4.852265 | \n", + "3.986563 | \n", + "6.360245 | \n", + "7.147750 | \n", + "5.026830 | \n", + "0.525546 | \n", + "0.372760 | \n", + "0.084053 | \n", + "0.113026 | \n", + "0.166653 | \n", + "
64 | \n", + "3.000221 | \n", + "4.311225 | \n", + "6.570626 | \n", + "5.680667 | \n", + "5.084522 | \n", + "0.680960 | \n", + "0.919951 | \n", + "0.111115 | \n", + "0.261413 | \n", + "0.163289 | \n", + "
65 | \n", + "3.974615 | \n", + "4.496878 | \n", + "6.546694 | \n", + "7.366003 | \n", + "5.269932 | \n", + "0.339395 | \n", + "0.458521 | \n", + "0.099534 | \n", + "0.238012 | \n", + "0.151942 | \n", + "
66 rows × 10 columns
\n", + "