-
Notifications
You must be signed in to change notification settings - Fork 8
/
run_spflow.py
216 lines (192 loc) · 8.12 KB
/
run_spflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import gc
import time
import json
import random
import argparse
import numpy as np
from typing import Union, List, Tuple
from spn.algorithms.Sampling import sample_instances
from spn.structure.Base import Node
from spn.structure.leaves.cltree.CLTree import CLTree
from benchmark.utils import spflow_learn_binary_spn, spflow_learn_continuous_spn, spflow_learn_binary_clt
from spn.algorithms.Inference import log_likelihood
from spn.algorithms.MPE import mpe
from spn.structure.leaves.cltree.Inference import cltree_log_likelihood
from spn.structure.leaves.cltree.MPE import cltree_mpe
from deeprob.utils import DataStandardizer
from experiments.datasets import BINARY_DATASETS, CONTINUOUS_DATASETS, load_binary_dataset, load_continuous_dataset
def benchmark_log_likelihood(model: Union[Node, CLTree], data: np.ndarray) -> Tuple[List[float], np.ndarray]:
dts = list()
if isinstance(model, CLTree):
for i in range(args.num_reps):
start_time = time.perf_counter()
cltree_log_likelihood(model, data, dtype=np.float32)
end_time = time.perf_counter()
dts.append(end_time - start_time)
lls = cltree_log_likelihood(model, data, dtype=np.float32)
elif isinstance(model, Node):
for i in range(args.num_reps):
start_time = time.perf_counter()
log_likelihood(model, data, dtype=np.float32)
end_time = time.perf_counter()
dts.append(end_time - start_time)
lls = log_likelihood(model, data, dtype=np.float32)
else:
raise ValueError("Unknown model")
return dts, lls
def benchmark_mpe(model: Union[Node, CLTree], data: np.ndarray) -> Tuple[List[float], np.ndarray]:
dts = list()
if isinstance(model, CLTree):
logprobs = np.empty(len(mar_data), dtype=np.float32)
for i in range(args.num_reps):
start_time = time.perf_counter()
mar_data_copied = mar_data.copy()
cltree_mpe(model, mar_data_copied, logprobs=logprobs)
end_time = time.perf_counter()
dts.append(end_time - start_time)
mpe_data = mar_data.copy()
cltree_mpe(model, mpe_data, logprobs=logprobs)
lls = cltree_log_likelihood(model, mpe_data.astype(np.int64), dtype=np.float32)
elif isinstance(model, Node):
for i in range(args.num_reps):
start_time = time.perf_counter()
mpe(model, data, in_place=False)
end_time = time.perf_counter()
dts.append(end_time - start_time)
mpe_data = mpe(model, data, in_place=False)
lls = log_likelihood(model, mpe_data, dtype=np.float32)
else:
raise ValueError("Unknown model")
return dts, lls
def benchmark_csampling(model: Node, data: np.ndarray) -> Tuple[List[float], np.ndarray]:
assert not isinstance(model, CLTree)
dts = list()
rand_gen = np.random.RandomState(42)
if isinstance(model, Node):
for i in range(args.num_reps):
start_time = time.perf_counter()
sample_instances(model, data, rand_gen=rand_gen, in_place=False)
end_time = time.perf_counter()
dts.append(end_time - start_time)
sampled_data = sample_instances(model, data, rand_gen=rand_gen, in_place=False)
lls = log_likelihood(model, sampled_data, dtype=np.float32)
else:
raise ValueError("Unknown model")
return dts, lls
def benchmark_learnclt(data: np.ndarray) -> List[float]:
dts = list()
for i in range(args.num_reps):
start_time = time.perf_counter()
spflow_learn_binary_clt(data)
end_time = time.perf_counter()
dts.append(end_time - start_time)
return dts
if __name__ == '__main__':
# Parse the arguments
parser = argparse.ArgumentParser(
description="SPFlow==0.0.41 Benchmark"
)
parser.add_argument(
'model', choices=['spn', 'binary-clt'], help="The model to benchmark"
)
parser.add_argument(
'dataset', choices=BINARY_DATASETS + CONTINUOUS_DATASETS, help="The dataset"
)
parser.add_argument(
'--num-reps', type=int, default=10, help="Number of repetitions"
)
parser.add_argument(
'--mar-prob', type=float, default=0.5,
help="Marginalization probability (used to benchmark marginal queries and sampling)"
)
parser.add_argument(
'--algs', type=str, help="The algorithms to benchmark, separated by a dot:\n"
+ "Complete Evidence (evi), Marginal (mar), Most Probable Explaination (mpe), "
+ "Conditional Sampling (csampling), Learn Chow-Liu Tree (learnclt)",
default="evi.mar.mpe.csampling"
)
parser.add_argument(
'--out-filepath', type=str, help="JSON results filepath, defaults to deeprob-{model}-{dataset}.json",
default=""
)
parser.add_argument(
'--verbose', dest='verbose', action='store_true', help="Whether to enable verbose mode."
)
args = parser.parse_args()
# Check arguments
if args.model == 'binary-clt' and args.dataset in CONTINUOUS_DATASETS:
raise ValueError("Cannot benchmark Binary-CLT on a continuous dataset")
if args.model != 'binary-clt' and 'learnclt' in args.algs:
raise ValueError("Cannot benchmark `learnclt` algorithm on a non Binary-CLT model")
if args.model == 'binary-clt' and 'csampling' in args.algs:
raise ValueError("Cannot benchmark `csampling` algorithm on a Binary-CLT model")
if args.mar_prob <= 0.0 or args.mar_prob >= 1.0:
raise ValueError("Invalid marginalization probability")
# Set always the same seed
random.seed(42)
np.random.seed(42)
# Load the dataset
if args.verbose:
print(f"Preparing {args.dataset} ...")
if args.dataset in BINARY_DATASETS:
data, _, _ = load_binary_dataset(
'../experiments/datasets', args.dataset, raw=True
)
else:
transform = DataStandardizer()
data, _, _ = load_continuous_dataset(
'../experiments/datasets', args.dataset, raw=True, random_state=args.seed
)
transform.fit(data)
data = transform.forward(data)
# Marginalize some variables randomly with 0.5 probability
random_state = np.random.RandomState(1234)
mar_data = data.copy().astype(np.float32, copy=False)
mar_data[random_state.rand(*mar_data.shape) <= args.mar_prob] = np.nan
# Initialize the model
if args.verbose:
print(f"Initializing {args.model} ...")
if args.model == 'spn':
if args.dataset in BINARY_DATASETS:
model = spflow_learn_binary_spn(data)
else:
model = spflow_learn_continuous_spn(data)
elif args.model == 'binary-clt':
model = spflow_learn_binary_clt(data)
else:
raise ValueError("Unknown model name")
# The results dictionary
results = {
'model': args.model,
'dataset': args.dataset,
'num-reps': args.num_reps
}
# Benchmark algorithms
for alg in args.algs.split('.'):
if args.verbose:
print("Benchmarking {} ...".format(alg))
gc.disable()
if alg == 'evi':
dts, lls = benchmark_log_likelihood(model, data)
elif alg == 'mar':
dts, lls = benchmark_log_likelihood(model, mar_data)
elif alg == 'mpe':
dts, lls = benchmark_mpe(model, mar_data)
elif alg == 'csampling':
dts, lls = benchmark_csampling(model, mar_data)
elif alg == 'learnclt':
dts = benchmark_learnclt(data)
else:
raise ValueError("Unknown algorithm identifier")
gc.enable()
results[alg] = {'dt': {'mu': np.mean(dts), 'std': 2.0 * np.std(dts)}}
if alg not in ['learnclt']:
results[alg].update({'ll': {'mu': np.mean(lls).item(), 'std': 2.0 * np.std(lls).item()}})
# Save the benchmark results to file
out_filepath = args.out_filepath
if out_filepath == "":
out_filepath = f"spflow==0.0.41-{args.model}-{args.dataset}.json"
with open(out_filepath, 'w') as f:
results = json.loads(json.dumps(results), parse_float=lambda x: round(float(x), 2))
json.dump(results, f, indent=2)
print(f"Saved benchmark results to {out_filepath}")