Skip to content

Commit

Permalink
Merge pull request #25 from st-tech/fix/def-sndr
Browse files Browse the repository at this point in the history
Hotfix: fix the definition of the SNDR estimator
  • Loading branch information
usaito authored Nov 12, 2020
2 parents cd919bc + 8ad89a2 commit fc7d0c2
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 291 deletions.
6 changes: 6 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Benchmark Experiments
---
This directory includes some benchmark experiments and demonstrations about off-policy evaluation using [the full size Open Bandit Dataset](https://research.zozo.com/data.html). The detailed description, results, and discussions can be found in [the relevant paper](https://arxiv.org/abs/2008.07146).

- `cf_policy_search`: counterfactual policy search using OPE
- `ope`: estimation performance comparisons on a variety of OPE estimators
70 changes: 33 additions & 37 deletions benchmark/cf_policy_search/run_cf_policy_search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
from pathlib import Path
import yaml
import time

import pandas as pd
import numpy as np
from pandas import DataFrame
from joblib import Parallel, delayed
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.experimental import enable_hist_gradient_boosting
Expand All @@ -13,7 +13,6 @@
from custom_dataset import OBDWithInteractionFeatures
from obp.policy import IPWLearner
from obp.ope import InverseProbabilityWeighting
from obp.utils import estimate_confidence_interval_by_bootstrap

# hyperparameter for the regression model used in model dependent OPE estimators
with open("./conf/hyperparams.yaml", "rb") as f:
Expand All @@ -28,10 +27,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="run evaluation policy selection.")
parser.add_argument(
"--n_boot_samples",
"--n_runs",
type=int,
default=5,
help="number of bootstrap samples in the experiment.",
help="number of bootstrap sampling in the experiment.",
)
parser.add_argument(
"--context_set",
Expand Down Expand Up @@ -67,17 +66,24 @@
default=0.5,
help="the proportion of the dataset to include in the test split.",
)
parser.add_argument(
"--n_jobs",
type=int,
default=1,
help="the maximum number of concurrently running jobs.",
)
parser.add_argument("--random_state", type=int, default=12345)
args = parser.parse_args()
print(args)

