Skip to content

Commit

Permalink
Update; refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MNasert committed Sep 1, 2022
1 parent a7be2e0 commit 01dbd8c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 160 deletions.
5 changes: 3 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@


if __name__ == "__main__":
explanations = fe.explain("configs/quantile_dev.yml", to_json=True)
#print(explanations)
explanations = fe.explain("configs/mean_dev.yml")
for i in explanations:
print(i)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy
numba
transformers
thermostat-datasets
thermostat-datasets
imgkit
10 changes: 9 additions & 1 deletion src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def doit(self, modes: list = None, n_samples: int = None):
if "total order" in modes:
explanations["total order"] = t.verbalize_total_order(t.total_order(sample_array))

if "compare_searches" in modes:
explanations["compare searches"] = t.compare_searches(orders_and_searches, sample_array)

# TODO: Maybe detokenize input_ids using tokenizer from self?
if not self.dev:
return explanations, sample_array, None
Expand Down Expand Up @@ -247,9 +250,14 @@ def convolution_search(self, _dict, len_filters, sgn=None, metric=None):
return explanations, prepared_data

def compare_search(self, previous_searches, samples):
coincedences = t.compare_searches(previous_searches, samples)
coincedences = t.compare_search(previous_searches, samples)
return coincedences


def compare_searches(self, previous_searches, samples):
v = t.compare_searches(previous_searches, samples)
return v

def filter_verbalizations(self, verbalizations, samples, orders_and_searches, maxwords=100, mincoverage=.1, *args):
"""
Expand Down
146 changes: 59 additions & 87 deletions src/search_methods/fastexplain.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,16 @@ def explain_nodev(config, to_json=False):

for key in texts.keys():
cutoff_top_k_single = 5

txt = "\nSAMPLE:\n"
fmtd_tokens = []
for i, token in enumerate(texts[key]["input_ids"]):
if texts[key]["attributions"][i] >= sorted(
texts[key]["attributions"], reverse=True)[cutoff_top_k_single-1]:
fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD))
elif texts[key]["attributions"][i] > 0:
fmtd_tokens.append(color_str(token, Color.BOLD))
else:
fmtd_tokens.append(token)
txt += " ".join(fmtd_tokens)
c = 0
txt_ = ""
for i in txt:
c += 1
txt_ += i
if c > 150:
if i == " ":
txt_ += "\n"
c = 0
else:
pass
sample = txt_
txt = ""
# makeshift \n-ing
for expl_subclass in explanations.keys():
txt += "\nsubclass '{}'".format(expl_subclass)
_ = explanations[expl_subclass][key][:cutoff_top_k_single]
for __ in _:
txt += "\n"+__
txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect"
txt, sample_text = txt = to_string(explanations, texts, key, cutoff_top_k_single)
if to_json:
key_verbalization_attribs[key] = {"modelname": modelname,
"sample": sample,
key_verbalization_attribs[key] = {"sample": texts[key]["input_ids"],
"verbalization": txt,
"attributions": texts[key]["attributions"]}
else:
returnstr.append(sample + "\n" + txt)
returnstr.append(sample_text + "\n" + txt)

if to_json:
key_verbalization_attribs["modelname"] = modelname
res = json.dumps(key_verbalization_attribs)
return res
else:
Expand Down Expand Up @@ -126,73 +95,76 @@ def explain(config_path, to_json=False):
if not to_json:
returnstr = []

cutoff_top_k_single = 5
for key in texts.keys():
if key in valid_keys:
if thermo_config:
thermounit = thermo_config[key]
""" Only execute the code once! """
save_heatmap_as_image(thermounit.heatmap, filename=f"{thermo_config_name}/{key}.png")

sample = " ".join(texts[key]["input_ids"])
if key not in explanations["compare search"]:
continue
smv = explanations["compare search"][key]
"""
cutoff_top_k_single = 5
txt = "\nSAMPLE:\n"
fmtd_tokens = []
for i, token in enumerate(texts[key]["input_ids"]):
if texts[key]["attributions"][i] >= sorted(
texts[key]["attributions"], reverse=True)[cutoff_top_k_single-1]:
fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD))
elif texts[key]["attributions"][i] > 0:
fmtd_tokens.append(color_str(token, Color.BOLD))
else:
fmtd_tokens.append(token)
txt += " ".join(fmtd_tokens)
c = 0
txt_ = ""
for i in txt:
c += 1
txt_ += i
if c > 150:
if i == " ":
txt_ += "\n"
c = 0
else:
pass
sample = txt_
txt = ""
# makeshift \n-ing
for expl_subclass in explanations.keys():
txt += "\nsubclass '{}'".format(expl_subclass)
_ = explanations[expl_subclass][key][:cutoff_top_k_single]
for __ in _:
txt += "\n"+__
txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect"
"""
key_verbalization_attribs[key] = {}
txt, sample_text = to_string(explanations, texts, key, cutoff_top_k_single)
if to_json:
key_verbalization_attribs[key] = {"sample": sample,
"verbalization": smv,
key_verbalization_attribs[key] = {"sample": texts[key]["input_ids"],
"verbalization": txt,
"attributions": texts[key]["attributions"]}
else:
returnstr.append(sample + "\n" + smv)
returnstr.append(sample_text + "\n" + txt)

if to_json:
pd.DataFrame.from_dict(key_verbalization_attribs, orient="index").to_csv(
f"verbalizations/SMV_{modelname}.csv")

"""
key_verbalization_attribs["modelname"] = modelname
"""
res = json.dumps(key_verbalization_attribs)
return res
else:
return returnstr
else:
return explain_nodev(config, to_json)


def explain_json(config_path):
return explain(config_path, True)


def to_string(explanations, texts, key, cutoff_top_k_single):
txt = "\nSAMPLE:\n"
fmtd_tokens = []
for i, token in enumerate(texts[key]["input_ids"]):
if texts[key]["attributions"][i] >= sorted(
texts[key]["attributions"], reverse=True)[cutoff_top_k_single - 1]:
fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD))
elif texts[key]["attributions"][i] > 0:
fmtd_tokens.append(color_str(token, Color.BOLD))
else:
fmtd_tokens.append(token)
txt += " ".join(fmtd_tokens)
sample_text = txt
c = 0
txt_ = ""
for i in txt:
c += 1
txt_ += i
if c > 150:
if i == " ":
txt_ += "\n" # makeshift \n-ing
c = 0
else:
pass
sample = txt_
txt = ""
for expl_subclass in explanations.keys():
txt += "\nsubclass '{}'".format(expl_subclass)
_ = explanations[expl_subclass][key][:cutoff_top_k_single]
for __ in _:
txt += "\n" + __
txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect"
return txt, sample_text


def heatmap_com_verb(key, valid_keys, thermo_config, thermo_config_name, explanations):
if key in valid_keys:
#if thermo_config:
# thermounit = thermo_config[key]
# """ Only execute the code once! """
# save_heatmap_as_image(thermounit.heatmap, filename=f"{thermo_config_name}/{key}.png")

# sample = " ".join(texts[key]["input_ids"]) ##UNUSED?##
if key in explanations["compare searches"]:
smv = explanations["compare searches"][key]

return smv
Loading

0 comments on commit 01dbd8c

Please sign in to comment.