-
Notifications
You must be signed in to change notification settings - Fork 30
/
pruning.py
421 lines (346 loc) · 16 KB
/
pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
# Copyright 2022 Feedzai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union, Tuple, List
import numpy as np
import pandas as pd
from timeshap.explainer.kernel import TimeShapKernel
import os
import csv
from pathlib import Path
from timeshap.utils import convert_to_indexes, convert_data_to_3d
def calc_prun_indexes(df: pd.DataFrame,
tol: Union[float, int, list]
) -> pd.DataFrame:
"""Calculates the pruning indexes given pruning data and tolerances
Parameters
----------
df: pd.DataFrame
Pruning data to be analysed produced by `prune_all`
tol: Union[float, list]
The tolerances to analyze the pruning
Returns
-------
pd.DataFrame
"""
if "Tolerance" not in list(df.columns):
pruning_data = []
if not isinstance(tol, list):
tol = [tol]
for uuid in np.unique(df.iloc[:, -1].values):
uuid_data = df[df.iloc[:, -1] == uuid]
pruning_data.append([uuid, -1, -(uuid_data.shape[0] / 2) + 1])
for tolerance in tol:
pruning_idx = prune_given_data(uuid_data, tolerance)
pruning_data.append([uuid, tolerance, pruning_idx])
df = pd.DataFrame(pruning_data, columns=["Entity", 'Tolerance', 'Pruning idx'])
return df
def pruning_statistics(df: pd.DataFrame,
tol: Union[float, list],
) -> pd.DataFrame:
"""Calculates global pruning statistics with the given tolerances.
Parameters
----------
df: pd.DataFrame
Pruning data to be analysed produced by `prune_all`
tol: Union[float, list]
The tolerances to analyze the pruning
Returns
-------
pd.DataFrame
"""
if isinstance(tol, float):
tol = [tol]
if "Tolerance" not in list(df.columns):
df = calc_prun_indexes(df, tol)
resume = []
orig = df[df['Tolerance'] == -1]
for idx, row in orig.iterrows():
resume += [["Original", 'No Pruning', row["Entity"], -row['Pruning idx']]]
for tol in tol:
tolerance_sequences = df[df['Tolerance'] == tol]
for idx, row in tolerance_sequences.iterrows():
resume.append(["Pruning", tol, row["Entity"], -row['Pruning idx']])
resume_df = pd.DataFrame(resume, columns=["Algorithm", "Tolerance", "Entity", "Sequence Length"])
resume_df['Mean'] = resume_df['Sequence Length']
resume_df['Std'] = resume_df['Sequence Length']
resume_df = resume_df.groupby("Tolerance").agg({"Mean": "mean", "Std": "std"})
resume_df.reset_index(inplace=True)
resume_df = resume_df.rename(columns={'index': 'Tolerance'})
return resume_df
def prune_given_data(data: pd.DataFrame,
tolerance: float,
) -> int:
"""Calculates the pruning index to prune the sequence to with a given tolerance
Parameters
----------
data: pd.DataFrame
Dataframe containing the pruning algorithm information
tolerance: str
Tolerance to prun the sequence
Returns
-------
int
"""
data = data[data['Coalition'] == 'Sum of contribution of events \u2264 t']
if tolerance == 0:
# to filter float unprecision out
tolerance = 0.00000000001
respecting_lens = data[data['Shapley Value'].abs() <= tolerance]
if respecting_lens.shape[0] == 0:
return -data['t (event index)'].min()
return respecting_lens.iloc[0]['t (event index)']
def temp_coalition_pruning(f: Callable,
data: np.ndarray,
baseline: Union[np.ndarray, pd.DataFrame],
tolerance: float = None,
ret_plot_data=False,
verbose=False,
) -> Union[int, pd.DataFrame, Tuple[int, pd.DataFrame]]:
"""Temporal coalition pruning method
Parameters
----------
f: Callable[[np.ndarray], np.ndarray]im
Point of entry for model being explained.
This method receives a 3-D np.ndarray (#samples, #seq_len, #features).
This method returns a 2-D np.ndarray (#samples, 1).
data: numpy.ndarray
Input matrix to use. First element of the first dimension is explained,
using the rest of the elements as context/hidden state.
baseline: Union[np.ndarray, pd.DataFrame],
Dataset baseline. Median/Mean of numerical features and mode of categorical.
In case of np.array feature are assumed to be in order with `model_features`.
The baseline can be an average event or an average sequence
tolerance: float
Temporal coalition explainer tolerance.
Represents the maximum allowed Shapley Value of the older grouped events.
ret_plot_data: bool
If method returns pruning algorithm across the whole sequence
verbose: bool
If process is verbose
Returns
-------
Union[int, pd.DataFrame, Tuple[int, pd.DataFrame]]:
int:
Pruning index
pd.DataFrame
Pruning data over the whole sequence
Tuple[int, pd.DataFrame]]
Pruning index and Pruning data over the whole sequence
"""
if verbose:
print("Allowed importance for pruned events: {}".format(tolerance))
if ret_plot_data:
plot_data = []
pruning_idx = 0
for seq_len in range(data.shape[1], -1, -1):
explainer = TimeShapKernel(f, baseline, 0, "pruning")
shap_values = explainer.shap_values(data, pruning_idx=seq_len, **{'nsamples': 4})
if ret_plot_data:
plot_data += [['Sum of contribution of events \u003E t', -data.shape[1]+seq_len, shap_values[0]]]
plot_data += [['Sum of contribution of events \u2264 t', -data.shape[1]+seq_len, shap_values[1]]]
if verbose:
print("len {} | importance {}".format(-data.shape[1] + seq_len, shap_values[1]))
if tolerance and seq_len == data.shape[1] and abs(shap_values[1]) <= tolerance:
print("Unable to prune sequence.")
if seq_len < data.shape[1] and tolerance and abs(shap_values[1]) <= tolerance:
if pruning_idx == 0:
pruning_idx = -data.shape[1] + seq_len
if not ret_plot_data:
return pruning_idx
if tolerance is not None and pruning_idx == 0:
pruning_idx = -data.shape[1]
if tolerance is not None and ret_plot_data:
# used for plotting
return pruning_idx, pd.DataFrame(plot_data, columns=['Coalition', 't (event index)', 'Shapley Value'])
if tolerance is not None and not ret_plot_data:
# used for event level
return pruning_idx
return pd.DataFrame(plot_data, columns=['Coalition', 't (event index)', 'Shapley Value'])
def local_pruning(f: Callable[[np.ndarray], np.ndarray],
data: np.ndarray,
pruning_dict: dict,
baseline: Union[np.ndarray, pd.DataFrame],
entity_uuid: Union[str, int, float] = None,
entity_col: str = None,
verbose: bool = False,
) -> Tuple[pd.DataFrame, int]:
"""Method to prune a sequence or fetch the respective information if a path
is provided
Parameters
----------
f: Callable[[np.ndarray], np.ndarray]
Point of entry for model being explained.
This method receives a 3-D np.ndarray (#samples, #seq_len, #features).
This method returns a 2-D np.ndarray (#samples, 1).
data: numpy.ndarray
Input matrix to use. First element of the first dimension is explained,
using the rest of the elements as context/hidden state.
pruning_dict: dict
Information required for pruning algorithm
baseline: Union[np.ndarray, pd.DataFrame],
Dataset baseline. Median/Mean of numerical features and mode of categorical.
In case of np.array feature are assumed to be in order with `model_features`.
The baseline can be an average event or an average sequence
entity_uuid: Union[str, int, float]
The indentifier of the sequence that is being pruned.
Used when fetching information from a csv of explanations
entity_col: str
Column that contains the sequence identifiers
Used when fetching information from a csv of explanations
verbose: bool
If process is verbose
Returns
-------
Tuple[int, pd.DataFrame]]
Pruning index and Pruning data over the whole sequence
"""
def calculate_pruning():
if baseline is None:
raise ValueError("Baseline is not defined")
coal_prun_idx, coal_plot_data = temp_coalition_pruning(f,
data,
baseline,
pruning_dict['tol'],
ret_plot_data=True,
verbose=verbose)
return coal_prun_idx, coal_plot_data
if pruning_dict.get("path") is None or not os.path.exists(pruning_dict.get("path")):
#print("No path to explainer data provided. Calculating data")
if baseline is None:
raise ValueError("Baseline is not defined")
coal_prun_idx, coal_plot_data = calculate_pruning()
if pruning_dict.get("path") is not None:
# create directory
if '/' in pruning_dict.get("path"):
Path(pruning_dict.get("path").rsplit("/", 1)[0]).mkdir(parents=True, exist_ok=True)
coal_plot_data.to_csv(pruning_dict.get("path"), index=False)
elif pruning_dict.get("path") is not None and os.path.exists(pruning_dict.get("path")):
coal_plot_data = pd.read_csv(pruning_dict.get("path"))
if len(coal_plot_data.columns) > 3:
# global df
assert entity_uuid is not None, "When using a dataset with several instances, a uuid needs to be provided"
coal_plot_data = coal_plot_data[coal_plot_data[entity_col] == entity_uuid]
coal_prun_idx = prune_given_data(coal_plot_data, pruning_dict.get('tol'))
else:
raise ValueError('Unrecognized explainer procedure.')
return coal_plot_data, coal_prun_idx
def verify_pruning_dict(pruning_dict: dict):
"""Verifies the format of the pruning dict
Parameters
----------
pruning_dict: dict
"""
if pruning_dict.get('path'):
assert isinstance(pruning_dict.get('path'), str)
assert pruning_dict.get('tol', False), "Tolerance(s) must be provided on the pruning dict"
tolerances = pruning_dict.get('tol')
if isinstance(tolerances, float):
pruning_dict['tol'] = [tolerances]
elif isinstance(tolerances, list):
assert np.array([isinstance(x, float) for x in tolerances]).all(), "All provided tolerances must be floats."
else:
raise ValueError("Unsuported format of pruning tolerance(s). Please provide one tolerance or a list of them.")
def prune_all(f: Callable,
data: Union[List[np.ndarray], pd.DataFrame, np.array],
pruning_dict: dict,
baseline: Union[pd.DataFrame, np.array] = None,
model_features: List[Union[int, str]] = None,
schema: List[str] = None,
entity_col: Union[int, str] = None,
time_col: Union[int, str] = None,
append_to_files: bool = False,
verbose: bool = False,
) -> pd.DataFrame:
"""Applies pruning to a dataset
Parameters
----------
f: Callable[[np.ndarray], np.ndarray]
Point of entry for model being explained.
This method receives a 3-D np.ndarray (#samples, #seq_len, #features).
This method returns a 2-D np.ndarray (#samples, 1).
data: Union[pd.DataFrame, np.array]
Sequence to explain.
pruning_dict: dict
Information required for pruning algorithm
baseline: Union[np.ndarray, pd.DataFrame],
Dataset baseline. Median/Mean of numerical features and mode of categorical.
In case of np.array feature are assumed to be in order with `model_features`.
The baseline can be an average event or an average sequence
model_features: List[str]
In-order list of features to select and input to the model
schema: List[str]
Schema of provided data
entity_col: str
Column that contains the sequence identifiers
time_col: str
Data column that represents the time feature in order to sort sequences
temporally
append_to_files: bool
Append explanations to files if file already exists
verbose: bool
If process is verbose
Returns
-------
pd.DataFrame
"""
if schema is None and isinstance(data, pd.DataFrame):
schema = list(data.columns)
verify_pruning_dict(pruning_dict)
file_path = pruning_dict.get('path')
tolerances = list(np.unique(pruning_dict.get('tol')))
make_predictions = True
prun_data = None
if file_path is not None and os.path.exists(file_path):
prun_data = pd.read_csv(file_path)
make_predictions = False
# TODO resume explanations for missing entities
# necessary_entities = set(np.unique(data[entity_col].values))
# loaded_csv = pd.read_csv(file_path)
# present_entities = set(np.unique(loaded_csv[entity_col].values))
# if necessary_entities.issubset(present_entities):
# make_predictions = False
# prun_data = loaded_csv[loaded_csv[entity_col].isin(necessary_entities)]
if make_predictions:
ret_prun_data = []
names = ["Coalition", "t (event index)", "Shapley Value", entity_col if isinstance(entity_col, str) else "Entity"]
if file_path is not None:
if os.path.exists(file_path):
assert append_to_files, "The defined path for pruning data already exists and the append option is turned off. If you wish to append the explanations please use the flag `append_to_files`, otherwise change the provided path."
else:
if '/' in file_path:
Path(file_path.rsplit("/", 1)[0]).mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', newline='') as file:
writer = csv.writer(file, delimiter=',')
writer.writerow(names)
if time_col is None:
print("No time col provided, assuming dataset is ordered ascendingly by date")
model_features_index, entity_col_index, time_col_index = convert_to_indexes(model_features, schema, entity_col, time_col)
data = convert_data_to_3d(data, entity_col_index, time_col_index)
for sequence in data:
if entity_col is not None:
entity = sequence[0, 0, entity_col_index]
if model_features:
sequence = sequence[:, :, model_features_index]
sequence = sequence.astype(np.float64)
local_pruning_data = temp_coalition_pruning(f, sequence, baseline, None, ret_plot_data=True, verbose=verbose)
if entity_col is not None:
local_pruning_data["Entity"] = entity
ret_prun_data.append(local_pruning_data.values)
if file_path is not None:
with open(file_path, 'a', newline='') as file:
writer = csv.writer(file, delimiter=',')
writer.writerows(local_pruning_data.values)
prun_data = pd.DataFrame(np.concatenate(ret_prun_data), columns=names)
df = calc_prun_indexes(prun_data, tolerances)
return df