From 9c31b554cbed694e688fb4d5590940ad5acc04f0 Mon Sep 17 00:00:00 2001 From: Eduardo Souza Date: Fri, 26 Aug 2022 10:34:50 -0300 Subject: [PATCH] Causal T-learner (#206) * Included initial Meta S-learner classifier code. * Included references in docstring. * Fixed numpy typing. * Removed unused import * Created new type to allow mutable features, which is required with S-learner specifically. * Fixed types based on new created type. * Removed unused import. * Requires learner to be mandatory. * Removed lgbm import * meta_learners.py: converted _predict_for_treatment and _predict_for_control functions into single _predict_by_treatment_flag function. * Renamed LearnerMutableFeaturesFnType to LearnerMutableParametersFnType. * included tests * Fixed docstring and handled case in which uplift is negative. * fix lint * removed unused imports - wip * added test__simulate_treatment_effect * fix assertions * fix assert_frame_equal * meta_learners.py: included control_name parameter in _simulate_treatment_effect call. * Included custom exceptions. * meta_learner.py: Included exception checks in some functions. * meta_learners.py: fixed bug. Changed treatment_col by treatment_name * included assertion tests * updated predict_by_treatment_flag test functions to test only the treatment value (1 or 0) * added tests for causal_s_classification_learner * changed function test name * included explanation to test__simulate_treatment_effect * moved column drop from simulate_treatment_effect to predict_by_treatment_flag * changed mock learner for a mock function at test__simulate_treatment_effect * included test for learners being the correct type * meta_learners.py: fixed replace to x.replace(__uplift, ) * test_meta_learners.py: fixed test__simulate_treatment_effect to adapt change in replace on previous commit. * meta_learners.py: fixed description of prediction_column in docstring. * causal_s_learner_demo.ipynb: included demo notebook for causal s-learner. * causal_s_learner_demo.ipynb: fixed notebook headers. * Rerun notebook. * causal_s_learner_demo.ipynb: included documentation in the notebook. * Improved notebook description. * Initial commit of T-Learner * Add docstring * Lint * Change _get_model_fcn return type * Fix learner type * Removed unused import * Removed unused import * Updated types * Fix variables naming * Added tests * Lint * Type fix * Improve documentation of T-Learner Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Fixed t-learner docstring Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Fixed t-learner docstring Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Fixed t-learner parameter type Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Fixed t-learner docstring Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Fix t-learner argument type Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Improve T-Learner types Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Remove unused type Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Add SHAP to logs Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> * Update src/fklearn/causal/cate_learning/meta_learners.py * Applied review comments * Update demo notebook * Merged with master and included rst * Fix test * Update t-learner notebook Co-authored-by: Giulio Santo Co-authored-by: Nicolas Behar Co-authored-by: Giulio Cesare Mastrocinque Santo <32403782+GiulioCMSanto@users.noreply.github.com> --- .../api/fklearn.causal.cate_learning.rst | 22 + notebooks/causal_s_learner_demo.ipynb | 2 +- notebooks/causal_t_learner_demo.ipynb | 2434 +++++++++++++++++ .../causal/cate_learning/meta_learners.py | 201 +- .../cate_learning/test_meta_learners.py | 266 +- 5 files changed, 2867 insertions(+), 58 deletions(-) create mode 100644 docs/source/api/fklearn.causal.cate_learning.rst create mode 100644 notebooks/causal_t_learner_demo.ipynb diff --git a/docs/source/api/fklearn.causal.cate_learning.rst b/docs/source/api/fklearn.causal.cate_learning.rst new file mode 100644 index 00000000..35888232 --- /dev/null +++ b/docs/source/api/fklearn.causal.cate_learning.rst @@ -0,0 +1,22 @@ +fklearn.causal.cate\_learning package +===================================== + +Submodules +---------- + +fklearn.causal.cate\_learning.double\_machine\_learning module +-------------------------------------------------------------- + +.. automodule:: fklearn.causal.cate_learning.double_machine_learning + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: fklearn.causal.cate_learning + :members: + :undoc-members: + :show-inheritance: diff --git a/notebooks/causal_s_learner_demo.ipynb b/notebooks/causal_s_learner_demo.ipynb index ab017263..d5cb5b8f 100644 --- a/notebooks/causal_s_learner_demo.ipynb +++ b/notebooks/causal_s_learner_demo.ipynb @@ -2439,4 +2439,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/notebooks/causal_t_learner_demo.ipynb b/notebooks/causal_t_learner_demo.ipynb new file mode 100644 index 00000000..fe848b5e --- /dev/null +++ b/notebooks/causal_t_learner_demo.ipynb @@ -0,0 +1,2434 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f84aeac6", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T17:38:44.522833Z", + "start_time": "2022-08-01T17:38:44.516580Z" + } + }, + "source": [ + "## Causal T-Learner Classifier\n", + "\n", + "### TL; DR\n", + "---\n", + "This notebooks exemplifies how one can use the causal T-learner through fklearn.\n", + "\n", + "### Long\n", + "---\n", + "A very interesting and useful type of causal models are uplift models, in which one is able to identify how each sample responds to a given treatment, and what is the effect of that treatment compared to a control group. The main goal of uplift models is, therefore, to learn the difference in probability of a sample converting (using a product), given that it was submitted to some action (nudge). Meta-learns are examples of causal models, in which the CATE represents how each unit will respond to a given treatment [1]. In fact, the uplift can be understood as the incremental gain in the conversion probability in the case a given sample was in the treatment group instead of the control one. In addition, these models have the advantage of using conventional machine learning models, such as LightGBM.\n", + "\n", + "More specifically, the T-Learner is a meta-learner which learns the Conditional Average Treatment Effect (CATE) through the use of multiple models, one for each treatment. Each model is fitted in a subset of the data, according to the treatment. The CATE $\\tau$ is defined as $\\tau(x_{i}) = M_{1}(X=x_{i}, T=1) - M_{0}(X=x_{i}, T=0)$, being $M_{1}$ a model fitted with treatment data and $M_{0}$ a model fitted with control data, and they can be a Machine Learning Model such as a LightGBM Classifier and $x_{i}$ the feature set of sample $i$.\n", + "\n", + "### Data\n", + "---\n", + "The data here adopted is provided in [1], [2].\n", + "\n", + "### References\n", + "---\n", + "\n", + "[1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html\n", + "\n", + "[2] https://github.com/matheusfacure/python-causality-handbook/tree/master/causal-inference-for-the-brave-and-true/data\n", + "\n", + "[3] https://causalml.readthedocs.io/en/latest/methodology.html" + ] + }, + { + "cell_type": "markdown", + "id": "651c52e1", + "metadata": {}, + "source": [ + "### 1. Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "209dd195", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b9ae70be", + "metadata": {}, + "outputs": [], + "source": [ + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "909075ef", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/Users/eduardo.souza/dev/nu/fklearn/src'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.chdir(\"../src\")\n", + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "334c24ac", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:54.942876Z", + "start_time": "2022-08-01T21:08:52.183053Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "from fklearn.causal.cate_learning.meta_learners import causal_t_classification_learner\n", + "from fklearn.training.classification import lgbm_classification_learner\n", + "from fklearn.training.calibration import isotonic_calibration_learner\n", + "from fklearn.causal.validation.curves import cumulative_gain_curve\n", + "from fklearn.training.pipeline import build_pipeline\n", + "from fklearn.training.transformation import ecdfer\n", + "\n", + "sns.set_style(\"darkgrid\")" + ] + }, + { + "cell_type": "markdown", + "id": "7951044d", + "metadata": {}, + "source": [ + "### 2. Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3b3820cf", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:54.949172Z", + "start_time": "2022-08-01T21:08:54.944256Z" + } + }, + "outputs": [], + "source": [ + "def plot_cumulative_gain_curve(\n", + " train_gain: np.ndarray,\n", + " test_gain: np.ndarray,\n", + " random_gain: np.ndarray,\n", + " fontsize: int = 16,\n", + " figsize: tuple = (15,5)\n", + ") -> None:\n", + " \"\"\"\n", + " Plots the cumulative gain curve.\n", + " \"\"\"\n", + "\n", + " xaxis = np.arange(len(train_gain))/len(train_gain)\n", + " \n", + " plt.figure(figsize=figsize);\n", + " plt.plot(xaxis, train_gain, label=\"Training Data\");\n", + " plt.plot(xaxis, test_gain, label=\"Testing Data\");\n", + " plt.plot(xaxis, random_gain, \"--\", label=\"Random\");\n", + " \n", + " plt.ylabel(\"Cumulative Gain\", fontsize=fontsize);\n", + " plt.xlabel(\"Population Proportion\", fontsize=fontsize);\n", + " plt.title(\"Cumulative Gain Curve\", fontsize=fontsize);\n", + " \n", + " plt.legend(fontsize=fontsize);\n", + " plt.xticks(fontsize=fontsize);\n", + " plt.yticks(fontsize=fontsize);" + ] + }, + { + "cell_type": "markdown", + "id": "1bb4b598", + "metadata": {}, + "source": [ + "### 3. Read Data" + ] + }, + { + "cell_type": "markdown", + "id": "99ced58c", + "metadata": {}, + "source": [ + "Notice that the data here adopted is provided in [1] and [2]." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "164af8a7", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.654415Z", + "start_time": "2022-08-01T21:08:54.952419Z" + } + }, + "outputs": [], + "source": [ + "test_data = pd.read_csv(\n", + " \"https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/invest_email_rnd.csv\"\n", + ")\n", + "train_data = pd.read_csv(\n", + " \"https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/invest_email_biased.csv\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "c4100764", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.665495Z", + "start_time": "2022-08-01T21:08:55.656451Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(15000, 8)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "9dbc21fd", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.671480Z", + "start_time": "2022-08-01T21:08:55.667548Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(15000, 8)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_data.shape" + ] + }, + { + "cell_type": "markdown", + "id": "bfba7573", + "metadata": {}, + "source": [ + "#### 3.1 Include Treatment Column" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bed6542e", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.684291Z", + "start_time": "2022-08-01T21:08:55.673873Z" + } + }, + "outputs": [], + "source": [ + "train_data[\"control\"] = np.where((train_data[\"em1\"]+train_data[\"em2\"]+train_data[\"em3\"])==0,1,0)\n", + "train_data[\"treatment_col\"] = train_data[[\"em1\",\"em2\",\"em3\",\"control\"]].idxmax(axis=1).values\n", + "\n", + "test_data[\"control\"] = np.where((test_data[\"em1\"]+test_data[\"em2\"]+test_data[\"em3\"])==0,1,0)\n", + "test_data[\"treatment_col\"] = test_data[[\"em1\",\"em2\",\"em3\",\"control\"]].idxmax(axis=1).values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b951bc0c", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.701127Z", + "start_time": "2022-08-01T21:08:55.686721Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_col
044.15483.806155.2914294.8100100em3
139.82737.9250069.407468.1510000em1
249.02712.515707.085095.6500110em3
339.72326.3715657.976345.2000001control
435.32787.2627074.4414114.8611000em1
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n", + "\n", + " treatment_col \n", + "0 em3 \n", + "1 em1 \n", + "2 em3 \n", + "3 control \n", + "4 em1 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2dc0eb3e", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.717300Z", + "start_time": "2022-08-01T21:08:55.703374Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_col
044.15483.806155.2914294.8101100em2
139.82737.9250069.407468.1510000em1
249.02712.515707.085095.6510110em1
339.72326.3715657.976345.2011100em1
435.32787.2627074.4414114.8611100em1
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 1 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 1 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 1 1 1 0 0 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 1 0 0 \n", + "\n", + " treatment_col \n", + "0 em2 \n", + "1 em1 \n", + "2 em1 \n", + "3 em1 \n", + "4 em1 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "b02dd80c", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T17:34:49.478324Z", + "start_time": "2022-08-01T17:34:49.472616Z" + } + }, + "source": [ + "### 4. Causal T-Learner" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f2e41cb9", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.727534Z", + "start_time": "2022-08-01T21:08:55.724439Z" + } + }, + "outputs": [], + "source": [ + "target_column = \"converted\"\n", + "features = [\"age\", \"income\", \"insurance\", \"invested\"]\n", + "treatment_column = \"treatment_col\"\n", + "control_name = \"control\"\n", + "prediction_column = \"prediction\"" + ] + }, + { + "cell_type": "markdown", + "id": "f13a389d", + "metadata": {}, + "source": [ + "#### 4.1 Using T-Learner with LightGBM" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "86154394", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.732761Z", + "start_time": "2022-08-01T21:08:55.729735Z" + } + }, + "outputs": [], + "source": [ + "clf_learner = lgbm_classification_learner(\n", + " features = features,\n", + " target = target_column,\n", + " prediction_column = prediction_column\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "885a539f", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:55.738793Z", + "start_time": "2022-08-01T21:08:55.735586Z" + } + }, + "outputs": [], + "source": [ + "t_learner = causal_t_classification_learner(\n", + " treatment_col=treatment_column,\n", + " control_name=control_name,\n", + " prediction_column=prediction_column,\n", + " learner=clf_learner,\n", + " treatment_learner=clf_learner,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e2d4cd23", + "metadata": {}, + "source": [ + "**Training the model**" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "68c21ad5", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.027747Z", + "start_time": "2022-08-01T21:08:55.740932Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000434 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 963\n", + "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n", + "[LightGBM] [Info] Start training from score -2.022283\n", + "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000334 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 898\n", + "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n", + "[LightGBM] [Info] Start training from score -1.296143\n", + "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000364 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 958\n", + "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n", + "[LightGBM] [Info] Start training from score -1.163774\n", + "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000133 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 906\n", + "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n", + "[LightGBM] [Info] Start training from score -0.537647\n", + "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + } + ], + "source": [ + "t_learner_fcn, t_learner_train_df, t_learner_log = t_learner(train_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "629ee628", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.034377Z", + "start_time": "2022-08-01T21:08:56.030279Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + ".p(new_df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame>" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t_learner_fcn" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bb863168", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.051585Z", + "start_time": "2022-08-01T21:08:56.036919Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_coltreatment_em3__prediction_on_treatmenttreatment_em3__uplifttreatment_em1__prediction_on_treatmenttreatment_em1__uplifttreatment_em2__prediction_on_treatmenttreatment_em2__upliftupliftsuggested_treatment
044.15483.806155.2914294.8100100em30.2595510.0798250.4479920.2682650.3083170.1285900.268265treatment_em1
139.82737.9250069.407468.1510000em10.091645-0.0037860.063676-0.0317550.015090-0.080341-0.003786control
249.02712.515707.085095.6500110em30.4644480.4180730.6826840.6363090.0612110.0148360.636309treatment_em1
339.72326.3715657.976345.2000001control0.0861450.0367140.2565010.2070710.1967680.1473380.207071treatment_em1
435.32787.2627074.4414114.8611000em10.1497710.1403580.2014140.1920010.0822540.0728420.192001treatment_em1
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n", + "\n", + " treatment_col treatment_em3__prediction_on_treatment \\\n", + "0 em3 0.259551 \n", + "1 em1 0.091645 \n", + "2 em3 0.464448 \n", + "3 control 0.086145 \n", + "4 em1 0.149771 \n", + "\n", + " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n", + "0 0.079825 0.447992 \n", + "1 -0.003786 0.063676 \n", + "2 0.418073 0.682684 \n", + "3 0.036714 0.256501 \n", + "4 0.140358 0.201414 \n", + "\n", + " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n", + "0 0.268265 0.308317 \n", + "1 -0.031755 0.015090 \n", + "2 0.636309 0.061211 \n", + "3 0.207071 0.196768 \n", + "4 0.192001 0.082254 \n", + "\n", + " treatment_em2__uplift uplift suggested_treatment \n", + "0 0.128590 0.268265 treatment_em1 \n", + "1 -0.080341 -0.003786 control \n", + "2 0.014836 0.636309 treatment_em1 \n", + "3 0.147338 0.207071 treatment_em1 \n", + "4 0.072842 0.192001 treatment_em1 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t_learner_train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "ae0405f3", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.059283Z", + "start_time": "2022-08-01T21:08:56.054044Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'causal_t_classification_learner': {'control': {'lgbm_classification_learner': {'features': ['age',\n", + " 'income',\n", + " 'insurance',\n", + " 'invested'],\n", + " 'target': 'converted',\n", + " 'prediction_column': 'prediction',\n", + " 'package': 'lightgbm',\n", + " 'package_version': '3.3.2',\n", + " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n", + " 'feature_importance': {'age': 739,\n", + " 'income': 811,\n", + " 'insurance': 724,\n", + " 'invested': 726},\n", + " 'training_samples': 4312,\n", + " 'running_time': '0.087 s'},\n", + " 'object': },\n", + " 'em3': {'lgbm_classification_learner': {'features': ['age',\n", + " 'income',\n", + " 'insurance',\n", + " 'invested'],\n", + " 'target': 'converted',\n", + " 'prediction_column': 'prediction',\n", + " 'package': 'lightgbm',\n", + " 'package_version': '3.3.2',\n", + " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n", + " 'feature_importance': {'age': 627,\n", + " 'income': 772,\n", + " 'insurance': 816,\n", + " 'invested': 785},\n", + " 'training_samples': 6480,\n", + " 'running_time': '0.064 s'},\n", + " 'object': },\n", + " 'em1': {'lgbm_classification_learner': {'features': ['age',\n", + " 'income',\n", + " 'insurance',\n", + " 'invested'],\n", + " 'target': 'converted',\n", + " 'prediction_column': 'prediction',\n", + " 'package': 'lightgbm',\n", + " 'package_version': '3.3.2',\n", + " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n", + " 'feature_importance': {'age': 769,\n", + " 'income': 747,\n", + " 'insurance': 748,\n", + " 'invested': 736},\n", + " 'training_samples': 3370,\n", + " 'running_time': '0.051 s'},\n", + " 'object': },\n", + " 'em2': {'lgbm_classification_learner': {'features': ['age',\n", + " 'income',\n", + " 'insurance',\n", + " 'invested'],\n", + " 'target': 'converted',\n", + " 'prediction_column': 'prediction',\n", + " 'package': 'lightgbm',\n", + " 'package_version': '3.3.2',\n", + " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n", + " 'feature_importance': {'age': 618,\n", + " 'income': 745,\n", + " 'insurance': 801,\n", + " 'invested': 835},\n", + " 'training_samples': 838,\n", + " 'running_time': '0.085 s'},\n", + " 'object': }}}" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t_learner_log" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2ed58ac7", + "metadata": {}, + "outputs": [], + "source": [ + "score_of_interest = \"treatment_em1__uplift\"" + ] + }, + { + "cell_type": "markdown", + "id": "dadfd813", + "metadata": {}, + "source": [ + "**Making Predictions**" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "19c87675", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.138488Z", + "start_time": "2022-08-01T21:08:56.061810Z" + } + }, + "outputs": [], + "source": [ + "t_learner_test_df = t_learner_fcn(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d992f3c8", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.155056Z", + "start_time": "2022-08-01T21:08:56.141081Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_coltreatment_em3__prediction_on_treatmenttreatment_em3__uplifttreatment_em1__prediction_on_treatmenttreatment_em1__uplifttreatment_em2__prediction_on_treatmenttreatment_em2__upliftupliftsuggested_treatment
044.15483.806155.2914294.8101100em20.2595510.0798250.4479920.2682650.3083170.1285900.268265treatment_em1
139.82737.9250069.407468.1510000em10.091645-0.0037860.063676-0.0317550.015090-0.080341-0.003786control
249.02712.515707.085095.6510110em10.4644480.4180730.6826840.6363090.0612110.0148360.636309treatment_em1
339.72326.3715657.976345.2011100em10.0861450.0367140.2565010.2070710.1967680.1473380.207071treatment_em1
435.32787.2627074.4414114.8611100em10.1497710.1403580.2014140.1920010.0822540.0728420.192001treatment_em1
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 1 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 1 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 1 1 1 0 0 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 1 0 0 \n", + "\n", + " treatment_col treatment_em3__prediction_on_treatment \\\n", + "0 em2 0.259551 \n", + "1 em1 0.091645 \n", + "2 em1 0.464448 \n", + "3 em1 0.086145 \n", + "4 em1 0.149771 \n", + "\n", + " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n", + "0 0.079825 0.447992 \n", + "1 -0.003786 0.063676 \n", + "2 0.418073 0.682684 \n", + "3 0.036714 0.256501 \n", + "4 0.140358 0.201414 \n", + "\n", + " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n", + "0 0.268265 0.308317 \n", + "1 -0.031755 0.015090 \n", + "2 0.636309 0.061211 \n", + "3 0.207071 0.196768 \n", + "4 0.192001 0.082254 \n", + "\n", + " treatment_em2__uplift uplift suggested_treatment \n", + "0 0.128590 0.268265 treatment_em1 \n", + "1 -0.080341 -0.003786 control \n", + "2 0.014836 0.636309 treatment_em1 \n", + "3 0.147338 0.207071 treatment_em1 \n", + "4 0.072842 0.192001 treatment_em1 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t_learner_test_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "60a1bfb9", + "metadata": {}, + "source": [ + "**Create Random Score**" + ] + }, + { + "cell_type": "markdown", + "id": "196a3a23", + "metadata": {}, + "source": [ + "Let's also create a random score that can be used to compute the Cumulative Gain curve." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "61728dd6", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.162206Z", + "start_time": "2022-08-01T21:08:56.157333Z" + } + }, + "outputs": [], + "source": [ + "random_score_df = test_data[[\"em1\", target_column]].copy()\n", + "random_score_df[score_of_interest] = np.random.uniform(0,1,random_score_df.shape[0])" + ] + }, + { + "cell_type": "markdown", + "id": "b5b65182", + "metadata": {}, + "source": [ + "**Checking Cumulative Gain Curve**" + ] + }, + { + "cell_type": "markdown", + "id": "c18d8dc7", + "metadata": {}, + "source": [ + "For more details about causal models evaluation, please look at the following reference:\n", + "\n", + "https://matheusfacure.github.io/python-causality-handbook/19-Evaluating-Causal-Models.html?highlight=gain%20curve " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "dada0b5c", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.167642Z", + "start_time": "2022-08-01T21:08:56.164515Z" + } + }, + "outputs": [], + "source": [ + "gain_curve = cumulative_gain_curve(\n", + " treatment = \"em1\",\n", + " outcome = target_column,\n", + " prediction = score_of_interest\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "44ac0567", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.479996Z", + "start_time": "2022-08-01T21:08:56.170218Z" + } + }, + "outputs": [], + "source": [ + "gain_curve_train = gain_curve(t_learner_train_df)\n", + "gain_curve_test = gain_curve(t_learner_test_df)\n", + "gain_curve_random = gain_curve(random_score_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "e0f39cfc", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.695829Z", + "start_time": "2022-08-01T21:08:56.481931Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_cumulative_gain_curve(\n", + " train_gain = gain_curve_train,\n", + " test_gain = gain_curve_test,\n", + " random_gain = gain_curve_random\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "16b67daa", + "metadata": {}, + "source": [ + "#### 4.2 Using T-Learner with Fklearn Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "5adc2f2f", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.701600Z", + "start_time": "2022-08-01T21:08:56.698546Z" + } + }, + "outputs": [], + "source": [ + "cdf = ecdfer(\n", + " prediction_column=score_of_interest\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "a995e275", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.705705Z", + "start_time": "2022-08-01T21:08:56.703511Z" + } + }, + "outputs": [], + "source": [ + "pipeline = build_pipeline(\n", + " *[t_learner, cdf]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "21db9ef8", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.885605Z", + "start_time": "2022-08-01T21:08:56.707835Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000739 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 963\n", + "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n", + "[LightGBM] [Info] Start training from score -2.022283\n", + "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000108 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 898\n", + "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n", + "[LightGBM] [Info] Start training from score -1.296143\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000157 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 958\n", + "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n", + "[LightGBM] [Info] Start training from score -1.163774\n", + "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000331 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 906\n", + "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n", + "[LightGBM] [Info] Start training from score -0.537647\n", + "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n" + ] + } + ], + "source": [ + "pipe_fcn, pipe_train_df, pipe_log = pipeline(\n", + " train_data\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "08b227a5", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:56.901902Z", + "start_time": "2022-08-01T21:08:56.887944Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_coltreatment_em3__prediction_on_treatmenttreatment_em3__uplifttreatment_em1__prediction_on_treatmenttreatment_em1__uplifttreatment_em2__prediction_on_treatmenttreatment_em2__upliftupliftsuggested_treatmentprediction_ecdf
044.15483.806155.2914294.8100100em30.2595510.0798250.4479920.2682650.3083170.1285900.268265treatment_em1754.266667
139.82737.9250069.407468.1510000em10.091645-0.0037860.063676-0.0317550.015090-0.080341-0.003786control110.066667
249.02712.515707.085095.6500110em30.4644480.4180730.6826840.6363090.0612110.0148360.636309treatment_em1989.133333
339.72326.3715657.976345.2000001control0.0861450.0367140.2565010.2070710.1967680.1473380.207071treatment_em1653.533333
435.32787.2627074.4414114.8611000em10.1497710.1403580.2014140.1920010.0822540.0728420.192001treatment_em1622.666667
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n", + "\n", + " treatment_col treatment_em3__prediction_on_treatment \\\n", + "0 em3 0.259551 \n", + "1 em1 0.091645 \n", + "2 em3 0.464448 \n", + "3 control 0.086145 \n", + "4 em1 0.149771 \n", + "\n", + " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n", + "0 0.079825 0.447992 \n", + "1 -0.003786 0.063676 \n", + "2 0.418073 0.682684 \n", + "3 0.036714 0.256501 \n", + "4 0.140358 0.201414 \n", + "\n", + " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n", + "0 0.268265 0.308317 \n", + "1 -0.031755 0.015090 \n", + "2 0.636309 0.061211 \n", + "3 0.207071 0.196768 \n", + "4 0.192001 0.082254 \n", + "\n", + " treatment_em2__uplift uplift suggested_treatment prediction_ecdf \n", + "0 0.128590 0.268265 treatment_em1 754.266667 \n", + "1 -0.080341 -0.003786 control 110.066667 \n", + "2 0.014836 0.636309 treatment_em1 989.133333 \n", + "3 0.147338 0.207071 treatment_em1 653.533333 \n", + "4 0.072842 0.192001 treatment_em1 622.666667 " + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipe_train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "3f0ae64f", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.152815Z", + "start_time": "2022-08-01T21:08:56.904019Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pipe_train_df[\"uplift\"].hist(bins=50);\n", + "plt.title(\"Uplift Score Distribution\", fontsize=16);\n", + "plt.xticks(fontsize=16);\n", + "plt.yticks(fontsize=16);" + ] + }, + { + "cell_type": "markdown", + "id": "80873d8e", + "metadata": {}, + "source": [ + "**Checking Gain Curve with ECDF**" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "304ce7c2", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.405122Z", + "start_time": "2022-08-01T21:08:57.402884Z" + } + }, + "outputs": [], + "source": [ + "gain_curve = cumulative_gain_curve(\n", + " treatment = \"em1\",\n", + " outcome = target_column,\n", + " prediction = \"prediction_ecdf\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "875df4da", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.536943Z", + "start_time": "2022-08-01T21:08:57.406693Z" + } + }, + "outputs": [], + "source": [ + "pipeline_train_df = pipe_fcn(train_data)\n", + "pipeline_test_df = pipe_fcn(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "ed3ee7e7", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.747225Z", + "start_time": "2022-08-01T21:08:57.539511Z" + } + }, + "outputs": [], + "source": [ + "gain_curve_train = gain_curve(pipeline_train_df)\n", + "gain_curve_test = gain_curve(pipeline_test_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "af2707ed", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.939012Z", + "start_time": "2022-08-01T21:08:57.749300Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_cumulative_gain_curve(\n", + " train_gain = gain_curve_train,\n", + " test_gain = gain_curve_test,\n", + " random_gain = gain_curve_random\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7ee410fb", + "metadata": {}, + "source": [ + "#### 4.3 Build T-learner Using Lightgbm with Isotonic Calibration" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "ea34b176", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.944825Z", + "start_time": "2022-08-01T21:08:57.941394Z" + } + }, + "outputs": [], + "source": [ + "clf_learner = lgbm_classification_learner(\n", + " features = features,\n", + " target = target_column,\n", + " prediction_column = prediction_column\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "f328d84c", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.950805Z", + "start_time": "2022-08-01T21:08:57.947489Z" + } + }, + "outputs": [], + "source": [ + "calibrator = isotonic_calibration_learner(\n", + " target_column=target_column,\n", + " prediction_column=prediction_column,\n", + " output_column=\"calibration_prediction\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "9abd3945", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.955102Z", + "start_time": "2022-08-01T21:08:57.952772Z" + } + }, + "outputs": [], + "source": [ + "t_learner = causal_t_classification_learner(\n", + " treatment_col=treatment_column,\n", + " control_name=control_name,\n", + " prediction_column=\"calibration_prediction\",\n", + " learner=clf_learner,\n", + " learner_transformers=[calibrator]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "47630533", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.959571Z", + "start_time": "2022-08-01T21:08:57.957001Z" + } + }, + "outputs": [], + "source": [ + "cdf = ecdfer(\n", + " prediction_column=score_of_interest\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "9dfb55d3", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:57.964743Z", + "start_time": "2022-08-01T21:08:57.961573Z" + } + }, + "outputs": [], + "source": [ + "pipeline = build_pipeline(\n", + " *[t_learner, cdf]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "bb117302", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.188382Z", + "start_time": "2022-08-01T21:08:57.968255Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000469 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 963\n", + "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n", + "[LightGBM] [Info] Start training from score -2.022283\n", + "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000172 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 898\n", + "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n", + "[LightGBM] [Info] Start training from score -1.296143\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n", + "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n", + " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000512 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 958\n", + "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n", + "[LightGBM] [Info] Start training from score -1.163774\n", + "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n", + "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000282 seconds.\n", + "You can set `force_col_wise=true` to remove the overhead.\n", + "[LightGBM] [Info] Total Bins 906\n", + "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n", + "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n", + "[LightGBM] [Info] Start training from score -0.537647\n", + "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n" + ] + } + ], + "source": [ + "pipe_fcn, pipe_train_df, pipe_log = pipeline(\n", + " train_data\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "0a80c08c", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.205741Z", + "start_time": "2022-08-01T21:08:58.190857Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageincomeinsuranceinvestedem1em2em3convertedcontroltreatment_coltreatment_em3__calibration_prediction_on_treatmenttreatment_em3__uplifttreatment_em1__calibration_prediction_on_treatmenttreatment_em1__uplifttreatment_em2__calibration_prediction_on_treatmenttreatment_em2__upliftupliftsuggested_treatmentprediction_ecdf
044.15483.806155.2914294.8100100em30.3389830.2004290.8207550.6822010.0-0.1385540.682201treatment_em1813.133333
139.82737.9250069.407468.1510000em10.000000-0.0017270.000000-0.0017270.0-0.001727-0.001727control225.400000
249.02712.515707.085095.6500110em30.8833330.8833330.9825580.9825580.00.0000000.982558treatment_em1997.333333
339.72326.3715657.976345.2000001control0.0000000.0000000.1406250.1406250.00.0000000.140625treatment_em1606.133333
435.32787.2627074.4414114.8611000em10.0100500.0100500.0339510.0339510.00.0000000.033951treatment_em1518.400000
\n", + "
" + ], + "text/plain": [ + " age income insurance invested em1 em2 em3 converted control \\\n", + "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n", + "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n", + "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n", + "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n", + "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n", + "\n", + " treatment_col treatment_em3__calibration_prediction_on_treatment \\\n", + "0 em3 0.338983 \n", + "1 em1 0.000000 \n", + "2 em3 0.883333 \n", + "3 control 0.000000 \n", + "4 em1 0.010050 \n", + "\n", + " treatment_em3__uplift treatment_em1__calibration_prediction_on_treatment \\\n", + "0 0.200429 0.820755 \n", + "1 -0.001727 0.000000 \n", + "2 0.883333 0.982558 \n", + "3 0.000000 0.140625 \n", + "4 0.010050 0.033951 \n", + "\n", + " treatment_em1__uplift treatment_em2__calibration_prediction_on_treatment \\\n", + "0 0.682201 0.0 \n", + "1 -0.001727 0.0 \n", + "2 0.982558 0.0 \n", + "3 0.140625 0.0 \n", + "4 0.033951 0.0 \n", + "\n", + " treatment_em2__uplift uplift suggested_treatment prediction_ecdf \n", + "0 -0.138554 0.682201 treatment_em1 813.133333 \n", + "1 -0.001727 -0.001727 control 225.400000 \n", + "2 0.000000 0.982558 treatment_em1 997.333333 \n", + "3 0.000000 0.140625 treatment_em1 606.133333 \n", + "4 0.000000 0.033951 treatment_em1 518.400000 " + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipe_train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "65693eaa", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.442712Z", + "start_time": "2022-08-01T21:08:58.208095Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pipe_train_df[\"treatment_em1__calibration_prediction_on_treatment\"].hist(bins=10);\n", + "plt.title(\"Uplift with Calibrated LightGBM Distribution\", fontsize=16);\n", + "plt.xticks(fontsize=16);\n", + "plt.yticks(fontsize=16);" + ] + }, + { + "cell_type": "markdown", + "id": "bcf7861e", + "metadata": {}, + "source": [ + "**Checking Gain Curve with ECDF**" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "08cfa1e4", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.613042Z", + "start_time": "2022-08-01T21:08:58.610279Z" + } + }, + "outputs": [], + "source": [ + "gain_curve = cumulative_gain_curve(\n", + " treatment = \"em1\",\n", + " outcome = target_column,\n", + " prediction = \"prediction_ecdf\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "e996c832", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.755011Z", + "start_time": "2022-08-01T21:08:58.614934Z" + } + }, + "outputs": [], + "source": [ + "pipeline_train_df = pipe_fcn(train_data)\n", + "pipeline_test_df = pipe_fcn(test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "159c3e84", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:58.968608Z", + "start_time": "2022-08-01T21:08:58.757109Z" + } + }, + "outputs": [], + "source": [ + "gain_curve_train = gain_curve(pipeline_train_df)\n", + "gain_curve_test = gain_curve(pipeline_test_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "d4dc9022", + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-01T21:08:59.164538Z", + "start_time": "2022-08-01T21:08:58.970696Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_cumulative_gain_curve(\n", + " train_gain = gain_curve_train,\n", + " test_gain = gain_curve_test,\n", + " random_gain = gain_curve_random\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.12 ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": false, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + }, + "vscode": { + "interpreter": { + "hash": "7992107b5676fcc2947e8c872ad2cde6c7066279bc6e7ee23218548d8b4682e3" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/fklearn/causal/cate_learning/meta_learners.py b/src/fklearn/causal/cate_learning/meta_learners.py index 73719dd7..039afde7 100644 --- a/src/fklearn/causal/cate_learning/meta_learners.py +++ b/src/fklearn/causal/cate_learning/meta_learners.py @@ -1,6 +1,6 @@ import copy import inspect -from typing import Callable, List, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np import pandas as pd @@ -185,32 +185,24 @@ def causal_s_classification_learner( of a new sample for both scenarios, i.e., with T = 0 and T = 1. The CATE τ is defined as τ(xi) = M(X=xi, T=1) - M(X=xi, T=0), being M a Machine Learning Model. - References: [1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html [2] https://causalml.readthedocs.io/en/latest/methodology.html - Parameters ---------- - df : pd.DataFrame A Pandas' DataFrame with features and target columns. The model will be trained to predict the target column from the features. - treatment_col: str The name of the column in `df` which contains the names of the treatments or control to which each data sample was subjected. - control_name: str The name of the control group. - prediction_column : str The name of the column with the predictions from the provided learner. - learner: Callable A fklearn classification learner function. - learner_transformers: list A list of fklearn transformer functions to be applied after the learner and before estimating the CATE. This parameter may be useful, for example, to estimate the CATE with calibrated classifiers. @@ -264,3 +256,194 @@ def p(new_df: pd.DataFrame) -> pd.DataFrame: causal_s_classification_learner.__doc__ += learner_return_docstring( "Causal S-Learner Classifier" ) + + +def _simulate_t_learner_treatment_effect( + df: pd.DataFrame, + learners: dict, + treatments: list, + control_name: str, + prediction_column: str, +) -> pd.DataFrame: + control_fcn = learners[control_name] + control_conversion_probability = control_fcn(df)[prediction_column].values + + scored_df = df.copy() + + uplift_cols = [] + for treatment_name in treatments: + treatment_fcn = learners[treatment_name] + treatment_conversion_probability = treatment_fcn(df)[prediction_column].values + + scored_df[ + f"treatment_{treatment_name}__{prediction_column}_on_treatment" + ] = treatment_conversion_probability + + uplift_cols.append(f"treatment_{treatment_name}__uplift") + scored_df[uplift_cols[-1]] = ( + treatment_conversion_probability - control_conversion_probability + ) + + scored_df["uplift"] = scored_df[uplift_cols].max(axis=1).values + scored_df["suggested_treatment"] = np.where( + scored_df["uplift"].values <= 0, + control_name, + scored_df[uplift_cols].idxmax(axis=1).values, + ) + scored_df["suggested_treatment"] = ( + scored_df["suggested_treatment"] + .apply(lambda x: x.replace("__uplift", "")) + .values + ) + + return scored_df + + +def _get_model_fcn( + df: pd.DataFrame, + treatment_col: str, + treatment_name: str, + learner: Callable, +) -> Tuple[Callable, dict, dict]: + """ + Returns a function that predicts the target column from the features. + """ + + treatment_names = df[treatment_col].unique() + + if treatment_name not in treatment_names: + raise MissingTreatmentError() + + df = df.loc[df[treatment_col] == treatment_name].reset_index(drop=True).copy() + + return learner(df) + + +def _get_learners( + df: pd.DataFrame, + control_learner: Callable, + treatment_learner: Callable, + unique_treatments: List[str], + control_name: str, + treatment_col: str, +) -> Tuple[Dict[str, Callable], Dict[str, dict]]: + learners: Dict[str, Callable] = {} + logs: Dict[str, dict] = {} + + learner_fcn, _, learner_logs = _get_model_fcn( + df, treatment_col, control_name, control_learner + ) + learners[control_name] = learner_fcn + logs[control_name] = learner_logs + + for treatment_name in unique_treatments: + learner_fcn, _, learner_logs = _get_model_fcn( + df, treatment_col, treatment_name, treatment_learner + ) + learners[treatment_name] = learner_fcn + logs[treatment_name] = learner_logs + + return learners, logs + + +@curry +def causal_t_classification_learner( + df: pd.DataFrame, + treatment_col: str, + control_name: str, + prediction_column: str, + learner: LearnerFnType, + treatment_learner: LearnerFnType = None, + learner_transformers: List[LearnerFnType] = None, +) -> LearnerReturnType: + """ + Fits a Causal T-Learner classifier. The T-Learner is a meta-learner which learns the + Conditional Average Treatment Effect (CATE) through the use of one Machine Learning + model for each treatment and for the control group. Each model is fitted in a subset of + the data, according to the treatment: the CATE $\tau$ is defined as + $\tau(x_{i}) = M_{1}(X=x_{i}, T=1) - M_{0}(X=x_{i}, T=0)$, being $M_{1}$ a model fitted + with treatment data and $M_{0}$ a model fitted with control data. Notice that $M_{0}$ + and $M_{1}$ are traditional Machine Learning models such as a LightGBM Classifier and + that $x_{i}$ is the feature set of sample $i$. + + References: + [1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html + [2] https://causalml.readthedocs.io/en/latest/methodology.html + + Parameters + ---------- + + df : pd.DataFrame + A Pandas' DataFrame with features and target columns. + The model will be trained to predict the target column + from the features. + + treatment_col: str + The name of the column in `df` which contains the names of + the treatments and control to which each data sample was subjected. + + control_name: str + The name of the control group. + + prediction_column : str + The name of the column with the predictions from the provided learner. + + learner: LearnerFnType + A fklearn classification learner function. + + treatment_learner: LearnerFnType + An optional fklearn classification learner function. + + learner_transformers: List[LearnerFnType] + A list of fklearn transformer functions to be applied after the learner and before estimating the CATE. + This parameter may be useful, for example, to estimate the CATE with calibrated classifiers. + """ + + control_learner = copy.deepcopy(learner) + + if treatment_learner is None: + treatment_learner = copy.deepcopy(learner) + + # pipeline + if learner_transformers is not None: + learner_transformers = copy.deepcopy(learner_transformers) + control_learner_pipe = build_pipeline(*[control_learner] + learner_transformers) + + treatment_learner_pipe = build_pipeline( + *[treatment_learner] + learner_transformers + ) + else: + control_learner_pipe = copy.deepcopy(control_learner) + treatment_learner_pipe = copy.deepcopy(treatment_learner) + + # learners + unique_treatments = _get_unique_treatments(df, treatment_col, control_name) + + learners, learners_logs = _get_learners( + df=df, + control_learner=control_learner_pipe, + treatment_learner=treatment_learner_pipe, + unique_treatments=unique_treatments, + control_name=control_name, + treatment_col=treatment_col, + ) + + def p(new_df: pd.DataFrame) -> pd.DataFrame: + return _simulate_t_learner_treatment_effect( + new_df, + learners, + unique_treatments, + control_name, + prediction_column, + ) + + p.__doc__ = learner_pred_fn_docstring("causal_t_classification_learner") + + log = {"causal_t_classification_learner": {**learners_logs}} + + return p, p(df), log + + +causal_t_classification_learner.__doc__ = learner_return_docstring( + "Causal T-Learner Classifier" +) diff --git a/tests/causal/cate_learning/test_meta_learners.py b/tests/causal/cate_learning/test_meta_learners.py index 4d982623..dd7d297e 100644 --- a/tests/causal/cate_learning/test_meta_learners.py +++ b/tests/causal/cate_learning/test_meta_learners.py @@ -1,20 +1,45 @@ -from unittest.mock import create_autospec, patch +from typing import Callable +from unittest.mock import MagicMock, call, create_autospec, patch import numpy as np import pandas as pd import pytest +from pandas import DataFrame +from pandas.testing import assert_frame_equal + from fklearn.causal.cate_learning.meta_learners import ( TREATMENT_FEATURE, _append_treatment_feature, _create_treatment_flag, - _filter_by_treatment, _fit_by_treatment, _get_unique_treatments, - _predict_by_treatment_flag, _simulate_treatment_effect, - causal_s_classification_learner) + _filter_by_treatment, _fit_by_treatment, _get_learners, _get_model_fcn, + _get_unique_treatments, _predict_by_treatment_flag, + _simulate_t_learner_treatment_effect, _simulate_treatment_effect, + causal_s_classification_learner, causal_t_classification_learner) from fklearn.exceptions.exceptions import (MissingControlError, MissingTreatmentError, MultipleTreatmentsError) from fklearn.training.classification import logistic_classification_learner from fklearn.types import LearnerFnType -from pandas import DataFrame -from pandas.testing import assert_frame_equal + + +@pytest.fixture +def base_input_df(): + return pd.DataFrame( + { + "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0], + "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12], + "treatment": [ + "A", + "B", + "A", + "A", + "B", + "control", + "control", + "B", + "control", + ], + "target": [1, 1, 1, 0, 0, 1, 0, 0, 1], + } + ) def test__append_treatment_feature(): @@ -191,26 +216,7 @@ def test__create_treatment_flag(): assert_frame_equal(results, expected) -def test__fit_by_treatment(): - df = pd.DataFrame( - { - "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0], - "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12], - "treatment": [ - "A", - "B", - "A", - "A", - "B", - "control", - "control", - "B", - "control", - ], - "target": [1, 1, 1, 0, 0, 1, 0, 0, 1], - } - ) - +def test__fit_by_treatment(base_input_df): learner_binary = logistic_classification_learner( features=["x1", "x2", TREATMENT_FEATURE], target="target", @@ -220,7 +226,7 @@ def test__fit_by_treatment(): treatments = ["A", "B"] learners, logs = _fit_by_treatment( - df, + base_input_df, learner=learner_binary, treatment_col="treatment", control_name="control", @@ -352,27 +358,8 @@ def test_causal_s_classification_learner( mock_get_unique_treatments, mock_fit_by_treatment, mock_simulate_treatment_effect, + base_input_df, ): - - df = pd.DataFrame( - { - "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0], - "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12], - "treatment": [ - "A", - "B", - "A", - "A", - "B", - "control", - "control", - "B", - "control", - ], - "target": [1, 1, 1, 0, 0, 1, 0, 0, 1], - } - ) - mock_model = create_autospec(logistic_classification_learner) mock_fit_by_treatment.side_effect = [ # treatment = A @@ -382,7 +369,7 @@ def test_causal_s_classification_learner( ] causal_s_classification_learner( - df, + base_input_df, treatment_col="treatment", control_name="control", prediction_column="prediction", @@ -394,3 +381,186 @@ def test_causal_s_classification_learner( mock_get_unique_treatments.assert_called() mock_fit_by_treatment.assert_called() mock_simulate_treatment_effect.assert_called() + + +def test_simulate_t_learner_treatment_effect(): + df = pd.DataFrame( + { + "x1": [1.3, 1.0, 1.8, -0.1], + "x2": [10, 4, 15, 6], + "treatment": ["A", "B", "A", "control"], + "target": [0, 0, 0, 1], + } + ) + + treatments = ["A", "B"] + control_name = "control" + prediction_column = "prediction" + + control_learner = MagicMock() + control_learner.side_effect = lambda _: pd.DataFrame({"prediction": [1, 2, 3, 4]}) + + treatment_learner = MagicMock() + treatment_learner.side_effect = lambda _: pd.DataFrame({"prediction": [3, 2, 4, 4]}) + + learners = { + "control": control_learner, + "A": treatment_learner, + "B": treatment_learner, + } + + result = _simulate_t_learner_treatment_effect( + df, + learners, + treatments, + control_name, + prediction_column, + ) + + print(result.suggested_treatment) + + expected = pd.DataFrame( + { + "x1": [1.3, 1.0, 1.8, -0.1], + "x2": [10, 4, 15, 6], + "treatment": ["A", "B", "A", "control"], + "target": [0, 0, 0, 1], + "treatment_A__prediction_on_treatment": [3, 2, 4, 4], + "treatment_A__uplift": [2, 0, 1, 0], + "treatment_B__prediction_on_treatment": [3, 2, 4, 4], + "treatment_B__uplift": [2, 0, 1, 0], + "uplift": [2, 0, 1, 0], + "suggested_treatment": ["treatment_A", "control", "treatment_A", "control"], + } + ) + + assert isinstance(result, pd.DataFrame) + assert_frame_equal(result, expected) + + +def test_get_model_fcn(base_input_df): + """ + Test if the fn is filtering the data + Test if the learner is called with the filtered data + """ + + fake_prediction_column = [0.1, 0.2, 0.3] + df_expected = pd.DataFrame( + { + "x1": [1.3, 1.8, -0.1], + "x2": [10, 15, 6], + "treatment": [ + "A", + "A", + "A", + ], + "target": [1, 1, 0], + "prediction": fake_prediction_column, + } + ) + + def mock_learner(df): + df["prediction"] = fake_prediction_column + + return (lambda x: x, df, dict()) + + learner = MagicMock() + learner.side_effect = mock_learner + + mock_fcn, mock_p_df, mock_logs = _get_model_fcn( + base_input_df, "treatment", "A", learner + ) + + assert isinstance(mock_fcn, Callable) + assert_frame_equal(mock_p_df, df_expected) + assert isinstance(mock_logs, dict) + + +def test_get_model_fcn_exception(base_input_df): + """ + Test if the fn is raising an exception when treatment name + is not in treatment list. + """ + + fake_prediction_column = [0.1, 0.2, 0.3] + + def mock_learner(df): + df["prediction"] = fake_prediction_column + + return (lambda x: x, df, dict()) + + learner = MagicMock() + learner.side_effect = mock_learner + + with pytest.raises(Exception) as e: + _ = _get_model_fcn(base_input_df, "treatment", "C", learner) + + assert e.type == MissingTreatmentError + + +@patch("fklearn.causal.cate_learning.meta_learners._get_model_fcn") +def test_get_learners(mock_get_model_fcn): + """ + Test if it is receiving a list of treatments and is returning a dict + of learners. + """ + unique_treatments = ["treatment_a", "treatment_b", "treatment_c"] + + mock_get_model_fcn.side_effect = [ + ("mocked_control_fcn", None, None), + ("mocked_treatment_fcn_filtering_treatment_a", None, None), + ("mocked_treatment_fcn_filtering_treatment_b", None, None), + ("mocked_treatment_fcn_filtering_treatment_c", None, None), + ] + + learners, logs = _get_learners( + df="mocked_df", + unique_treatments=unique_treatments, + treatment_col="treatment", + control_name="control", + control_learner="mocked_control_fcn", + treatment_learner="mocked_treatment_fcn", + ) + + assert learners["control"] == "mocked_control_fcn" + assert learners["treatment_a"] == "mocked_treatment_fcn_filtering_treatment_a" + assert learners["treatment_b"] == "mocked_treatment_fcn_filtering_treatment_b" + assert learners["treatment_c"] == "mocked_treatment_fcn_filtering_treatment_c" + assert isinstance(learners, dict) + assert isinstance(logs, dict) + + calls = [ + call("mocked_df", "treatment", "control", "mocked_control_fcn"), + call("mocked_df", "treatment", "treatment_a", "mocked_treatment_fcn"), + call("mocked_df", "treatment", "treatment_b", "mocked_treatment_fcn"), + call("mocked_df", "treatment", "treatment_c", "mocked_treatment_fcn"), + ] + + mock_get_model_fcn.assert_has_calls(calls) + + +@patch( + "fklearn.causal.cate_learning.meta_learners._simulate_t_learner_treatment_effect" +) +@patch("fklearn.causal.cate_learning.meta_learners._get_learners") +@patch("fklearn.causal.cate_learning.meta_learners._get_unique_treatments") +def test_causal_t_classification_learner( + mock_get_unique_treatments, + mock_get_learners, + mock_simulate_t_learner_treatment_effect, + base_input_df, +): + mock_get_learners.side_effect = [([], dict())] + mock_model = create_autospec(logistic_classification_learner) + + causal_t_classification_learner( + df=base_input_df, + treatment_col="treatment", + control_name="control", + prediction_column="prediction", + learner=mock_model, + ) + + mock_get_unique_treatments.assert_called() + mock_get_learners.assert_called() + mock_simulate_t_learner_treatment_effect.assert_called()