# configurations
n_boot_samples = args.n_boot_samples
n_runs = args.n_runs
context_set = args.context_set
base_model = args.base_model
behavior_policy = args.behavior_policy
campaign = args.campaign
test_size = args.test_size
n_jobs = args.n_jobs
random_state = args.random_state
np.random.seed(random_state)
data_path = Path("../open_bandit_dataset")
Expand All @@ -89,8 +95,8 @@
data_path=data_path,
context_set=context_set,
)
# define a evaluation policy
evaluation_policy = IPWLearner(
# define a counterfactual policy based on IPWLearner
counterfactual_policy = IPWLearner(
base_model=base_model_dict[base_model](**hyperparams[base_model]),
n_actions=obd.n_actions,
len_list=obd.len_list,
Expand All @@ -107,54 +113,44 @@
is_timeseries_split=True,
)

start = time.time()
ope_results = np.zeros(n_boot_samples)
for b in np.arange(n_boot_samples):
def process(b: int):
# sample bootstrap from batch logged bandit feedback
boot_bandit_feedback = obd.sample_bootstrap_bandit_feedback(
test_size=test_size, is_timeseries_split=True, random_state=b
)
# train an evaluation on the training set of the logged bandit feedback data
action_dist = evaluation_policy.fit(
action_dist = counterfactual_policy.fit(
context=boot_bandit_feedback["context"],
action=boot_bandit_feedback["action"],
reward=boot_bandit_feedback["reward"],
pscore=boot_bandit_feedback["pscore"],
position=boot_bandit_feedback["position"],
)
# make action selections (predictions)
action_dist = evaluation_policy.predict(
action_dist = counterfactual_policy.predict(
context=boot_bandit_feedback["context_test"]
)
# estimate the policy value of a given counterfactual algorithm by the three OPE estimators.
ipw = InverseProbabilityWeighting()
ope_results[b] = (
ipw.estimate_policy_value(
reward=boot_bandit_feedback["reward_test"],
action=boot_bandit_feedback["action_test"],
position=boot_bandit_feedback["position_test"],
pscore=boot_bandit_feedback["pscore_test"],
action_dist=action_dist,
)
/ ground_truth
return ipw.estimate_policy_value(
reward=boot_bandit_feedback["reward_test"],
action=boot_bandit_feedback["action_test"],
position=boot_bandit_feedback["position_test"],
pscore=boot_bandit_feedback["pscore_test"],
action_dist=action_dist,
)

print(f"{b+1}th iteration: {np.round((time.time() - start) / 60, 2)}min")
ope_results_dict = estimate_confidence_interval_by_bootstrap(
samples=ope_results, random_state=random_state
processed = Parallel(backend="multiprocessing", n_jobs=n_jobs, verbose=50,)(
[delayed(process)(i) for i in np.arange(n_runs)]
)
ope_results_dict["mean(no-boot)"] = ope_results.mean()
ope_results_dict["std"] = np.std(ope_results, ddof=1)
ope_results_df = pd.DataFrame(ope_results_dict, index=["ipw"])

# calculate estimated policy value relative to that of the behavior policy
print("=" * 70)
print(f"random_state={random_state}: evaluation policy={policy_name}")
print("-" * 70)
print(ope_results_df)
print("=" * 70)

# save evaluation policy evaluation results in `./logs` directory
# save counterfactual policy evaluation results in `./logs` directory
ope_results = np.zeros((n_runs, 2))
for b, estimated_policy_value_b in enumerate(processed):
ope_results[b, 0] = estimated_policy_value_b
ope_results[b, 1] = estimated_policy_value_b / ground_truth
save_path = Path("./logs") / behavior_policy / campaign
save_path.mkdir(exist_ok=True, parents=True)
ope_results_df.to_csv(save_path / f"{policy_name}.csv")
DataFrame(
ope_results, columns=["policy_value", "relative_policy_value"]
).describe().round(6).to_csv(save_path / f"{policy_name}.csv")
153 changes: 137 additions & 16 deletions benchmark/ope/README.md
Original file line number Diff line number Diff line change
@@ -1,45 +1,166 @@
# Benchmarking Off-Policy Evaluation

## Description
We use the (full size) open bandit dataset to evaluate and compare OPE estimators in a *realistic* and *reproducible* manner. Specifically, we evaluate the estimation performances of a wide variety of existing estimators by comparing the estimated policy values with the ground-truth of an evaluation policy contained in the data.

### Dataset
Please download the full [open bandit dataset](https://research.zozo.com/data.html) and put it as the `../open_bandit_dataset/` directory.

## Training Regression Model

Model-dependent estimators such as DM and DR need a pre-trained regression model.
Here, we train a regression model with some machine learning methods.

We define hyperparameters for the machine learning methods in [`conf/hyperparams.yaml`](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/conf/hyperparams.yaml).
[train_regression_model.py](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/train_regression_model.py) implements the training process of the regression model.

```
for model in logistic_regression
python train_regression_model.py\
--n_runs $n_runs\
--base_model $base_model\ "logistic_regression" or "lightgbm"
--behavior_policy $behavior_policy\ "random" or "bts"
--campaign $campaign\ # "men", "women", or "all"
--n_sim_to_compute_action_dist $n_sim_to_compute_action_dist\
--is_timeseries_split $is_timeseries_split\ # in-sample or out-sample
--test_size $test_size\
--is_mrdr $is_mrdr\ # use "more robust doubly robust" option or not
--n_jobs $n_jobs\
--random_state $random_state
```

where
- `$n_runs` specifies the number of simulation runs with different bootstrap samples in the experiment.
- `$base_model` specifies the base ML model for defining the regression model and should be one of "logistic_regression", "random_forest", or "lightgbm".
- `$campaign` specifies the campaign considered in ZOZOTOWN and should be one of "all", "men", or "women".
- `$n_sim_to_compute_action_dist` is the number of monte carlo simulation to compute the action choice probabilities by a given evaluation policy.
- `$is_timeseries_split` is whether the data is split based on timestamp or not. If true, the out-sample performance of OPE is tested. See the relevant paper for details.
- - `$test_size` specifies the proportion of the dataset to include in the test split when `$is_timeseries_split=True`.
- `$is_mrdr` is whether the regression model is trained by the more robust doubly robust way or not. See the relevant paper for details.
- `$n_jobs` is the maximum number of concurrently running jobs.

For example, the following command trains the regression model based on logistic regression on the logged bandit feedback data collected by the Random policy (as a behavior policy) in "All" campaign.

```bash
python train_regression_model.py\
--n_runs 10\
--base_model logistic_regression\
--behavior_policy random\
--campaign all\
--is_mrdr False\
--is_timeseries_split False
```

<!--
```
for model in random_forest
do
for pi_b in random
for pi_b in bts
do
for camp in men
for camp in all
do
python train_regression_model.py\
--n_boot_samples 5\
--base_model $model\
--behavior_policy $pi_b\
--campaign $camp
for is_mrdr in True False
do
for is_timeseries in True False
do
python train_regression_model.py\
--n_runs 30\
--base_model $model\
--behavior_policy $pi_b\
--campaign $camp\
--is_mrdr $is_mrdr\
--n_jobs 1\
--is_timeseries_split $is_timeseries
done
done
done
done
done
```
``` -->


## Evaluating Off-Policy Estimators

Next, we evaluate and compare the estimation performances of the following OPE estimators:

- Direct Method (DM)
- Inverse Probability Weighting (IPW)
- Self-Normalized Inverse Probability Weighting (SNIPW)
- Doubly Robust (DR)
- Self-Normalized Doubly Robust (SNDR)
- Switch Doubly Robust (Switch-DR)
- Doubly Robust with Optimistic Shrinkage (DRos)
- More Robust Doubly Robust (MRDR)

For Switch-DR and DRos, we test some different values of hyperparameters.
See our [documentation](https://zr-obp.readthedocs.io/en/latest/estimators.html) for the details about these estimators.


[benchmark_off_policy_estimators.py](https://github.com/st-tech/zr-obp/blob/master/benchmark/ope/benchmark_off_policy_estimators.py) implements the evaluation and comparison of OPE estimators using the open bandit dataset.
Note that you have to finish training a regression model (see the above section) before conducting the evaluation of OPE in the corresponding setting.
We summarize the detailed experimental protocol for evaluating OPE estimators using real-world data [here](https://zr-obp.readthedocs.io/en/latest/evaluation_ope.html).

```
# run evaluation of OPE estimators with the full open bandit dataset
python benchmark_off_policy_estimators.py\
--n_runs $n_runs\
--base_model $base_model\ "logistic_regression" or "lightgbm"
--behavior_policy $behavior_policy\ "random" or "bts"
--campaign $campaign\ # "men", "women", or "all"
--n_sim_to_compute_action_dist $n_sim_to_compute_action_dist\
--is_timeseries_split\ # in-sample or out-sample
--test_size $test_size\
--n_jobs $n_jobs\
--random_state $random_state
```
where
- `$n_runs` specifies the number of simulation runs with different bootstrap samples in the experiment to estimate standard deviations of the performance of OPE estimators.
- $base_model_for_evaluation_policy specifies the base ML model for defining the regression model and should be one of "logistic_regression", "random_forest", or "lightgbm".
- `$campaign` specifies the campaign considered in ZOZOTOWN and should be one of "all", "men", or "women".
- `$n_sim_to_compute_action_dist` is the number of monte carlo simulation to compute the action choice probabilities by a given evaluation policy.
- `$is_timeseries_split` is whether the data is split based on timestamp or not. If true, the out-sample performance of OPE is tested. See the relevant paper for details.
- `$test_size` specifies the proportion of the dataset to include in the test split when `$is_timeseries_split=True`.
- `$n_jobs` is the maximum number of concurrently running jobs.

For example, the following command compares the estimation performances of the OPE estimators listed above using Bernoulli TS as an evaluation policy and Random as a behavior policy in "All" campaign in the out-sample situation.

```bash
python benchmark_off_policy_estimators.py\
--n_runs 10\
--base_model logistic_regression\
--behavior_policy random\
--campaign all\
--test_size 0.3\
--is_timeseries_split True
```

The results of our benchmark experiments can be found in Section 5 of [our paper](https://arxiv.org/abs/2008.07146).

<!--
```
for model in logistic_regression
do
for pi_b in random
do
for camp in men
for camp in women all
do
python benchmark_off_policy_estimators.py\
--n_boot_samples 5\
--base_model $model\
--behavior_policy $pi_b\
--campaign $camp
for is_timeseries in True False
do
python benchmark_off_policy_estimators.py\
--n_runs 30\
--base_model $model\
--behavior_policy $pi_b\
--campaign $camp\
--n_jobs 10\
--is_timeseries_split $is_timeseries
done
done
done
done
```
-->

<!-- ## Results
## Results
We report the results of the benchmark experiments on the three campaigns (all, men, women) in the following tables.
We describe **Random -> Bernoulli TS** to represent the OPE situation where we use Bernoulli TS as a hypothetical evaluation policy and Random as a hypothetical behavior policy.
In contrast, we use **Bernoulli TS -> Random** to represent the situation where we use Random as a hypothetical evaluation policy and Bernoulli TS as a hypothetical behavior policy. -->
Loading

0 comments on commit fc7d0c2

Please sign in to comment.