diff --git a/docs/source/api/fklearn.causal.cate_learning.rst b/docs/source/api/fklearn.causal.cate_learning.rst
new file mode 100644
index 00000000..35888232
--- /dev/null
+++ b/docs/source/api/fklearn.causal.cate_learning.rst
@@ -0,0 +1,22 @@
+fklearn.causal.cate\_learning package
+=====================================
+
+Submodules
+----------
+
+fklearn.causal.cate\_learning.double\_machine\_learning module
+--------------------------------------------------------------
+
+.. automodule:: fklearn.causal.cate_learning.double_machine_learning
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+
+Module contents
+---------------
+
+.. automodule:: fklearn.causal.cate_learning
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/notebooks/causal_s_learner_demo.ipynb b/notebooks/causal_s_learner_demo.ipynb
index ab017263..d5cb5b8f 100644
--- a/notebooks/causal_s_learner_demo.ipynb
+++ b/notebooks/causal_s_learner_demo.ipynb
@@ -2439,4 +2439,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
+}
\ No newline at end of file
diff --git a/notebooks/causal_t_learner_demo.ipynb b/notebooks/causal_t_learner_demo.ipynb
new file mode 100644
index 00000000..fe848b5e
--- /dev/null
+++ b/notebooks/causal_t_learner_demo.ipynb
@@ -0,0 +1,2434 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "f84aeac6",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T17:38:44.522833Z",
+ "start_time": "2022-08-01T17:38:44.516580Z"
+ }
+ },
+ "source": [
+ "## Causal T-Learner Classifier\n",
+ "\n",
+ "### TL; DR\n",
+ "---\n",
+ "This notebooks exemplifies how one can use the causal T-learner through fklearn.\n",
+ "\n",
+ "### Long\n",
+ "---\n",
+ "A very interesting and useful type of causal models are uplift models, in which one is able to identify how each sample responds to a given treatment, and what is the effect of that treatment compared to a control group. The main goal of uplift models is, therefore, to learn the difference in probability of a sample converting (using a product), given that it was submitted to some action (nudge). Meta-learns are examples of causal models, in which the CATE represents how each unit will respond to a given treatment [1]. In fact, the uplift can be understood as the incremental gain in the conversion probability in the case a given sample was in the treatment group instead of the control one. In addition, these models have the advantage of using conventional machine learning models, such as LightGBM.\n",
+ "\n",
+ "More specifically, the T-Learner is a meta-learner which learns the Conditional Average Treatment Effect (CATE) through the use of multiple models, one for each treatment. Each model is fitted in a subset of the data, according to the treatment. The CATE $\\tau$ is defined as $\\tau(x_{i}) = M_{1}(X=x_{i}, T=1) - M_{0}(X=x_{i}, T=0)$, being $M_{1}$ a model fitted with treatment data and $M_{0}$ a model fitted with control data, and they can be a Machine Learning Model such as a LightGBM Classifier and $x_{i}$ the feature set of sample $i$.\n",
+ "\n",
+ "### Data\n",
+ "---\n",
+ "The data here adopted is provided in [1], [2].\n",
+ "\n",
+ "### References\n",
+ "---\n",
+ "\n",
+ "[1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html\n",
+ "\n",
+ "[2] https://github.com/matheusfacure/python-causality-handbook/tree/master/causal-inference-for-the-brave-and-true/data\n",
+ "\n",
+ "[3] https://causalml.readthedocs.io/en/latest/methodology.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "651c52e1",
+ "metadata": {},
+ "source": [
+ "### 1. Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "209dd195",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%reload_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "b9ae70be",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "909075ef",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'/Users/eduardo.souza/dev/nu/fklearn/src'"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "os.chdir(\"../src\")\n",
+ "os.getcwd()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "334c24ac",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:54.942876Z",
+ "start_time": "2022-08-01T21:08:52.183053Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "\n",
+ "from fklearn.causal.cate_learning.meta_learners import causal_t_classification_learner\n",
+ "from fklearn.training.classification import lgbm_classification_learner\n",
+ "from fklearn.training.calibration import isotonic_calibration_learner\n",
+ "from fklearn.causal.validation.curves import cumulative_gain_curve\n",
+ "from fklearn.training.pipeline import build_pipeline\n",
+ "from fklearn.training.transformation import ecdfer\n",
+ "\n",
+ "sns.set_style(\"darkgrid\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7951044d",
+ "metadata": {},
+ "source": [
+ "### 2. Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "3b3820cf",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:54.949172Z",
+ "start_time": "2022-08-01T21:08:54.944256Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plot_cumulative_gain_curve(\n",
+ " train_gain: np.ndarray,\n",
+ " test_gain: np.ndarray,\n",
+ " random_gain: np.ndarray,\n",
+ " fontsize: int = 16,\n",
+ " figsize: tuple = (15,5)\n",
+ ") -> None:\n",
+ " \"\"\"\n",
+ " Plots the cumulative gain curve.\n",
+ " \"\"\"\n",
+ "\n",
+ " xaxis = np.arange(len(train_gain))/len(train_gain)\n",
+ " \n",
+ " plt.figure(figsize=figsize);\n",
+ " plt.plot(xaxis, train_gain, label=\"Training Data\");\n",
+ " plt.plot(xaxis, test_gain, label=\"Testing Data\");\n",
+ " plt.plot(xaxis, random_gain, \"--\", label=\"Random\");\n",
+ " \n",
+ " plt.ylabel(\"Cumulative Gain\", fontsize=fontsize);\n",
+ " plt.xlabel(\"Population Proportion\", fontsize=fontsize);\n",
+ " plt.title(\"Cumulative Gain Curve\", fontsize=fontsize);\n",
+ " \n",
+ " plt.legend(fontsize=fontsize);\n",
+ " plt.xticks(fontsize=fontsize);\n",
+ " plt.yticks(fontsize=fontsize);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1bb4b598",
+ "metadata": {},
+ "source": [
+ "### 3. Read Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "99ced58c",
+ "metadata": {},
+ "source": [
+ "Notice that the data here adopted is provided in [1] and [2]."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "164af8a7",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.654415Z",
+ "start_time": "2022-08-01T21:08:54.952419Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "test_data = pd.read_csv(\n",
+ " \"https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/invest_email_rnd.csv\"\n",
+ ")\n",
+ "train_data = pd.read_csv(\n",
+ " \"https://raw.githubusercontent.com/matheusfacure/python-causality-handbook/master/causal-inference-for-the-brave-and-true/data/invest_email_biased.csv\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "c4100764",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.665495Z",
+ "start_time": "2022-08-01T21:08:55.656451Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(15000, 8)"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_data.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "9dbc21fd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.671480Z",
+ "start_time": "2022-08-01T21:08:55.667548Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(15000, 8)"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_data.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bfba7573",
+ "metadata": {},
+ "source": [
+ "#### 3.1 Include Treatment Column"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bed6542e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.684291Z",
+ "start_time": "2022-08-01T21:08:55.673873Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "train_data[\"control\"] = np.where((train_data[\"em1\"]+train_data[\"em2\"]+train_data[\"em3\"])==0,1,0)\n",
+ "train_data[\"treatment_col\"] = train_data[[\"em1\",\"em2\",\"em3\",\"control\"]].idxmax(axis=1).values\n",
+ "\n",
+ "test_data[\"control\"] = np.where((test_data[\"em1\"]+test_data[\"em2\"]+test_data[\"em3\"])==0,1,0)\n",
+ "test_data[\"treatment_col\"] = test_data[[\"em1\",\"em2\",\"em3\",\"control\"]].idxmax(axis=1).values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "b951bc0c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.701127Z",
+ "start_time": "2022-08-01T21:08:55.686721Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em3 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em3 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " control \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n",
+ "\n",
+ " treatment_col \n",
+ "0 em3 \n",
+ "1 em1 \n",
+ "2 em3 \n",
+ "3 control \n",
+ "4 em1 "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_data.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "2dc0eb3e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.717300Z",
+ "start_time": "2022-08-01T21:08:55.703374Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em2 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 1 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 1 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 1 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 1 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 1 1 1 0 0 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 1 0 0 \n",
+ "\n",
+ " treatment_col \n",
+ "0 em2 \n",
+ "1 em1 \n",
+ "2 em1 \n",
+ "3 em1 \n",
+ "4 em1 "
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_data.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b02dd80c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T17:34:49.478324Z",
+ "start_time": "2022-08-01T17:34:49.472616Z"
+ }
+ },
+ "source": [
+ "### 4. Causal T-Learner"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "f2e41cb9",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.727534Z",
+ "start_time": "2022-08-01T21:08:55.724439Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "target_column = \"converted\"\n",
+ "features = [\"age\", \"income\", \"insurance\", \"invested\"]\n",
+ "treatment_column = \"treatment_col\"\n",
+ "control_name = \"control\"\n",
+ "prediction_column = \"prediction\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f13a389d",
+ "metadata": {},
+ "source": [
+ "#### 4.1 Using T-Learner with LightGBM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "86154394",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.732761Z",
+ "start_time": "2022-08-01T21:08:55.729735Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "clf_learner = lgbm_classification_learner(\n",
+ " features = features,\n",
+ " target = target_column,\n",
+ " prediction_column = prediction_column\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "885a539f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:55.738793Z",
+ "start_time": "2022-08-01T21:08:55.735586Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "t_learner = causal_t_classification_learner(\n",
+ " treatment_col=treatment_column,\n",
+ " control_name=control_name,\n",
+ " prediction_column=prediction_column,\n",
+ " learner=clf_learner,\n",
+ " treatment_learner=clf_learner,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e2d4cd23",
+ "metadata": {},
+ "source": [
+ "**Training the model**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "68c21ad5",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.027747Z",
+ "start_time": "2022-08-01T21:08:55.740932Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000434 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 963\n",
+ "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n",
+ "[LightGBM] [Info] Start training from score -2.022283\n",
+ "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000334 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 898\n",
+ "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n",
+ "[LightGBM] [Info] Start training from score -1.296143\n",
+ "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000364 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 958\n",
+ "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n",
+ "[LightGBM] [Info] Start training from score -1.163774\n",
+ "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000133 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 906\n",
+ "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n",
+ "[LightGBM] [Info] Start training from score -0.537647\n",
+ "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ }
+ ],
+ "source": [
+ "t_learner_fcn, t_learner_train_df, t_learner_log = t_learner(train_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "629ee628",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.034377Z",
+ "start_time": "2022-08-01T21:08:56.030279Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ".p(new_df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame>"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t_learner_fcn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "bb863168",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.051585Z",
+ "start_time": "2022-08-01T21:08:56.036919Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " treatment_em3__prediction_on_treatment \n",
+ " treatment_em3__uplift \n",
+ " treatment_em1__prediction_on_treatment \n",
+ " treatment_em1__uplift \n",
+ " treatment_em2__prediction_on_treatment \n",
+ " treatment_em2__uplift \n",
+ " uplift \n",
+ " suggested_treatment \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.259551 \n",
+ " 0.079825 \n",
+ " 0.447992 \n",
+ " 0.268265 \n",
+ " 0.308317 \n",
+ " 0.128590 \n",
+ " 0.268265 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.091645 \n",
+ " -0.003786 \n",
+ " 0.063676 \n",
+ " -0.031755 \n",
+ " 0.015090 \n",
+ " -0.080341 \n",
+ " -0.003786 \n",
+ " control \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.464448 \n",
+ " 0.418073 \n",
+ " 0.682684 \n",
+ " 0.636309 \n",
+ " 0.061211 \n",
+ " 0.014836 \n",
+ " 0.636309 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " control \n",
+ " 0.086145 \n",
+ " 0.036714 \n",
+ " 0.256501 \n",
+ " 0.207071 \n",
+ " 0.196768 \n",
+ " 0.147338 \n",
+ " 0.207071 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.149771 \n",
+ " 0.140358 \n",
+ " 0.201414 \n",
+ " 0.192001 \n",
+ " 0.082254 \n",
+ " 0.072842 \n",
+ " 0.192001 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n",
+ "\n",
+ " treatment_col treatment_em3__prediction_on_treatment \\\n",
+ "0 em3 0.259551 \n",
+ "1 em1 0.091645 \n",
+ "2 em3 0.464448 \n",
+ "3 control 0.086145 \n",
+ "4 em1 0.149771 \n",
+ "\n",
+ " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n",
+ "0 0.079825 0.447992 \n",
+ "1 -0.003786 0.063676 \n",
+ "2 0.418073 0.682684 \n",
+ "3 0.036714 0.256501 \n",
+ "4 0.140358 0.201414 \n",
+ "\n",
+ " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n",
+ "0 0.268265 0.308317 \n",
+ "1 -0.031755 0.015090 \n",
+ "2 0.636309 0.061211 \n",
+ "3 0.207071 0.196768 \n",
+ "4 0.192001 0.082254 \n",
+ "\n",
+ " treatment_em2__uplift uplift suggested_treatment \n",
+ "0 0.128590 0.268265 treatment_em1 \n",
+ "1 -0.080341 -0.003786 control \n",
+ "2 0.014836 0.636309 treatment_em1 \n",
+ "3 0.147338 0.207071 treatment_em1 \n",
+ "4 0.072842 0.192001 treatment_em1 "
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t_learner_train_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "ae0405f3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.059283Z",
+ "start_time": "2022-08-01T21:08:56.054044Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'causal_t_classification_learner': {'control': {'lgbm_classification_learner': {'features': ['age',\n",
+ " 'income',\n",
+ " 'insurance',\n",
+ " 'invested'],\n",
+ " 'target': 'converted',\n",
+ " 'prediction_column': 'prediction',\n",
+ " 'package': 'lightgbm',\n",
+ " 'package_version': '3.3.2',\n",
+ " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n",
+ " 'feature_importance': {'age': 739,\n",
+ " 'income': 811,\n",
+ " 'insurance': 724,\n",
+ " 'invested': 726},\n",
+ " 'training_samples': 4312,\n",
+ " 'running_time': '0.087 s'},\n",
+ " 'object': },\n",
+ " 'em3': {'lgbm_classification_learner': {'features': ['age',\n",
+ " 'income',\n",
+ " 'insurance',\n",
+ " 'invested'],\n",
+ " 'target': 'converted',\n",
+ " 'prediction_column': 'prediction',\n",
+ " 'package': 'lightgbm',\n",
+ " 'package_version': '3.3.2',\n",
+ " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n",
+ " 'feature_importance': {'age': 627,\n",
+ " 'income': 772,\n",
+ " 'insurance': 816,\n",
+ " 'invested': 785},\n",
+ " 'training_samples': 6480,\n",
+ " 'running_time': '0.064 s'},\n",
+ " 'object': },\n",
+ " 'em1': {'lgbm_classification_learner': {'features': ['age',\n",
+ " 'income',\n",
+ " 'insurance',\n",
+ " 'invested'],\n",
+ " 'target': 'converted',\n",
+ " 'prediction_column': 'prediction',\n",
+ " 'package': 'lightgbm',\n",
+ " 'package_version': '3.3.2',\n",
+ " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n",
+ " 'feature_importance': {'age': 769,\n",
+ " 'income': 747,\n",
+ " 'insurance': 748,\n",
+ " 'invested': 736},\n",
+ " 'training_samples': 3370,\n",
+ " 'running_time': '0.051 s'},\n",
+ " 'object': },\n",
+ " 'em2': {'lgbm_classification_learner': {'features': ['age',\n",
+ " 'income',\n",
+ " 'insurance',\n",
+ " 'invested'],\n",
+ " 'target': 'converted',\n",
+ " 'prediction_column': 'prediction',\n",
+ " 'package': 'lightgbm',\n",
+ " 'package_version': '3.3.2',\n",
+ " 'parameters': {'eta': 0.1, 'objective': 'binary', 'num_estimators': 100},\n",
+ " 'feature_importance': {'age': 618,\n",
+ " 'income': 745,\n",
+ " 'insurance': 801,\n",
+ " 'invested': 835},\n",
+ " 'training_samples': 838,\n",
+ " 'running_time': '0.085 s'},\n",
+ " 'object': }}}"
+ ]
+ },
+ "execution_count": 63,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t_learner_log"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "2ed58ac7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "score_of_interest = \"treatment_em1__uplift\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dadfd813",
+ "metadata": {},
+ "source": [
+ "**Making Predictions**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "19c87675",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.138488Z",
+ "start_time": "2022-08-01T21:08:56.061810Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "t_learner_test_df = t_learner_fcn(test_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "d992f3c8",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.155056Z",
+ "start_time": "2022-08-01T21:08:56.141081Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " treatment_em3__prediction_on_treatment \n",
+ " treatment_em3__uplift \n",
+ " treatment_em1__prediction_on_treatment \n",
+ " treatment_em1__uplift \n",
+ " treatment_em2__prediction_on_treatment \n",
+ " treatment_em2__uplift \n",
+ " uplift \n",
+ " suggested_treatment \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em2 \n",
+ " 0.259551 \n",
+ " 0.079825 \n",
+ " 0.447992 \n",
+ " 0.268265 \n",
+ " 0.308317 \n",
+ " 0.128590 \n",
+ " 0.268265 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.091645 \n",
+ " -0.003786 \n",
+ " 0.063676 \n",
+ " -0.031755 \n",
+ " 0.015090 \n",
+ " -0.080341 \n",
+ " -0.003786 \n",
+ " control \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 1 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.464448 \n",
+ " 0.418073 \n",
+ " 0.682684 \n",
+ " 0.636309 \n",
+ " 0.061211 \n",
+ " 0.014836 \n",
+ " 0.636309 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 1 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.086145 \n",
+ " 0.036714 \n",
+ " 0.256501 \n",
+ " 0.207071 \n",
+ " 0.196768 \n",
+ " 0.147338 \n",
+ " 0.207071 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.149771 \n",
+ " 0.140358 \n",
+ " 0.201414 \n",
+ " 0.192001 \n",
+ " 0.082254 \n",
+ " 0.072842 \n",
+ " 0.192001 \n",
+ " treatment_em1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 1 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 1 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 1 1 1 0 0 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 1 0 0 \n",
+ "\n",
+ " treatment_col treatment_em3__prediction_on_treatment \\\n",
+ "0 em2 0.259551 \n",
+ "1 em1 0.091645 \n",
+ "2 em1 0.464448 \n",
+ "3 em1 0.086145 \n",
+ "4 em1 0.149771 \n",
+ "\n",
+ " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n",
+ "0 0.079825 0.447992 \n",
+ "1 -0.003786 0.063676 \n",
+ "2 0.418073 0.682684 \n",
+ "3 0.036714 0.256501 \n",
+ "4 0.140358 0.201414 \n",
+ "\n",
+ " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n",
+ "0 0.268265 0.308317 \n",
+ "1 -0.031755 0.015090 \n",
+ "2 0.636309 0.061211 \n",
+ "3 0.207071 0.196768 \n",
+ "4 0.192001 0.082254 \n",
+ "\n",
+ " treatment_em2__uplift uplift suggested_treatment \n",
+ "0 0.128590 0.268265 treatment_em1 \n",
+ "1 -0.080341 -0.003786 control \n",
+ "2 0.014836 0.636309 treatment_em1 \n",
+ "3 0.147338 0.207071 treatment_em1 \n",
+ "4 0.072842 0.192001 treatment_em1 "
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t_learner_test_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "60a1bfb9",
+ "metadata": {},
+ "source": [
+ "**Create Random Score**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "196a3a23",
+ "metadata": {},
+ "source": [
+ "Let's also create a random score that can be used to compute the Cumulative Gain curve."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "61728dd6",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.162206Z",
+ "start_time": "2022-08-01T21:08:56.157333Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "random_score_df = test_data[[\"em1\", target_column]].copy()\n",
+ "random_score_df[score_of_interest] = np.random.uniform(0,1,random_score_df.shape[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5b65182",
+ "metadata": {},
+ "source": [
+ "**Checking Cumulative Gain Curve**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c18d8dc7",
+ "metadata": {},
+ "source": [
+ "For more details about causal models evaluation, please look at the following reference:\n",
+ "\n",
+ "https://matheusfacure.github.io/python-causality-handbook/19-Evaluating-Causal-Models.html?highlight=gain%20curve "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "dada0b5c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.167642Z",
+ "start_time": "2022-08-01T21:08:56.164515Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve = cumulative_gain_curve(\n",
+ " treatment = \"em1\",\n",
+ " outcome = target_column,\n",
+ " prediction = score_of_interest\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "44ac0567",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.479996Z",
+ "start_time": "2022-08-01T21:08:56.170218Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve_train = gain_curve(t_learner_train_df)\n",
+ "gain_curve_test = gain_curve(t_learner_test_df)\n",
+ "gain_curve_random = gain_curve(random_score_df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "e0f39cfc",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.695829Z",
+ "start_time": "2022-08-01T21:08:56.481931Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot_cumulative_gain_curve(\n",
+ " train_gain = gain_curve_train,\n",
+ " test_gain = gain_curve_test,\n",
+ " random_gain = gain_curve_random\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "16b67daa",
+ "metadata": {},
+ "source": [
+ "#### 4.2 Using T-Learner with Fklearn Pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "5adc2f2f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.701600Z",
+ "start_time": "2022-08-01T21:08:56.698546Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "cdf = ecdfer(\n",
+ " prediction_column=score_of_interest\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "a995e275",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.705705Z",
+ "start_time": "2022-08-01T21:08:56.703511Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "pipeline = build_pipeline(\n",
+ " *[t_learner, cdf]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "21db9ef8",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.885605Z",
+ "start_time": "2022-08-01T21:08:56.707835Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000739 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 963\n",
+ "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n",
+ "[LightGBM] [Info] Start training from score -2.022283\n",
+ "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000108 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 898\n",
+ "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n",
+ "[LightGBM] [Info] Start training from score -1.296143\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000157 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 958\n",
+ "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n",
+ "[LightGBM] [Info] Start training from score -1.163774\n",
+ "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000331 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 906\n",
+ "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n",
+ "[LightGBM] [Info] Start training from score -0.537647\n",
+ "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipe_fcn, pipe_train_df, pipe_log = pipeline(\n",
+ " train_data\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "08b227a5",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:56.901902Z",
+ "start_time": "2022-08-01T21:08:56.887944Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " treatment_em3__prediction_on_treatment \n",
+ " treatment_em3__uplift \n",
+ " treatment_em1__prediction_on_treatment \n",
+ " treatment_em1__uplift \n",
+ " treatment_em2__prediction_on_treatment \n",
+ " treatment_em2__uplift \n",
+ " uplift \n",
+ " suggested_treatment \n",
+ " prediction_ecdf \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.259551 \n",
+ " 0.079825 \n",
+ " 0.447992 \n",
+ " 0.268265 \n",
+ " 0.308317 \n",
+ " 0.128590 \n",
+ " 0.268265 \n",
+ " treatment_em1 \n",
+ " 754.266667 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.091645 \n",
+ " -0.003786 \n",
+ " 0.063676 \n",
+ " -0.031755 \n",
+ " 0.015090 \n",
+ " -0.080341 \n",
+ " -0.003786 \n",
+ " control \n",
+ " 110.066667 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.464448 \n",
+ " 0.418073 \n",
+ " 0.682684 \n",
+ " 0.636309 \n",
+ " 0.061211 \n",
+ " 0.014836 \n",
+ " 0.636309 \n",
+ " treatment_em1 \n",
+ " 989.133333 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " control \n",
+ " 0.086145 \n",
+ " 0.036714 \n",
+ " 0.256501 \n",
+ " 0.207071 \n",
+ " 0.196768 \n",
+ " 0.147338 \n",
+ " 0.207071 \n",
+ " treatment_em1 \n",
+ " 653.533333 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.149771 \n",
+ " 0.140358 \n",
+ " 0.201414 \n",
+ " 0.192001 \n",
+ " 0.082254 \n",
+ " 0.072842 \n",
+ " 0.192001 \n",
+ " treatment_em1 \n",
+ " 622.666667 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n",
+ "\n",
+ " treatment_col treatment_em3__prediction_on_treatment \\\n",
+ "0 em3 0.259551 \n",
+ "1 em1 0.091645 \n",
+ "2 em3 0.464448 \n",
+ "3 control 0.086145 \n",
+ "4 em1 0.149771 \n",
+ "\n",
+ " treatment_em3__uplift treatment_em1__prediction_on_treatment \\\n",
+ "0 0.079825 0.447992 \n",
+ "1 -0.003786 0.063676 \n",
+ "2 0.418073 0.682684 \n",
+ "3 0.036714 0.256501 \n",
+ "4 0.140358 0.201414 \n",
+ "\n",
+ " treatment_em1__uplift treatment_em2__prediction_on_treatment \\\n",
+ "0 0.268265 0.308317 \n",
+ "1 -0.031755 0.015090 \n",
+ "2 0.636309 0.061211 \n",
+ "3 0.207071 0.196768 \n",
+ "4 0.192001 0.082254 \n",
+ "\n",
+ " treatment_em2__uplift uplift suggested_treatment prediction_ecdf \n",
+ "0 0.128590 0.268265 treatment_em1 754.266667 \n",
+ "1 -0.080341 -0.003786 control 110.066667 \n",
+ "2 0.014836 0.636309 treatment_em1 989.133333 \n",
+ "3 0.147338 0.207071 treatment_em1 653.533333 \n",
+ "4 0.072842 0.192001 treatment_em1 622.666667 "
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipe_train_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "id": "3f0ae64f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.152815Z",
+ "start_time": "2022-08-01T21:08:56.904019Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "pipe_train_df[\"uplift\"].hist(bins=50);\n",
+ "plt.title(\"Uplift Score Distribution\", fontsize=16);\n",
+ "plt.xticks(fontsize=16);\n",
+ "plt.yticks(fontsize=16);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "80873d8e",
+ "metadata": {},
+ "source": [
+ "**Checking Gain Curve with ECDF**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "id": "304ce7c2",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.405122Z",
+ "start_time": "2022-08-01T21:08:57.402884Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve = cumulative_gain_curve(\n",
+ " treatment = \"em1\",\n",
+ " outcome = target_column,\n",
+ " prediction = \"prediction_ecdf\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "id": "875df4da",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.536943Z",
+ "start_time": "2022-08-01T21:08:57.406693Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "pipeline_train_df = pipe_fcn(train_data)\n",
+ "pipeline_test_df = pipe_fcn(test_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "id": "ed3ee7e7",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.747225Z",
+ "start_time": "2022-08-01T21:08:57.539511Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve_train = gain_curve(pipeline_train_df)\n",
+ "gain_curve_test = gain_curve(pipeline_test_df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "id": "af2707ed",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.939012Z",
+ "start_time": "2022-08-01T21:08:57.749300Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot_cumulative_gain_curve(\n",
+ " train_gain = gain_curve_train,\n",
+ " test_gain = gain_curve_test,\n",
+ " random_gain = gain_curve_random\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7ee410fb",
+ "metadata": {},
+ "source": [
+ "#### 4.3 Build T-learner Using Lightgbm with Isotonic Calibration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "ea34b176",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.944825Z",
+ "start_time": "2022-08-01T21:08:57.941394Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "clf_learner = lgbm_classification_learner(\n",
+ " features = features,\n",
+ " target = target_column,\n",
+ " prediction_column = prediction_column\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "f328d84c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.950805Z",
+ "start_time": "2022-08-01T21:08:57.947489Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "calibrator = isotonic_calibration_learner(\n",
+ " target_column=target_column,\n",
+ " prediction_column=prediction_column,\n",
+ " output_column=\"calibration_prediction\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "9abd3945",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.955102Z",
+ "start_time": "2022-08-01T21:08:57.952772Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "t_learner = causal_t_classification_learner(\n",
+ " treatment_col=treatment_column,\n",
+ " control_name=control_name,\n",
+ " prediction_column=\"calibration_prediction\",\n",
+ " learner=clf_learner,\n",
+ " learner_transformers=[calibrator]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "47630533",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.959571Z",
+ "start_time": "2022-08-01T21:08:57.957001Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "cdf = ecdfer(\n",
+ " prediction_column=score_of_interest\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "9dfb55d3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:57.964743Z",
+ "start_time": "2022-08-01T21:08:57.961573Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "pipeline = build_pipeline(\n",
+ " *[t_learner, cdf]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "bb117302",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.188382Z",
+ "start_time": "2022-08-01T21:08:57.968255Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[LightGBM] [Info] Number of positive: 504, number of negative: 3808\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000469 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 963\n",
+ "[LightGBM] [Info] Number of data points in the train set: 4312, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.116883 -> initscore=-2.022283\n",
+ "[LightGBM] [Info] Start training from score -2.022283\n",
+ "[LightGBM] [Info] Number of positive: 1392, number of negative: 5088\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000172 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 898\n",
+ "[LightGBM] [Info] Number of data points in the train set: 6480, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.214815 -> initscore=-1.296143\n",
+ "[LightGBM] [Info] Start training from score -1.296143\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
+ "/Users/eduardo.souza/dev/nu/fklearn/venv/lib/python3.9/site-packages/lightgbm/basic.py:1491: UserWarning: 'silent' argument is deprecated and will be removed in a future release of LightGBM. Pass 'verbose' parameter via 'params' instead.\n",
+ " _log_warning(\"'silent' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[LightGBM] [Info] Number of positive: 802, number of negative: 2568\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000512 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 958\n",
+ "[LightGBM] [Info] Number of data points in the train set: 3370, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.237982 -> initscore=-1.163774\n",
+ "[LightGBM] [Info] Start training from score -1.163774\n",
+ "[LightGBM] [Info] Number of positive: 309, number of negative: 529\n",
+ "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000282 seconds.\n",
+ "You can set `force_col_wise=true` to remove the overhead.\n",
+ "[LightGBM] [Info] Total Bins 906\n",
+ "[LightGBM] [Info] Number of data points in the train set: 838, number of used features: 4\n",
+ "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.368735 -> initscore=-0.537647\n",
+ "[LightGBM] [Info] Start training from score -0.537647\n",
+ "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipe_fcn, pipe_train_df, pipe_log = pipeline(\n",
+ " train_data\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "0a80c08c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.205741Z",
+ "start_time": "2022-08-01T21:08:58.190857Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " age \n",
+ " income \n",
+ " insurance \n",
+ " invested \n",
+ " em1 \n",
+ " em2 \n",
+ " em3 \n",
+ " converted \n",
+ " control \n",
+ " treatment_col \n",
+ " treatment_em3__calibration_prediction_on_treatment \n",
+ " treatment_em3__uplift \n",
+ " treatment_em1__calibration_prediction_on_treatment \n",
+ " treatment_em1__uplift \n",
+ " treatment_em2__calibration_prediction_on_treatment \n",
+ " treatment_em2__uplift \n",
+ " uplift \n",
+ " suggested_treatment \n",
+ " prediction_ecdf \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 44.1 \n",
+ " 5483.80 \n",
+ " 6155.29 \n",
+ " 14294.81 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.338983 \n",
+ " 0.200429 \n",
+ " 0.820755 \n",
+ " 0.682201 \n",
+ " 0.0 \n",
+ " -0.138554 \n",
+ " 0.682201 \n",
+ " treatment_em1 \n",
+ " 813.133333 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 39.8 \n",
+ " 2737.92 \n",
+ " 50069.40 \n",
+ " 7468.15 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.000000 \n",
+ " -0.001727 \n",
+ " 0.000000 \n",
+ " -0.001727 \n",
+ " 0.0 \n",
+ " -0.001727 \n",
+ " -0.001727 \n",
+ " control \n",
+ " 225.400000 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 49.0 \n",
+ " 2712.51 \n",
+ " 5707.08 \n",
+ " 5095.65 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " em3 \n",
+ " 0.883333 \n",
+ " 0.883333 \n",
+ " 0.982558 \n",
+ " 0.982558 \n",
+ " 0.0 \n",
+ " 0.000000 \n",
+ " 0.982558 \n",
+ " treatment_em1 \n",
+ " 997.333333 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 39.7 \n",
+ " 2326.37 \n",
+ " 15657.97 \n",
+ " 6345.20 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " control \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.140625 \n",
+ " 0.140625 \n",
+ " 0.0 \n",
+ " 0.000000 \n",
+ " 0.140625 \n",
+ " treatment_em1 \n",
+ " 606.133333 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 35.3 \n",
+ " 2787.26 \n",
+ " 27074.44 \n",
+ " 14114.86 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " em1 \n",
+ " 0.010050 \n",
+ " 0.010050 \n",
+ " 0.033951 \n",
+ " 0.033951 \n",
+ " 0.0 \n",
+ " 0.000000 \n",
+ " 0.033951 \n",
+ " treatment_em1 \n",
+ " 518.400000 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age income insurance invested em1 em2 em3 converted control \\\n",
+ "0 44.1 5483.80 6155.29 14294.81 0 0 1 0 0 \n",
+ "1 39.8 2737.92 50069.40 7468.15 1 0 0 0 0 \n",
+ "2 49.0 2712.51 5707.08 5095.65 0 0 1 1 0 \n",
+ "3 39.7 2326.37 15657.97 6345.20 0 0 0 0 1 \n",
+ "4 35.3 2787.26 27074.44 14114.86 1 1 0 0 0 \n",
+ "\n",
+ " treatment_col treatment_em3__calibration_prediction_on_treatment \\\n",
+ "0 em3 0.338983 \n",
+ "1 em1 0.000000 \n",
+ "2 em3 0.883333 \n",
+ "3 control 0.000000 \n",
+ "4 em1 0.010050 \n",
+ "\n",
+ " treatment_em3__uplift treatment_em1__calibration_prediction_on_treatment \\\n",
+ "0 0.200429 0.820755 \n",
+ "1 -0.001727 0.000000 \n",
+ "2 0.883333 0.982558 \n",
+ "3 0.000000 0.140625 \n",
+ "4 0.010050 0.033951 \n",
+ "\n",
+ " treatment_em1__uplift treatment_em2__calibration_prediction_on_treatment \\\n",
+ "0 0.682201 0.0 \n",
+ "1 -0.001727 0.0 \n",
+ "2 0.982558 0.0 \n",
+ "3 0.140625 0.0 \n",
+ "4 0.033951 0.0 \n",
+ "\n",
+ " treatment_em2__uplift uplift suggested_treatment prediction_ecdf \n",
+ "0 -0.138554 0.682201 treatment_em1 813.133333 \n",
+ "1 -0.001727 -0.001727 control 225.400000 \n",
+ "2 0.000000 0.982558 treatment_em1 997.333333 \n",
+ "3 0.000000 0.140625 treatment_em1 606.133333 \n",
+ "4 0.000000 0.033951 treatment_em1 518.400000 "
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipe_train_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "id": "65693eaa",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.442712Z",
+ "start_time": "2022-08-01T21:08:58.208095Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "pipe_train_df[\"treatment_em1__calibration_prediction_on_treatment\"].hist(bins=10);\n",
+ "plt.title(\"Uplift with Calibrated LightGBM Distribution\", fontsize=16);\n",
+ "plt.xticks(fontsize=16);\n",
+ "plt.yticks(fontsize=16);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bcf7861e",
+ "metadata": {},
+ "source": [
+ "**Checking Gain Curve with ECDF**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "id": "08cfa1e4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.613042Z",
+ "start_time": "2022-08-01T21:08:58.610279Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve = cumulative_gain_curve(\n",
+ " treatment = \"em1\",\n",
+ " outcome = target_column,\n",
+ " prediction = \"prediction_ecdf\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "id": "e996c832",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.755011Z",
+ "start_time": "2022-08-01T21:08:58.614934Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "pipeline_train_df = pipe_fcn(train_data)\n",
+ "pipeline_test_df = pipe_fcn(test_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "id": "159c3e84",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:58.968608Z",
+ "start_time": "2022-08-01T21:08:58.757109Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "gain_curve_train = gain_curve(pipeline_train_df)\n",
+ "gain_curve_test = gain_curve(pipeline_test_df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "id": "d4dc9022",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-01T21:08:59.164538Z",
+ "start_time": "2022-08-01T21:08:58.970696Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot_cumulative_gain_curve(\n",
+ " train_gain = gain_curve_train,\n",
+ " test_gain = gain_curve_test,\n",
+ " random_gain = gain_curve_random\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.9.12 ('venv': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": false,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "7992107b5676fcc2947e8c872ad2cde6c7066279bc6e7ee23218548d8b4682e3"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/src/fklearn/causal/cate_learning/meta_learners.py b/src/fklearn/causal/cate_learning/meta_learners.py
index 73719dd7..039afde7 100644
--- a/src/fklearn/causal/cate_learning/meta_learners.py
+++ b/src/fklearn/causal/cate_learning/meta_learners.py
@@ -1,6 +1,6 @@
import copy
import inspect
-from typing import Callable, List, Tuple
+from typing import Callable, Dict, List, Tuple
import numpy as np
import pandas as pd
@@ -185,32 +185,24 @@ def causal_s_classification_learner(
of a new sample for both scenarios, i.e., with T = 0 and T = 1. The CATE τ
is defined as τ(xi) = M(X=xi, T=1) - M(X=xi, T=0), being M a Machine Learning
Model.
-
References:
[1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html
[2] https://causalml.readthedocs.io/en/latest/methodology.html
-
Parameters
----------
-
df : pd.DataFrame
A Pandas' DataFrame with features and target columns.
The model will be trained to predict the target column
from the features.
-
treatment_col: str
The name of the column in `df` which contains the names of
the treatments or control to which each data sample was subjected.
-
control_name: str
The name of the control group.
-
prediction_column : str
The name of the column with the predictions from the provided learner.
-
learner: Callable
A fklearn classification learner function.
-
learner_transformers: list
A list of fklearn transformer functions to be applied after the learner and before estimating the CATE.
This parameter may be useful, for example, to estimate the CATE with calibrated classifiers.
@@ -264,3 +256,194 @@ def p(new_df: pd.DataFrame) -> pd.DataFrame:
causal_s_classification_learner.__doc__ += learner_return_docstring(
"Causal S-Learner Classifier"
)
+
+
+def _simulate_t_learner_treatment_effect(
+ df: pd.DataFrame,
+ learners: dict,
+ treatments: list,
+ control_name: str,
+ prediction_column: str,
+) -> pd.DataFrame:
+ control_fcn = learners[control_name]
+ control_conversion_probability = control_fcn(df)[prediction_column].values
+
+ scored_df = df.copy()
+
+ uplift_cols = []
+ for treatment_name in treatments:
+ treatment_fcn = learners[treatment_name]
+ treatment_conversion_probability = treatment_fcn(df)[prediction_column].values
+
+ scored_df[
+ f"treatment_{treatment_name}__{prediction_column}_on_treatment"
+ ] = treatment_conversion_probability
+
+ uplift_cols.append(f"treatment_{treatment_name}__uplift")
+ scored_df[uplift_cols[-1]] = (
+ treatment_conversion_probability - control_conversion_probability
+ )
+
+ scored_df["uplift"] = scored_df[uplift_cols].max(axis=1).values
+ scored_df["suggested_treatment"] = np.where(
+ scored_df["uplift"].values <= 0,
+ control_name,
+ scored_df[uplift_cols].idxmax(axis=1).values,
+ )
+ scored_df["suggested_treatment"] = (
+ scored_df["suggested_treatment"]
+ .apply(lambda x: x.replace("__uplift", ""))
+ .values
+ )
+
+ return scored_df
+
+
+def _get_model_fcn(
+ df: pd.DataFrame,
+ treatment_col: str,
+ treatment_name: str,
+ learner: Callable,
+) -> Tuple[Callable, dict, dict]:
+ """
+ Returns a function that predicts the target column from the features.
+ """
+
+ treatment_names = df[treatment_col].unique()
+
+ if treatment_name not in treatment_names:
+ raise MissingTreatmentError()
+
+ df = df.loc[df[treatment_col] == treatment_name].reset_index(drop=True).copy()
+
+ return learner(df)
+
+
+def _get_learners(
+ df: pd.DataFrame,
+ control_learner: Callable,
+ treatment_learner: Callable,
+ unique_treatments: List[str],
+ control_name: str,
+ treatment_col: str,
+) -> Tuple[Dict[str, Callable], Dict[str, dict]]:
+ learners: Dict[str, Callable] = {}
+ logs: Dict[str, dict] = {}
+
+ learner_fcn, _, learner_logs = _get_model_fcn(
+ df, treatment_col, control_name, control_learner
+ )
+ learners[control_name] = learner_fcn
+ logs[control_name] = learner_logs
+
+ for treatment_name in unique_treatments:
+ learner_fcn, _, learner_logs = _get_model_fcn(
+ df, treatment_col, treatment_name, treatment_learner
+ )
+ learners[treatment_name] = learner_fcn
+ logs[treatment_name] = learner_logs
+
+ return learners, logs
+
+
+@curry
+def causal_t_classification_learner(
+ df: pd.DataFrame,
+ treatment_col: str,
+ control_name: str,
+ prediction_column: str,
+ learner: LearnerFnType,
+ treatment_learner: LearnerFnType = None,
+ learner_transformers: List[LearnerFnType] = None,
+) -> LearnerReturnType:
+ """
+ Fits a Causal T-Learner classifier. The T-Learner is a meta-learner which learns the
+ Conditional Average Treatment Effect (CATE) through the use of one Machine Learning
+ model for each treatment and for the control group. Each model is fitted in a subset of
+ the data, according to the treatment: the CATE $\tau$ is defined as
+ $\tau(x_{i}) = M_{1}(X=x_{i}, T=1) - M_{0}(X=x_{i}, T=0)$, being $M_{1}$ a model fitted
+ with treatment data and $M_{0}$ a model fitted with control data. Notice that $M_{0}$
+ and $M_{1}$ are traditional Machine Learning models such as a LightGBM Classifier and
+ that $x_{i}$ is the feature set of sample $i$.
+
+ References:
+ [1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html
+ [2] https://causalml.readthedocs.io/en/latest/methodology.html
+
+ Parameters
+ ----------
+
+ df : pd.DataFrame
+ A Pandas' DataFrame with features and target columns.
+ The model will be trained to predict the target column
+ from the features.
+
+ treatment_col: str
+ The name of the column in `df` which contains the names of
+ the treatments and control to which each data sample was subjected.
+
+ control_name: str
+ The name of the control group.
+
+ prediction_column : str
+ The name of the column with the predictions from the provided learner.
+
+ learner: LearnerFnType
+ A fklearn classification learner function.
+
+ treatment_learner: LearnerFnType
+ An optional fklearn classification learner function.
+
+ learner_transformers: List[LearnerFnType]
+ A list of fklearn transformer functions to be applied after the learner and before estimating the CATE.
+ This parameter may be useful, for example, to estimate the CATE with calibrated classifiers.
+ """
+
+ control_learner = copy.deepcopy(learner)
+
+ if treatment_learner is None:
+ treatment_learner = copy.deepcopy(learner)
+
+ # pipeline
+ if learner_transformers is not None:
+ learner_transformers = copy.deepcopy(learner_transformers)
+ control_learner_pipe = build_pipeline(*[control_learner] + learner_transformers)
+
+ treatment_learner_pipe = build_pipeline(
+ *[treatment_learner] + learner_transformers
+ )
+ else:
+ control_learner_pipe = copy.deepcopy(control_learner)
+ treatment_learner_pipe = copy.deepcopy(treatment_learner)
+
+ # learners
+ unique_treatments = _get_unique_treatments(df, treatment_col, control_name)
+
+ learners, learners_logs = _get_learners(
+ df=df,
+ control_learner=control_learner_pipe,
+ treatment_learner=treatment_learner_pipe,
+ unique_treatments=unique_treatments,
+ control_name=control_name,
+ treatment_col=treatment_col,
+ )
+
+ def p(new_df: pd.DataFrame) -> pd.DataFrame:
+ return _simulate_t_learner_treatment_effect(
+ new_df,
+ learners,
+ unique_treatments,
+ control_name,
+ prediction_column,
+ )
+
+ p.__doc__ = learner_pred_fn_docstring("causal_t_classification_learner")
+
+ log = {"causal_t_classification_learner": {**learners_logs}}
+
+ return p, p(df), log
+
+
+causal_t_classification_learner.__doc__ = learner_return_docstring(
+ "Causal T-Learner Classifier"
+)
diff --git a/tests/causal/cate_learning/test_meta_learners.py b/tests/causal/cate_learning/test_meta_learners.py
index 4d982623..dd7d297e 100644
--- a/tests/causal/cate_learning/test_meta_learners.py
+++ b/tests/causal/cate_learning/test_meta_learners.py
@@ -1,20 +1,45 @@
-from unittest.mock import create_autospec, patch
+from typing import Callable
+from unittest.mock import MagicMock, call, create_autospec, patch
import numpy as np
import pandas as pd
import pytest
+from pandas import DataFrame
+from pandas.testing import assert_frame_equal
+
from fklearn.causal.cate_learning.meta_learners import (
TREATMENT_FEATURE, _append_treatment_feature, _create_treatment_flag,
- _filter_by_treatment, _fit_by_treatment, _get_unique_treatments,
- _predict_by_treatment_flag, _simulate_treatment_effect,
- causal_s_classification_learner)
+ _filter_by_treatment, _fit_by_treatment, _get_learners, _get_model_fcn,
+ _get_unique_treatments, _predict_by_treatment_flag,
+ _simulate_t_learner_treatment_effect, _simulate_treatment_effect,
+ causal_s_classification_learner, causal_t_classification_learner)
from fklearn.exceptions.exceptions import (MissingControlError,
MissingTreatmentError,
MultipleTreatmentsError)
from fklearn.training.classification import logistic_classification_learner
from fklearn.types import LearnerFnType
-from pandas import DataFrame
-from pandas.testing import assert_frame_equal
+
+
+@pytest.fixture
+def base_input_df():
+ return pd.DataFrame(
+ {
+ "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0],
+ "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12],
+ "treatment": [
+ "A",
+ "B",
+ "A",
+ "A",
+ "B",
+ "control",
+ "control",
+ "B",
+ "control",
+ ],
+ "target": [1, 1, 1, 0, 0, 1, 0, 0, 1],
+ }
+ )
def test__append_treatment_feature():
@@ -191,26 +216,7 @@ def test__create_treatment_flag():
assert_frame_equal(results, expected)
-def test__fit_by_treatment():
- df = pd.DataFrame(
- {
- "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0],
- "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12],
- "treatment": [
- "A",
- "B",
- "A",
- "A",
- "B",
- "control",
- "control",
- "B",
- "control",
- ],
- "target": [1, 1, 1, 0, 0, 1, 0, 0, 1],
- }
- )
-
+def test__fit_by_treatment(base_input_df):
learner_binary = logistic_classification_learner(
features=["x1", "x2", TREATMENT_FEATURE],
target="target",
@@ -220,7 +226,7 @@ def test__fit_by_treatment():
treatments = ["A", "B"]
learners, logs = _fit_by_treatment(
- df,
+ base_input_df,
learner=learner_binary,
treatment_col="treatment",
control_name="control",
@@ -352,27 +358,8 @@ def test_causal_s_classification_learner(
mock_get_unique_treatments,
mock_fit_by_treatment,
mock_simulate_treatment_effect,
+ base_input_df,
):
-
- df = pd.DataFrame(
- {
- "x1": [1.3, 1.0, 1.8, -0.1, 0.0, 1.0, 2.2, 0.4, -5.0],
- "x2": [10, 4, 15, 6, 5, 12, 14, 5, 12],
- "treatment": [
- "A",
- "B",
- "A",
- "A",
- "B",
- "control",
- "control",
- "B",
- "control",
- ],
- "target": [1, 1, 1, 0, 0, 1, 0, 0, 1],
- }
- )
-
mock_model = create_autospec(logistic_classification_learner)
mock_fit_by_treatment.side_effect = [
# treatment = A
@@ -382,7 +369,7 @@ def test_causal_s_classification_learner(
]
causal_s_classification_learner(
- df,
+ base_input_df,
treatment_col="treatment",
control_name="control",
prediction_column="prediction",
@@ -394,3 +381,186 @@ def test_causal_s_classification_learner(
mock_get_unique_treatments.assert_called()
mock_fit_by_treatment.assert_called()
mock_simulate_treatment_effect.assert_called()
+
+
+def test_simulate_t_learner_treatment_effect():
+ df = pd.DataFrame(
+ {
+ "x1": [1.3, 1.0, 1.8, -0.1],
+ "x2": [10, 4, 15, 6],
+ "treatment": ["A", "B", "A", "control"],
+ "target": [0, 0, 0, 1],
+ }
+ )
+
+ treatments = ["A", "B"]
+ control_name = "control"
+ prediction_column = "prediction"
+
+ control_learner = MagicMock()
+ control_learner.side_effect = lambda _: pd.DataFrame({"prediction": [1, 2, 3, 4]})
+
+ treatment_learner = MagicMock()
+ treatment_learner.side_effect = lambda _: pd.DataFrame({"prediction": [3, 2, 4, 4]})
+
+ learners = {
+ "control": control_learner,
+ "A": treatment_learner,
+ "B": treatment_learner,
+ }
+
+ result = _simulate_t_learner_treatment_effect(
+ df,
+ learners,
+ treatments,
+ control_name,
+ prediction_column,
+ )
+
+ print(result.suggested_treatment)
+
+ expected = pd.DataFrame(
+ {
+ "x1": [1.3, 1.0, 1.8, -0.1],
+ "x2": [10, 4, 15, 6],
+ "treatment": ["A", "B", "A", "control"],
+ "target": [0, 0, 0, 1],
+ "treatment_A__prediction_on_treatment": [3, 2, 4, 4],
+ "treatment_A__uplift": [2, 0, 1, 0],
+ "treatment_B__prediction_on_treatment": [3, 2, 4, 4],
+ "treatment_B__uplift": [2, 0, 1, 0],
+ "uplift": [2, 0, 1, 0],
+ "suggested_treatment": ["treatment_A", "control", "treatment_A", "control"],
+ }
+ )
+
+ assert isinstance(result, pd.DataFrame)
+ assert_frame_equal(result, expected)
+
+
+def test_get_model_fcn(base_input_df):
+ """
+ Test if the fn is filtering the data
+ Test if the learner is called with the filtered data
+ """
+
+ fake_prediction_column = [0.1, 0.2, 0.3]
+ df_expected = pd.DataFrame(
+ {
+ "x1": [1.3, 1.8, -0.1],
+ "x2": [10, 15, 6],
+ "treatment": [
+ "A",
+ "A",
+ "A",
+ ],
+ "target": [1, 1, 0],
+ "prediction": fake_prediction_column,
+ }
+ )
+
+ def mock_learner(df):
+ df["prediction"] = fake_prediction_column
+
+ return (lambda x: x, df, dict())
+
+ learner = MagicMock()
+ learner.side_effect = mock_learner
+
+ mock_fcn, mock_p_df, mock_logs = _get_model_fcn(
+ base_input_df, "treatment", "A", learner
+ )
+
+ assert isinstance(mock_fcn, Callable)
+ assert_frame_equal(mock_p_df, df_expected)
+ assert isinstance(mock_logs, dict)
+
+
+def test_get_model_fcn_exception(base_input_df):
+ """
+ Test if the fn is raising an exception when treatment name
+ is not in treatment list.
+ """
+
+ fake_prediction_column = [0.1, 0.2, 0.3]
+
+ def mock_learner(df):
+ df["prediction"] = fake_prediction_column
+
+ return (lambda x: x, df, dict())
+
+ learner = MagicMock()
+ learner.side_effect = mock_learner
+
+ with pytest.raises(Exception) as e:
+ _ = _get_model_fcn(base_input_df, "treatment", "C", learner)
+
+ assert e.type == MissingTreatmentError
+
+
+@patch("fklearn.causal.cate_learning.meta_learners._get_model_fcn")
+def test_get_learners(mock_get_model_fcn):
+ """
+ Test if it is receiving a list of treatments and is returning a dict
+ of learners.
+ """
+ unique_treatments = ["treatment_a", "treatment_b", "treatment_c"]
+
+ mock_get_model_fcn.side_effect = [
+ ("mocked_control_fcn", None, None),
+ ("mocked_treatment_fcn_filtering_treatment_a", None, None),
+ ("mocked_treatment_fcn_filtering_treatment_b", None, None),
+ ("mocked_treatment_fcn_filtering_treatment_c", None, None),
+ ]
+
+ learners, logs = _get_learners(
+ df="mocked_df",
+ unique_treatments=unique_treatments,
+ treatment_col="treatment",
+ control_name="control",
+ control_learner="mocked_control_fcn",
+ treatment_learner="mocked_treatment_fcn",
+ )
+
+ assert learners["control"] == "mocked_control_fcn"
+ assert learners["treatment_a"] == "mocked_treatment_fcn_filtering_treatment_a"
+ assert learners["treatment_b"] == "mocked_treatment_fcn_filtering_treatment_b"
+ assert learners["treatment_c"] == "mocked_treatment_fcn_filtering_treatment_c"
+ assert isinstance(learners, dict)
+ assert isinstance(logs, dict)
+
+ calls = [
+ call("mocked_df", "treatment", "control", "mocked_control_fcn"),
+ call("mocked_df", "treatment", "treatment_a", "mocked_treatment_fcn"),
+ call("mocked_df", "treatment", "treatment_b", "mocked_treatment_fcn"),
+ call("mocked_df", "treatment", "treatment_c", "mocked_treatment_fcn"),
+ ]
+
+ mock_get_model_fcn.assert_has_calls(calls)
+
+
+@patch(
+ "fklearn.causal.cate_learning.meta_learners._simulate_t_learner_treatment_effect"
+)
+@patch("fklearn.causal.cate_learning.meta_learners._get_learners")
+@patch("fklearn.causal.cate_learning.meta_learners._get_unique_treatments")
+def test_causal_t_classification_learner(
+ mock_get_unique_treatments,
+ mock_get_learners,
+ mock_simulate_t_learner_treatment_effect,
+ base_input_df,
+):
+ mock_get_learners.side_effect = [([], dict())]
+ mock_model = create_autospec(logistic_classification_learner)
+
+ causal_t_classification_learner(
+ df=base_input_df,
+ treatment_col="treatment",
+ control_name="control",
+ prediction_column="prediction",
+ learner=mock_model,
+ )
+
+ mock_get_unique_treatments.assert_called()
+ mock_get_learners.assert_called()
+ mock_simulate_t_learner_treatment_effect.assert_called()