diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..aa9fe5ab --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,6 @@ +# Benchmark Experiments +--- +This directory includes some benchmark experiments and demonstrations about off-policy evaluation using [the full size Open Bandit Dataset](https://research.zozo.com/data.html). The detailed description, results, and discussions can be found in [the relevant paper](https://arxiv.org/abs/2008.07146). + +- `cf_policy_search`: counterfactual policy search using OPE +- `ope`: estimation performance comparisons on a variety of OPE estimators diff --git a/benchmark/cf_policy_search/run_cf_policy_search.py b/benchmark/cf_policy_search/run_cf_policy_search.py index 14a1d54b..3643451d 100644 --- a/benchmark/cf_policy_search/run_cf_policy_search.py +++ b/benchmark/cf_policy_search/run_cf_policy_search.py @@ -1,10 +1,10 @@ import argparse from pathlib import Path import yaml -import time -import pandas as pd import numpy as np +from pandas import DataFrame +from joblib import Parallel, delayed from sklearn.linear_model import LogisticRegression from sklearn.ensemble import RandomForestClassifier from sklearn.experimental import enable_hist_gradient_boosting @@ -13,7 +13,6 @@ from custom_dataset import OBDWithInteractionFeatures from obp.policy import IPWLearner from obp.ope import InverseProbabilityWeighting -from obp.utils import estimate_confidence_interval_by_bootstrap # hyperparameter for the regression model used in model dependent OPE estimators with open("./conf/hyperparams.yaml", "rb") as f: @@ -28,10 +27,10 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="run evaluation policy selection.") parser.add_argument( - "--n_boot_samples", + "--n_runs", type=int, default=5, - help="number of bootstrap samples in the experiment.", + help="number of bootstrap sampling in the experiment.", ) parser.add_argument( "--context_set", @@ -67,17 +66,24 @@ default=0.5, help="the proportion of the dataset to include in the test split.", ) + parser.add_argument( + "--n_jobs", + type=int, + default=1, + help="the maximum number of concurrently running jobs.", + ) parser.add_argument("--random_state", type=int, default=12345) args = parser.parse_args() print(args) # configurations - n_boot_samples = args.n_boot_samples + n_runs = args.n_runs context_set = args.context_set base_model = args.base_model behavior_policy = args.behavior_policy campaign = args.campaign test_size = args.test_size + n_jobs = args.n_jobs random_state = args.random_state np.random.seed(random_state) data_path = Path("../open_bandit_dataset") @@ -89,8 +95,8 @@ data_path=data_path, context_set=context_set, ) - # define a evaluation policy - evaluation_policy = IPWLearner( + # define a counterfactual policy based on IPWLearner + counterfactual_policy = IPWLearner( base_model=base_model_dict[base_model](**hyperparams[base_model]), n_actions=obd.n_actions, len_list=obd.len_list, @@ -107,15 +113,13 @@ is_timeseries_split=True, ) - start = time.time() - ope_results = np.zeros(n_boot_samples) - for b in np.arange(n_boot_samples): + def process(b: int): # sample bootstrap from batch logged bandit feedback boot_bandit_feedback = obd.sample_bootstrap_bandit_feedback( test_size=test_size, is_timeseries_split=True, random_state=b ) # train an evaluation on the training set of the logged bandit feedback data - action_dist = evaluation_policy.fit( + action_dist = counterfactual_policy.fit( context=boot_bandit_feedback["context"], action=boot_bandit_feedback["action"], reward=boot_bandit_feedback["reward"], @@ -123,38 +127,30 @@ position=boot_bandit_feedback["position"], ) # make action selections (predictions) - action_dist = evaluation_policy.predict( + action_dist = counterfactual_policy.predict( context=boot_bandit_feedback["context_test"] ) # estimate the policy value of a given counterfactual algorithm by the three OPE estimators. ipw = InverseProbabilityWeighting() - ope_results[b] = ( - ipw.estimate_policy_value( - reward=boot_bandit_feedback["reward_test"], - action=boot_bandit_feedback["action_test"], - position=boot_bandit_feedback["position_test"], - pscore=boot_bandit_feedback["pscore_test"], - action_dist=action_dist, - ) - / ground_truth + return ipw.estimate_policy_value( + reward=boot_bandit_feedback["reward_test"], + action=boot_bandit_feedback["action_test"], + position=boot_bandit_feedback["position_test"], + pscore=boot_bandit_feedback["pscore_test"], + action_dist=action_dist, ) - print(f"{b+1}th iteration: {np.round((time.time() - start) / 60, 2)}min") - ope_results_dict = estimate_confidence_interval_by_bootstrap( - samples=ope_results, random_state=random_state + processed = Parallel(backend="multiprocessing", n_jobs=n_jobs, verbose=50,)( + [delayed(process)(i) for i in np.arange(n_runs)] ) - ope_results_dict["mean(no-boot)"] = ope_results.mean() - ope_results_dict["std"] = np.std(ope_results, ddof=1) - ope_results_df = pd.DataFrame(ope_results_dict, index=["ipw"]) - - # calculate estimated policy value relative to that of the behavior policy - print("=" * 70) - print(f"random_state={random_state}: evaluation policy={policy_name}") - print("-" * 70) - print(ope_results_df) - print("=" * 70) - # save evaluation policy evaluation results in `./logs` directory + # save counterfactual policy evaluation results in `./logs` directory + ope_results = np.zeros((n_runs, 2)) + for b, estimated_policy_value_b in enumerate(processed): + ope_results[b, 0] = estimated_policy_value_b + ope_results[b, 1] = estimated_policy_value_b / ground_truth save_path = Path("./logs") / behavior_policy / campaign save_path.mkdir(exist_ok=True, parents=True) - ope_results_df.to_csv(save_path / f"{policy_name}.csv") + DataFrame( + ope_results, columns=["policy_value", "relative_policy_value"] + ).describe().round(6).to_csv(save_path / f"{policy_name}.csv") diff --git a/benchmark/ope/README.md b/benchmark/ope/README.md index 44f10d72..e81fae6c 100644 --- a/benchmark/ope/README.md +++ b/benchmark/ope/README.md @@ -1,45 +1,166 @@ # Benchmarking Off-Policy Evaluation ## Description +We use the (full size) open bandit dataset to evaluate and compare OPE estimators in a *realistic* and *reproducible* manner. Specifically, we evaluate the estimation performances of a wide variety of existing estimators by comparing the estimated policy values with the ground-truth of an evaluation policy contained in the data. +### Dataset +Please download the full [open bandit dataset](https://research.zozo.com/data.html) and put it as the `../open_bandit_dataset/` directory. ## Training Regression Model +Model-dependent estimators such as DM and DR need a pre-trained regression model. +Here, we train a regression model with some machine learning methods. + +We define hyperparameters for the machine learning methods in [`conf/hyperparams.yaml`](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/conf/hyperparams.yaml). +[train_regression_model.py](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/train_regression_model.py) implements the training process of the regression model. + ``` -for model in logistic_regression +python train_regression_model.py\ + --n_runs $n_runs\ + --base_model $base_model\ "logistic_regression" or "lightgbm" + --behavior_policy $behavior_policy\ "random" or "bts" + --campaign $campaign\ # "men", "women", or "all" + --n_sim_to_compute_action_dist $n_sim_to_compute_action_dist\ + --is_timeseries_split $is_timeseries_split\ # in-sample or out-sample + --test_size $test_size\ + --is_mrdr $is_mrdr\ # use "more robust doubly robust" option or not + --n_jobs $n_jobs\ + --random_state $random_state +``` + +where +- `$n_runs` specifies the number of simulation runs with different bootstrap samples in the experiment. +- `$base_model` specifies the base ML model for defining the regression model and should be one of "logistic_regression", "random_forest", or "lightgbm". +- `$campaign` specifies the campaign considered in ZOZOTOWN and should be one of "all", "men", or "women". +- `$n_sim_to_compute_action_dist` is the number of monte carlo simulation to compute the action choice probabilities by a given evaluation policy. +- `$is_timeseries_split` is whether the data is split based on timestamp or not. If true, the out-sample performance of OPE is tested. See the relevant paper for details. +- - `$test_size` specifies the proportion of the dataset to include in the test split when `$is_timeseries_split=True`. +- `$is_mrdr` is whether the regression model is trained by the more robust doubly robust way or not. See the relevant paper for details. +- `$n_jobs` is the maximum number of concurrently running jobs. + +For example, the following command trains the regression model based on logistic regression on the logged bandit feedback data collected by the Random policy (as a behavior policy) in "All" campaign. + +```bash +python train_regression_model.py\ + --n_runs 10\ + --base_model logistic_regression\ + --behavior_policy random\ + --campaign all\ + --is_mrdr False\ + --is_timeseries_split False +``` + + ## Evaluating Off-Policy Estimators +Next, we evaluate and compare the estimation performances of the following OPE estimators: + +- Direct Method (DM) +- Inverse Probability Weighting (IPW) +- Self-Normalized Inverse Probability Weighting (SNIPW) +- Doubly Robust (DR) +- Self-Normalized Doubly Robust (SNDR) +- Switch Doubly Robust (Switch-DR) +- Doubly Robust with Optimistic Shrinkage (DRos) +- More Robust Doubly Robust (MRDR) + +For Switch-DR and DRos, we test some different values of hyperparameters. +See our [documentation](https://zr-obp.readthedocs.io/en/latest/estimators.html) for the details about these estimators. + + +[benchmark_off_policy_estimators.py](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/benchmark_off_policy_estimators.py) implements the evaluation and comparison of OPE estimators using the open bandit dataset. +Note that you have to finish training a regression model (see the above section) before conducting the evaluation of OPE in the corresponding setting. +We summarize the detailed experimental protocol for evaluating OPE estimators using real-world data [here](https://zr-obp.readthedocs.io/en/latest/evaluation_ope.html). + +``` +# run evaluation of OPE estimators with the full open bandit dataset +python benchmark_off_policy_estimators.py\ + --n_runs $n_runs\ + --base_model $base_model\ "logistic_regression" or "lightgbm" + --behavior_policy $behavior_policy\ "random" or "bts" + --campaign $campaign\ # "men", "women", or "all" + --n_sim_to_compute_action_dist $n_sim_to_compute_action_dist\ + --is_timeseries_split\ # in-sample or out-sample + --test_size $test_size\ + --n_jobs $n_jobs\ + --random_state $random_state +``` +where +- `$n_runs` specifies the number of simulation runs with different bootstrap samples in the experiment to estimate standard deviations of the performance of OPE estimators. +- $base_model_for_evaluation_policy specifies the base ML model for defining the regression model and should be one of "logistic_regression", "random_forest", or "lightgbm". +- `$campaign` specifies the campaign considered in ZOZOTOWN and should be one of "all", "men", or "women". +- `$n_sim_to_compute_action_dist` is the number of monte carlo simulation to compute the action choice probabilities by a given evaluation policy. +- `$is_timeseries_split` is whether the data is split based on timestamp or not. If true, the out-sample performance of OPE is tested. See the relevant paper for details. +- `$test_size` specifies the proportion of the dataset to include in the test split when `$is_timeseries_split=True`. +- `$n_jobs` is the maximum number of concurrently running jobs. + +For example, the following command compares the estimation performances of the OPE estimators listed above using Bernoulli TS as an evaluation policy and Random as a behavior policy in "All" campaign in the out-sample situation. + +```bash +python benchmark_off_policy_estimators.py\ + --n_runs 10\ + --base_model logistic_regression\ + --behavior_policy random\ + --campaign all\ + --test_size 0.3\ + --is_timeseries_split True +``` + +The results of our benchmark experiments can be found in Section 5 of [our paper](https://arxiv.org/abs/2008.07146). + + + + diff --git a/benchmark/ope/benchmark_off_policy_estimators.py b/benchmark/ope/benchmark_off_policy_estimators.py index ca62c964..b1a77e04 100644 --- a/benchmark/ope/benchmark_off_policy_estimators.py +++ b/benchmark/ope/benchmark_off_policy_estimators.py @@ -1,10 +1,11 @@ import argparse import pickle -import time from pathlib import Path +from distutils.util import strtobool import numpy as np -import pandas as pd +from pandas import DataFrame +from joblib import Parallel, delayed from obp.dataset import OpenBanditDataset from obp.policy import BernoulliTS, Random @@ -16,39 +17,41 @@ DoublyRobust, SelfNormalizedDoublyRobust, SwitchDoublyRobust, - SwitchInverseProbabilityWeighting, DoublyRobustWithShrinkage, ) -# compared OPE estimators +# OPE estimators compared ope_estimators = [ DirectMethod(), InverseProbabilityWeighting(), SelfNormalizedInverseProbabilityWeighting(), DoublyRobust(), SelfNormalizedDoublyRobust(), - SwitchInverseProbabilityWeighting(tau=1, estimator_name="switch-ipw (tau=1)"), - SwitchInverseProbabilityWeighting(tau=100, estimator_name="switch-ipw (tau=100)"), - SwitchDoublyRobust(tau=1, estimator_name="switch-dr (tau=1)"), + SwitchDoublyRobust(tau=5, estimator_name="switch-dr (tau=5)"), + SwitchDoublyRobust(tau=10, estimator_name="switch-dr (tau=10)"), + SwitchDoublyRobust(tau=50, estimator_name="switch-dr (tau=50)"), SwitchDoublyRobust(tau=100, estimator_name="switch-dr (tau=100)"), - DoublyRobustWithShrinkage(lambda_=1, estimator_name="dr-os (lambda=1)"), + SwitchDoublyRobust(tau=500, estimator_name="switch-dr (tau=500)"), + SwitchDoublyRobust(tau=1000, estimator_name="switch-dr (tau=1000)"), + DoublyRobustWithShrinkage(lambda_=5, estimator_name="dr-os (lambda=5)"), + DoublyRobustWithShrinkage(lambda_=10, estimator_name="dr-os (lambda=10)"), + DoublyRobustWithShrinkage(lambda_=50, estimator_name="dr-os (lambda=50)"), DoublyRobustWithShrinkage(lambda_=100, estimator_name="dr-os (lambda=100)"), + DoublyRobustWithShrinkage(lambda_=500, estimator_name="dr-os (lambda=500)"), + DoublyRobustWithShrinkage(lambda_=1000, estimator_name="dr-os (lambda=1000)"), ] if __name__ == "__main__": parser = argparse.ArgumentParser(description="evaluate off-policy estimators.") parser.add_argument( - "--n_boot_samples", - type=int, - default=1, - help="number of bootstrap samples in the experiment.", + "--n_runs", type=int, default=1, help="number of experimental runs.", ) parser.add_argument( "--base_model", type=str, - choices=["logistic_regression", "lightgbm"], + choices=["logistic_regression", "random_forest", "lightgbm"], required=True, - help="base ML model for regression model, logistic_regression or lightgbm.", + help="base ML model for regression model, logistic_regression, random_forest, or lightgbm.", ) parser.add_argument( "--behavior_policy", @@ -64,12 +67,6 @@ required=True, help="campaign name, men, women, or all.", ) - parser.add_argument( - "--n_sim_for_action_dist", - type=float, - default=1000000, - help="number of monte carlo simulation to compute the action distribution of bts.", - ) parser.add_argument( "--test_size", type=float, @@ -78,7 +75,8 @@ ) parser.add_argument( "--is_timeseries_split", - action="store_true", + type=strtobool, + default=False, help="If true, split the original logged badnit feedback data by time series.", ) parser.add_argument( @@ -87,22 +85,30 @@ default=1000000, help="number of monte carlo simulation to compute the action distribution of bts.", ) + parser.add_argument( + "--n_jobs", + type=int, + default=1, + help="the maximum number of concurrently running jobs.", + ) parser.add_argument("--random_state", type=int, default=12345) args = parser.parse_args() print(args) # configurations of the benchmark experiment - n_boot_samples = args.n_boot_samples + n_runs = args.n_runs base_model = args.base_model behavior_policy = args.behavior_policy evaluation_policy = "bts" if behavior_policy == "random" else "random" campaign = args.campaign - n_sim_for_action_dist = args.n_sim_for_action_dist test_size = args.test_size is_timeseries_split = args.is_timeseries_split n_sim_to_compute_action_dist = args.n_sim_to_compute_action_dist + n_jobs = args.n_jobs random_state = args.random_state + np.random.seed(random_state) data_path = Path("../open_bandit_dataset") + # prepare path log_path = ( Path("./logs") / behavior_policy / campaign / "out_sample" / base_model @@ -124,82 +130,77 @@ test_size=test_size, is_timeseries_split=is_timeseries_split, ) + # compute action distribution by evaluation policy + if evaluation_policy == "bts": + policy = BernoulliTS( + n_actions=obd.n_actions, + len_list=obd.len_list, + is_zozotown_prior=True, # replicate the policy in the ZOZOTOWN production + campaign=campaign, + random_state=random_state, + ) + else: + policy = Random( + n_actions=obd.n_actions, len_list=obd.len_list, random_state=random_state, + ) + action_dist_single_round = policy.compute_batch_action_dist( + n_sim=n_sim_to_compute_action_dist + ) - start = time.time() - relative_ee = { - est.estimator_name: np.zeros(n_boot_samples) for est in ope_estimators - } - for b in np.arange(n_boot_samples): + def process(b: int): # load the pre-trained regression model with open(reg_model_path / f"reg_model_{b}.pkl", "rb") as f: reg_model = pickle.load(f) + with open(reg_model_path / f"reg_model_mrdr_{b}.pkl", "rb") as f: + reg_model_mrdr = pickle.load(f) with open(reg_model_path / f"is_for_reg_model_{b}.pkl", "rb") as f: is_for_reg_model = pickle.load(f) # sample bootstrap samples from batch logged bandit feedback - boot_bandit_feedback = obd.sample_bootstrap_bandit_feedback( - test_size=test_size, is_timeseries_split=is_timeseries_split, random_state=b + bandit_feedback = obd.sample_bootstrap_bandit_feedback( + test_size=test_size, + is_timeseries_split=is_timeseries_split, + random_state=b, ) for key_ in ["context", "action", "reward", "pscore", "position"]: - boot_bandit_feedback[key_] = boot_bandit_feedback[key_][~is_for_reg_model] - if evaluation_policy == "bts": - policy = BernoulliTS( - n_actions=obd.n_actions, - len_list=obd.len_list, - is_zozotown_prior=True, # replicate the policy in the ZOZOTOWN production - campaign=campaign, - random_state=random_state, - ) - action_dist = policy.compute_batch_action_dist( - n_sim=100000, n_rounds=boot_bandit_feedback["n_rounds"] - ) - else: - policy = Random( - n_actions=obd.n_actions, - len_list=obd.len_list, - random_state=random_state, - ) - action_dist = policy.compute_batch_action_dist( - n_sim=100000, n_rounds=boot_bandit_feedback["n_rounds"] - ) + bandit_feedback[key_] = bandit_feedback[key_][~is_for_reg_model] # estimate the mean reward function using the pre-trained reg_model estimated_rewards_by_reg_model = reg_model.predict( - context=boot_bandit_feedback["context"], + context=bandit_feedback["context"], + ) + estimated_rewards_by_reg_model_mrdr = reg_model_mrdr.predict( + context=bandit_feedback["context"], ) # evaluate the estimation performance of OPE estimators ope = OffPolicyEvaluation( - bandit_feedback=boot_bandit_feedback, ope_estimators=ope_estimators, + bandit_feedback=bandit_feedback, ope_estimators=ope_estimators, + ) + action_dist = np.tile( + action_dist_single_round, (bandit_feedback["n_rounds"], 1, 1) ) - relative_estimation_errors = ope.evaluate_performance_of_estimators( + relative_ee_b = ope.evaluate_performance_of_estimators( ground_truth_policy_value=ground_truth_policy_value, action_dist=action_dist, estimated_rewards_by_reg_model=estimated_rewards_by_reg_model, ) - # store relative estimation errors of OPE estimators at each bootstrap - for ( - estimator_name, - relative_estimation_error, - ) in relative_estimation_errors.items(): - relative_ee[estimator_name][b] = relative_estimation_error - - print(f"{b+1}th iteration: {np.round((time.time() - start) / 60, 2)}min") + relative_ee_b["mrdr"] = ope.evaluate_performance_of_estimators( + ground_truth_policy_value=ground_truth_policy_value, + action_dist=action_dist, + estimated_rewards_by_reg_model=estimated_rewards_by_reg_model_mrdr, + )["dr"] - # estimate means and standard deviations of relative estimation by nonparametric bootstrap method - evaluation_of_ope_results = {est.estimator_name: dict() for est in ope_estimators} - for estimator_name in evaluation_of_ope_results.keys(): - evaluation_of_ope_results[estimator_name]["mean"] = relative_ee[ - estimator_name - ].mean() - evaluation_of_ope_results[estimator_name]["std"] = np.std( - relative_ee[estimator_name], ddof=1 - ) + return relative_ee_b - evaluation_of_ope_results_df = pd.DataFrame(evaluation_of_ope_results).T - print("=" * 50) - print(f"random_state={random_state}") - print("-" * 50) - print(evaluation_of_ope_results_df) - print("=" * 50) + processed = Parallel(backend="multiprocessing", n_jobs=n_jobs, verbose=50,)( + [delayed(process)(i) for i in np.arange(n_runs)] + ) - # save results of the evaluation of off-policy estimators in './logs' directory. - evaluation_of_ope_results_df.to_csv(log_path / f"relative_ee_of_ope_estimators.csv") + # save results of the evaluation of ope in './logs' directory. + estimator_names = [est.estimator_name for est in ope_estimators] + ["mrdr"] + relative_ee = {est: np.zeros(n_runs) for est in estimator_names} + for b, relative_ee_b in enumerate(processed): + for (estimator_name, relative_ee_,) in relative_ee_b.items(): + relative_ee[estimator_name][b] = relative_ee_ + DataFrame(relative_ee).describe().T.round(6).to_csv( + log_path / f"eval_ope_results.csv" + ) diff --git a/benchmark/ope/conf/hyperparams.yaml b/benchmark/ope/conf/hyperparams.yaml index 331e0a60..92358bcb 100644 --- a/benchmark/ope/conf/hyperparams.yaml +++ b/benchmark/ope/conf/hyperparams.yaml @@ -1,6 +1,6 @@ lightgbm: - max_iter: 300 - learning_rate: 0.005 + max_iter: 100 + learning_rate: 0.01 max_depth: 5 min_samples_leaf: 10 random_state: 12345 @@ -9,7 +9,7 @@ logistic_regression: C: 1000 random_state: 12345 random_forest: - n_estimators: 300 + n_estimators: 100 max_depth: 5 min_samples_leaf: 10 random_state: 12345 diff --git a/benchmark/ope/train_regression_model.py b/benchmark/ope/train_regression_model.py index f33b26b4..bc81116c 100644 --- a/benchmark/ope/train_regression_model.py +++ b/benchmark/ope/train_regression_model.py @@ -1,43 +1,84 @@ -import time import argparse import yaml import pickle +from distutils.util import strtobool from pathlib import Path +from typing import Dict import numpy as np -import pandas as pd +from pandas import DataFrame +from joblib import Parallel, delayed from sklearn.experimental import enable_hist_gradient_boosting -from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss, roc_auc_score from obp.dataset import OpenBanditDataset +from obp.policy import BernoulliTS, Random from obp.ope import RegressionModel +from obp.types import BanditFeedback # hyperparameter settings for the base ML model in regression model with open("./conf/hyperparams.yaml", "rb") as f: hyperparams = yaml.safe_load(f) base_model_dict = dict( - logistic_regression=LogisticRegression, lightgbm=HistGradientBoostingClassifier, + logistic_regression=LogisticRegression, + lightgbm=HistGradientBoostingClassifier, + random_forest=RandomForestClassifier, ) -metrics = ["auc", "rce"] + +def relative_ce(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """Calculate relative cross-entropy.""" + naive_pred = np.ones_like(y_true) * y_true.mean() + ce_naive_pred = log_loss(y_true=y_true, y_pred=naive_pred) + ce_y_pred = log_loss(y_true=y_true, y_pred=y_pred) + return 1.0 - (ce_y_pred / ce_naive_pred) + + +def evaluate_reg_model( + bandit_feedback: BanditFeedback, + is_timeseries_split: bool, + estimated_rewards_by_reg_model: np.ndarray, + is_for_reg_model: bool, +) -> Dict[str, float]: + """Evaluate the estimation performance of regression model by AUC and RCE.""" + performance_reg_model = dict(auc=0.0, rce=0.0) + if is_timeseries_split: + factual_rewards = bandit_feedback["reward_test"] + estimated_factual_rewards = estimated_rewards_by_reg_model[ + np.arange(factual_rewards.shape[0]), + bandit_feedback["action_test"].astype(int), + bandit_feedback["position_test"].astype(int), + ] + else: + factual_rewards = bandit_feedback["reward"][~is_for_reg_model] + estimated_factual_rewards = estimated_rewards_by_reg_model[ + np.arange((~is_for_reg_model).sum()), + bandit_feedback["action"][~is_for_reg_model].astype(int), + bandit_feedback["position"][~is_for_reg_model].astype(int), + ] + performance_reg_model["auc"] = roc_auc_score( + y_true=factual_rewards, y_score=estimated_factual_rewards + ) + performance_reg_model["rce"] = relative_ce( + y_true=factual_rewards, y_pred=estimated_factual_rewards + ) + return performance_reg_model + if __name__ == "__main__": parser = argparse.ArgumentParser(description="evaluate off-policy estimators.") parser.add_argument( - "--n_boot_samples", - type=int, - default=1, - help="number of bootstrap samples in the experiment.", + "--n_runs", type=int, default=1, help="number of experimental runs.", ) parser.add_argument( "--base_model", type=str, - choices=["logistic_regression", "lightgbm"], + choices=["logistic_regression", "lightgbm", "random_forest"], required=True, - help="base ML model for regression model, logistic_regression, or lightgbm.", + help="base ML model for regression model, logistic_regression, random_forest, or lightgbm.", ) parser.add_argument( "--behavior_policy", @@ -61,22 +102,46 @@ ) parser.add_argument( "--is_timeseries_split", - action="store_true", + type=strtobool, + default=False, help="If true, split the original logged badnit feedback data by time series.", ) + parser.add_argument( + "--is_mrdr", + type=strtobool, + default=False, + help="If true, the regression model is trained by minimizing the empirical variance objective.", + ) + parser.add_argument( + "--n_sim_to_compute_action_dist", + type=float, + default=1000000, + help="number of monte carlo simulation to compute the action distribution of bts.", + ) + parser.add_argument( + "--n_jobs", + type=int, + default=1, + help="the maximum number of concurrently running jobs.", + ) parser.add_argument("--random_state", type=int, default=12345) args = parser.parse_args() print(args) # configurations of the benchmark experiment - n_boot_samples = args.n_boot_samples + n_runs = args.n_runs base_model = args.base_model behavior_policy = args.behavior_policy campaign = args.campaign test_size = args.test_size is_timeseries_split = args.is_timeseries_split + is_mrdr = args.is_mrdr + n_sim_to_compute_action_dist = args.n_sim_to_compute_action_dist + n_jobs = args.n_jobs random_state = args.random_state + np.random.seed(random_state) data_path = Path("../open_bandit_dataset") + # prepare path log_path = ( Path("./logs") / behavior_policy / campaign / "out_sample" / base_model @@ -89,93 +154,112 @@ obd = OpenBanditDataset( behavior_policy=behavior_policy, campaign=campaign, data_path=data_path ) - start_time = time.time() - performance_of_reg_model = { - metrics[i]: np.zeros(n_boot_samples) for i in np.arange(len(metrics)) - } - for b in np.arange(n_boot_samples): - # sample bootstrap samples from batch logged bandit feedback - boot_bandit_feedback = obd.sample_bootstrap_bandit_feedback( - test_size=test_size, is_timeseries_split=is_timeseries_split, random_state=b + # action distribution by evaluation policy + # (more robust doubly robust needs evaluation policy information) + if is_mrdr: + if behavior_policy == "random": + policy = BernoulliTS( + n_actions=obd.n_actions, + len_list=obd.len_list, + is_zozotown_prior=True, # replicate the policy in the ZOZOTOWN production + campaign=campaign, + random_state=random_state, + ) + else: + policy = Random( + n_actions=obd.n_actions, + len_list=obd.len_list, + random_state=random_state, + ) + action_dist_single_round = policy.compute_batch_action_dist( + n_sim=n_sim_to_compute_action_dist + ) + + def process(b: int): + # sample bootstrap from batch logged bandit feedback + bandit_feedback = obd.sample_bootstrap_bandit_feedback( + test_size=test_size, + is_timeseries_split=is_timeseries_split, + random_state=b, ) # split data into two folds (data for training reg_model and for ope) is_for_reg_model = np.random.binomial( - n=1, p=0.3, size=boot_bandit_feedback["n_rounds"] + n=1, p=0.3, size=bandit_feedback["n_rounds"] ).astype(bool) - # define regression model - reg_model = RegressionModel( - n_actions=obd.n_actions, - len_list=obd.len_list, - action_context=boot_bandit_feedback["action_context"], - base_model=base_model_dict[base_model](**hyperparams[base_model]), - ) - # train regression model on logged bandit feedback data - reg_model.fit( - context=boot_bandit_feedback["context"][is_for_reg_model], - action=boot_bandit_feedback["action"][is_for_reg_model], - reward=boot_bandit_feedback["reward"][is_for_reg_model], - position=boot_bandit_feedback["position"][is_for_reg_model], - ) - # evaluate the estimation performance of the regression model by AUC and RCE - if is_timeseries_split: - estimated_reward_by_reg_model = reg_model.predict( - context=boot_bandit_feedback["context_test"], + with open(reg_model_path / f"is_for_reg_model_{b}.pkl", "wb") as f: + pickle.dump( + is_for_reg_model, f, + ) + if is_mrdr: + reg_model = RegressionModel( + n_actions=obd.n_actions, + len_list=obd.len_list, + action_context=bandit_feedback["action_context"], + base_model=base_model_dict[base_model](**hyperparams[base_model]), + fitting_method="mrdr", + ) + # train regression model on logged bandit feedback data + reg_model.fit( + context=bandit_feedback["context"][is_for_reg_model], + action=bandit_feedback["action"][is_for_reg_model], + reward=bandit_feedback["reward"][is_for_reg_model], + pscore=bandit_feedback["pscore"][is_for_reg_model], + position=bandit_feedback["position"][is_for_reg_model], + action_dist=np.tile( + action_dist_single_round, (is_for_reg_model.sum(), 1, 1) + ), ) - rewards = boot_bandit_feedback["reward_test"] - estimated_rewards_ = estimated_reward_by_reg_model[ - np.arange(rewards.shape[0]), - boot_bandit_feedback["action_test"].astype(int), - boot_bandit_feedback["position_test"].astype(int), - ] + with open(reg_model_path / f"reg_model_mrdr_{b}.pkl", "wb") as f: + pickle.dump( + reg_model, f, + ) else: - estimated_reward_by_reg_model = reg_model.predict( - context=boot_bandit_feedback["context"][~is_for_reg_model], + reg_model = RegressionModel( + n_actions=obd.n_actions, + len_list=obd.len_list, + action_context=bandit_feedback["action_context"], + base_model=base_model_dict[base_model](**hyperparams[base_model]), + fitting_method="normal", + ) + # train regression model on logged bandit feedback data + reg_model.fit( + context=bandit_feedback["context"][is_for_reg_model], + action=bandit_feedback["action"][is_for_reg_model], + reward=bandit_feedback["reward"][is_for_reg_model], + position=bandit_feedback["position"][is_for_reg_model], + ) + with open(reg_model_path / f"reg_model_{b}.pkl", "wb") as f: + pickle.dump( + reg_model, f, + ) + # evaluate the estimation performance of the regression model by AUC and RCE + if is_timeseries_split: + estimated_rewards_by_reg_model = reg_model.predict( + context=bandit_feedback["context_test"], + ) + else: + estimated_rewards_by_reg_model = reg_model.predict( + context=bandit_feedback["context"][~is_for_reg_model], + ) + performance_reg_model_b = evaluate_reg_model( + bandit_feedback=bandit_feedback, + is_timeseries_split=is_timeseries_split, + estimated_rewards_by_reg_model=estimated_rewards_by_reg_model, + is_for_reg_model=is_for_reg_model, ) - rewards = boot_bandit_feedback["reward"][~is_for_reg_model] - estimated_rewards_ = estimated_reward_by_reg_model[ - np.arange((~is_for_reg_model).sum()), - boot_bandit_feedback["action"][~is_for_reg_model].astype(int), - boot_bandit_feedback["position"][~is_for_reg_model].astype(int), - ] - performance_of_reg_model["auc"][b] = roc_auc_score( - y_true=rewards, y_score=estimated_rewards_ - ) - rce_naive = -log_loss( - y_true=rewards, - y_pred=np.ones_like(rewards) - * boot_bandit_feedback["reward"][is_for_reg_model].mean(), - ) - rce_clf = -log_loss(y_true=rewards, y_pred=estimated_rewards_) - performance_of_reg_model["rce"][b] = (rce_naive - rce_clf) / rce_naive - # save trained regression model in a pickled form - pickle.dump( - reg_model, open(reg_model_path / f"reg_model_{b}.pkl", "wb"), - ) - pickle.dump( - is_for_reg_model, open(reg_model_path / f"is_for_reg_model_{b}.pkl", "wb"), - ) - print( - f"Finished {b+1}th bootstrap sample:", - f"{np.round((time.time() - start_time) / 60, 1)}min", - ) + return performance_reg_model_b - # estimate means and standard deviations of the performances of the regression model - performance_of_reg_model_ = {metric: dict() for metric in metrics} - for metric in performance_of_reg_model_.keys(): - performance_of_reg_model_[metric]["mean"] = performance_of_reg_model[ - metric - ].mean() - performance_of_reg_model_[metric]["std"] = np.std( - performance_of_reg_model[metric], ddof=1 + processed = Parallel(backend="multiprocessing", n_jobs=n_jobs, verbose=50,)( + [delayed(process)(i) for i in np.arange(n_runs)] + ) + # save performance of the regression model in './logs' directory. + if not is_mrdr: + performance_reg_model = {metric: dict() for metric in ["auc", "rce"]} + for b, performance_reg_model_b in enumerate(processed): + for metric, metric_value in performance_reg_model_b.items(): + performance_reg_model[metric][b] = metric_value + DataFrame(performance_reg_model).describe().T.round(6).to_csv( + log_path / f"performance_reg_model.csv" ) - performance_of_reg_model_df = pd.DataFrame(performance_of_reg_model_).T - print("=" * 50) - print(f"random_state={random_state}") - print("-" * 50) - print(performance_of_reg_model_df) - print("=" * 50) - - # save performance of the regression model in './logs' directory. - performance_of_reg_model_df.to_csv(log_path / f"performance_of_reg_model.csv") diff --git a/docs/estimators.rst b/docs/estimators.rst index 3c344323..199c2d5e 100644 --- a/docs/estimators.rst +++ b/docs/estimators.rst @@ -57,7 +57,7 @@ IPW does not have these properties. We can define Self-Normalized Doubly Robust (SNDR) in a similar manner as follows. .. math:: - \hat{V}_{\mathrm{SNDR}} (\pi_e; \calD) :=\frac{\E_{\calD} [\hat{q}(x_t, \pi_e) + w(x_t,a_t) (r_t-\hat{q}(x_t, a_t) ) ]}{\E_{\calD} [ w(x_t,a_t) ]}. + \hat{V}_{\mathrm{SNDR}} (\pi_e; \calD) := \E_{\calD} \left[\hat{q}(x_t, \pi_e) + \frac{w(x_t,a_t) (r_t-\hat{q}(x_t, a_t) )}{\E_{\calD} [ w(x_t,a_t) ]} \right]. Switch Estimators diff --git a/examples/examples_with_synthetic/README.md b/examples/examples_with_synthetic/README.md index 0b5624e3..1b279fbc 100644 --- a/examples/examples_with_synthetic/README.md +++ b/examples/examples_with_synthetic/README.md @@ -3,8 +3,8 @@ ## Description -Here, we use synthetic bandit datasets and pipeline to evaluate OPE estimators. -Specifically, we evaluate the estimation performances of well-known off-policy estimators using the ground-truth policy value of an evaluation policy, which is calculable with synthetic data. +Here, we use synthetic bandit datasets to evaluate OPE estimators. +Specifically, we evaluate the estimation performances of well-known off-policy estimators using the ground-truth policy value of an evaluation policy calculable with synthetic data. ## Evaluating Off-Policy Estimators @@ -18,7 +18,7 @@ In the following, we evaluate the estimation performances of - Switch Doubly Robust (Switch-DR) - Doubly Robust with Optimistic Shrinkage (DRos) -For Switch-IPW, Switch-DR, and DRos, we tried some different values of hyperparameters. +For Switch-IPW, Switch-DR, and DRos, we try some different values of hyperparameters. See [our documentation](https://zr-obp.readthedocs.io/en/latest/estimators.html) for the details about these estimators. [`./evaluate_off_policy_estimators.py`](./evaluate_off_policy_estimators.py) implements the evaluation of OPE estimators using synthetic bandit feedback data. @@ -43,7 +43,7 @@ python evaluate_off_policy_estimators.py\ - `$base_model_for_reg_model` specifies the base ML model for defining regression model and should be one of "logistic_regression", "random_forest", or "lightgbm". - `$n_jobs` is the maximum number of concurrently running jobs. -For example, the following command compares the estimation performances (relative estimation error; relative-ee) of the OPE estimators using the synthetic bandit feedback data with 100,000 rounds, 30 actions, context vectors with five dimensions. +For example, the following command compares the estimation performances (relative estimation error; relative-ee) of the OPE estimators using the synthetic bandit feedback data with 100,000 rounds, 30 actions, five dimensional context vectors. ```bash python evaluate_off_policy_estimators.py\ @@ -57,22 +57,22 @@ python evaluate_off_policy_estimators.py\ --random_state 12345 # relative-ee of OPE estimators and their standard deviations (lower is better). -# It appears that the performances of some OPE estimators depend on the choice of hyperparameters. +# It appears that the performances of some OPE estimators depend on the choice of their hyperparameters. # ============================================= # random_state=12345 # --------------------------------------------- # mean std -# dm 0.010835 0.000693 -# ipw 0.001764 0.000474 -# snipw 0.001630 0.001022 -# dr 0.001265 0.000773 -# sndr 0.002091 0.000115 -# switch-ipw (tau=1) 0.138272 0.000630 -# switch-ipw (tau=100) 0.001764 0.000474 -# switch-dr (tau=1) 0.021673 0.000507 -# switch-dr (tau=100) 0.001265 0.000773 -# dr-os (lambda=1) 0.010676 0.000694 -# dr-os (lambda=100) 0.001404 0.001083 +# dm 0.011110 0.000565 +# ipw 0.001953 0.000387 +# snipw 0.002036 0.000835 +# dr 0.001573 0.000631 +# sndr 0.001578 0.000625 +# switch-ipw (tau=1) 0.138523 0.000514 +# switch-ipw (tau=100) 0.001953 0.000387 +# switch-dr (tau=1) 0.021875 0.000414 +# switch-dr (tau=100) 0.001573 0.000631 +# dr-os (lambda=1) 0.010952 0.000567 +# dr-os (lambda=100) 0.001835 0.000884 # ============================================= ``` diff --git a/obp/dataset/real.py b/obp/dataset/real.py index cdcab98c..401c0cdb 100644 --- a/obp/dataset/real.py +++ b/obp/dataset/real.py @@ -4,7 +4,7 @@ """Dataset Class for Real-World Logged Bandit Feedback.""" from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import numpy as np import pandas as pd @@ -171,7 +171,7 @@ def pre_process(self) -> None: def obtain_batch_bandit_feedback( self, test_size: float = 0.3, is_timeseries_split: bool = False - ) -> Union[BanditFeedback, Tuple[BanditFeedback, BanditFeedback]]: + ) -> BanditFeedback: """Obtain batch logged bandit feedback. Parameters @@ -185,10 +185,8 @@ def obtain_batch_bandit_feedback( Returns -------- - bandit_feedback: tuple or BanditFeedback + bandit_feedback: BanditFeedback Batch logged bandit feedback collected by a behavior policy. - When `is_timeseries_split` is true, this method returns a tuple of - train and evaluation sets of bandit feedback, (bandit_feedback_train, bandit_feedback_eval) """ if is_timeseries_split: @@ -196,27 +194,21 @@ def obtain_batch_bandit_feedback( 0 < test_size < 1 ), f"test_size must be a float in the (0,1) interval, but {test_size} is given" n_rounds_train = np.int(self.n_rounds * (1.0 - test_size)) - bandit_feedback_train = dict( + return dict( n_rounds=n_rounds_train, n_actions=self.n_actions, action=self.action[:n_rounds_train], + action_test=self.action[n_rounds_train:], position=self.position[:n_rounds_train], + position_test=self.position[n_rounds_train:], reward=self.reward[:n_rounds_train], + reward_test=self.reward[n_rounds_train:], pscore=self.pscore[:n_rounds_train], + pscore_test=self.pscore[n_rounds_train:], context=self.context[:n_rounds_train], + context_test=self.context[n_rounds_train:], action_context=self.action_context, ) - bandit_feedback_eval = dict( - n_rounds=np.int(self.n_rounds - n_rounds_train), - n_actions=self.n_actions, - action=self.action[n_rounds_train:], - position=self.position[n_rounds_train:], - reward=self.reward[n_rounds_train:], - pscore=self.pscore[n_rounds_train:], - context=self.context[n_rounds_train:], - action_context=self.action_context, - ) - return bandit_feedback_train, bandit_feedback_eval else: return dict( n_rounds=self.n_rounds, @@ -235,7 +227,7 @@ def sample_bootstrap_bandit_feedback( test_size: float = 0.3, is_timeseries_split: bool = False, random_state: Optional[int] = None, - ) -> Union[BanditFeedback, Tuple[BanditFeedback, BanditFeedback]]: + ) -> BanditFeedback: """Obtain bootstrap logged bandit feedback. Parameters @@ -254,31 +246,14 @@ def sample_bootstrap_bandit_feedback( -------- bandit_feedback: BanditFeedback Logged bandit feedback sampled independently from the original data with replacement. - When `is_timeseries_split` is true, this method returns a tuple of - train and evaluation sets of bandit feedback, (bandit_feedback_train, bandit_feedback_eval) - where the train set is sampled independently from the original train data with replacement. """ - if is_timeseries_split: - ( - bandit_feedback_train, - bandit_feedback_eval, - ) = self.obtain_batch_bandit_feedback( - test_size=test_size, is_timeseries_split=is_timeseries_split - ) - n_rounds = bandit_feedback_train["n_rounds"] - random_ = check_random_state(random_state) - bootstrap_idx = random_.choice(n_rounds, size=n_rounds, replace=True) - for key_ in ["action", "position", "reward", "pscore", "context"]: - bandit_feedback_train[key_] = bandit_feedback_train[key_][bootstrap_idx] - return bandit_feedback_train, bandit_feedback_eval - else: - bandit_feedback = self.obtain_batch_bandit_feedback( - test_size=test_size, is_timeseries_split=is_timeseries_split - ) - n_rounds = bandit_feedback["n_rounds"] - random_ = check_random_state(random_state) - bootstrap_idx = random_.choice(n_rounds, size=n_rounds, replace=True) - for key_ in ["action", "position", "reward", "pscore", "context"]: - bandit_feedback[key_] = bandit_feedback[key_][bootstrap_idx] - return bandit_feedback + bandit_feedback = self.obtain_batch_bandit_feedback( + test_size=test_size, is_timeseries_split=is_timeseries_split + ) + n_rounds = bandit_feedback["n_rounds"] + random_ = check_random_state(random_state) + bootstrap_idx = random_.choice(np.arange(n_rounds), size=n_rounds, replace=True) + for key_ in ["action", "position", "reward", "pscore", "context"]: + bandit_feedback[key_] = bandit_feedback[key_][bootstrap_idx] + return bandit_feedback diff --git a/obp/ope/estimators.py b/obp/ope/estimators.py index afdc1715..c489e464 100644 --- a/obp/ope/estimators.py +++ b/obp/ope/estimators.py @@ -815,7 +815,7 @@ class SelfNormalizedDoublyRobust(DoublyRobust): .. math:: \\hat{V}_{\\mathrm{SNDR}} (\\pi_e; \\mathcal{D}, \\hat{q}) := - \\frac{\\mathbb{E}_{\\mathcal{D}}[\\hat{q}(x_t,\\pi_e) + w(x_t,a_t) (r_t - \\hat{q}(x_t,a_t))]}{\\mathbb{E}_{\\mathcal{D}}[ w(x_t,a_t) ]}, + \\mathbb{E}_{\\mathcal{D}} \\left[\\hat{q}(x_t,\\pi_e) + \\frac{w(x_t,a_t) (r_t - \\hat{q}(x_t,a_t))}{\\mathbb{E}_{\\mathcal{D}}[ w(x_t,a_t) ]} \\right], where :math:`\\mathcal{D}=\\{(x_t,a_t,r_t)\\}_{t=1}^{T}` is logged bandit feedback data with :math:`T` rounds collected by a behavior policy :math:`\\pi_b`. :math:`w(x,a):=\\pi_e (a|x)/\\pi_b (a|x)` is the importance weight given :math:`x` and :math:`a`. @@ -894,8 +894,8 @@ def _estimate_round_rewards( q_hat_factual = estimated_rewards_by_reg_model[ np.arange(n_rounds), action, position ] - estimated_rewards += iw * (reward - q_hat_factual) - return estimated_rewards / iw.mean() + estimated_rewards += iw * (reward - q_hat_factual) / iw.mean() + return estimated_rewards @dataclass