From 01dbd8c8815a4579550f0acf79cf5dbd8aa0aa40 Mon Sep 17 00:00:00 2001 From: Maximilian Nasert Date: Thu, 1 Sep 2022 15:34:21 +0200 Subject: [PATCH] Update; refactor --- demo.py | 5 +- requirements.txt | 3 +- src/dataloader.py | 10 +- src/search_methods/fastexplain.py | 146 +++++++++++----------------- src/search_methods/tools.py | 156 +++++++++++++++++------------- 5 files changed, 160 insertions(+), 160 deletions(-) diff --git a/demo.py b/demo.py index 9cdf645..948fed7 100644 --- a/demo.py +++ b/demo.py @@ -2,5 +2,6 @@ if __name__ == "__main__": - explanations = fe.explain("configs/quantile_dev.yml", to_json=True) - #print(explanations) + explanations = fe.explain("configs/mean_dev.yml") + for i in explanations: + print(i) diff --git a/requirements.txt b/requirements.txt index 2e078ab..6a0a0c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy numba transformers -thermostat-datasets \ No newline at end of file +thermostat-datasets +imgkit \ No newline at end of file diff --git a/src/dataloader.py b/src/dataloader.py index ecdb311..8d3aad3 100644 --- a/src/dataloader.py +++ b/src/dataloader.py @@ -217,6 +217,9 @@ def doit(self, modes: list = None, n_samples: int = None): if "total order" in modes: explanations["total order"] = t.verbalize_total_order(t.total_order(sample_array)) + if "compare_searches" in modes: + explanations["compare searches"] = t.compare_searches(orders_and_searches, sample_array) + # TODO: Maybe detokenize input_ids using tokenizer from self? if not self.dev: return explanations, sample_array, None @@ -247,9 +250,14 @@ def convolution_search(self, _dict, len_filters, sgn=None, metric=None): return explanations, prepared_data def compare_search(self, previous_searches, samples): - coincedences = t.compare_searches(previous_searches, samples) + coincedences = t.compare_search(previous_searches, samples) return coincedences + + def compare_searches(self, previous_searches, samples): + v = t.compare_searches(previous_searches, samples) + return v + def filter_verbalizations(self, verbalizations, samples, orders_and_searches, maxwords=100, mincoverage=.1, *args): """ diff --git a/src/search_methods/fastexplain.py b/src/search_methods/fastexplain.py index 4161842..f01e0fa 100644 --- a/src/search_methods/fastexplain.py +++ b/src/search_methods/fastexplain.py @@ -42,47 +42,16 @@ def explain_nodev(config, to_json=False): for key in texts.keys(): cutoff_top_k_single = 5 - - txt = "\nSAMPLE:\n" - fmtd_tokens = [] - for i, token in enumerate(texts[key]["input_ids"]): - if texts[key]["attributions"][i] >= sorted( - texts[key]["attributions"], reverse=True)[cutoff_top_k_single-1]: - fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD)) - elif texts[key]["attributions"][i] > 0: - fmtd_tokens.append(color_str(token, Color.BOLD)) - else: - fmtd_tokens.append(token) - txt += " ".join(fmtd_tokens) - c = 0 - txt_ = "" - for i in txt: - c += 1 - txt_ += i - if c > 150: - if i == " ": - txt_ += "\n" - c = 0 - else: - pass - sample = txt_ - txt = "" - # makeshift \n-ing - for expl_subclass in explanations.keys(): - txt += "\nsubclass '{}'".format(expl_subclass) - _ = explanations[expl_subclass][key][:cutoff_top_k_single] - for __ in _: - txt += "\n"+__ - txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect" + txt, sample_text = txt = to_string(explanations, texts, key, cutoff_top_k_single) if to_json: - key_verbalization_attribs[key] = {"modelname": modelname, - "sample": sample, + key_verbalization_attribs[key] = {"sample": texts[key]["input_ids"], "verbalization": txt, "attributions": texts[key]["attributions"]} else: - returnstr.append(sample + "\n" + txt) + returnstr.append(sample_text + "\n" + txt) if to_json: + key_verbalization_attribs["modelname"] = modelname res = json.dumps(key_verbalization_attribs) return res else: @@ -126,66 +95,20 @@ def explain(config_path, to_json=False): if not to_json: returnstr = [] + cutoff_top_k_single = 5 for key in texts.keys(): if key in valid_keys: - if thermo_config: - thermounit = thermo_config[key] - """ Only execute the code once! """ - save_heatmap_as_image(thermounit.heatmap, filename=f"{thermo_config_name}/{key}.png") - - sample = " ".join(texts[key]["input_ids"]) - if key not in explanations["compare search"]: - continue - smv = explanations["compare search"][key] - """ - cutoff_top_k_single = 5 - - txt = "\nSAMPLE:\n" - fmtd_tokens = [] - for i, token in enumerate(texts[key]["input_ids"]): - if texts[key]["attributions"][i] >= sorted( - texts[key]["attributions"], reverse=True)[cutoff_top_k_single-1]: - fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD)) - elif texts[key]["attributions"][i] > 0: - fmtd_tokens.append(color_str(token, Color.BOLD)) - else: - fmtd_tokens.append(token) - txt += " ".join(fmtd_tokens) - c = 0 - txt_ = "" - for i in txt: - c += 1 - txt_ += i - if c > 150: - if i == " ": - txt_ += "\n" - c = 0 - else: - pass - sample = txt_ - txt = "" - # makeshift \n-ing - for expl_subclass in explanations.keys(): - txt += "\nsubclass '{}'".format(expl_subclass) - _ = explanations[expl_subclass][key][:cutoff_top_k_single] - for __ in _: - txt += "\n"+__ - txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect" - """ + key_verbalization_attribs[key] = {} + txt, sample_text = to_string(explanations, texts, key, cutoff_top_k_single) if to_json: - key_verbalization_attribs[key] = {"sample": sample, - "verbalization": smv, + key_verbalization_attribs[key] = {"sample": texts[key]["input_ids"], + "verbalization": txt, "attributions": texts[key]["attributions"]} else: - returnstr.append(sample + "\n" + smv) + returnstr.append(sample_text + "\n" + txt) if to_json: - pd.DataFrame.from_dict(key_verbalization_attribs, orient="index").to_csv( - f"verbalizations/SMV_{modelname}.csv") - - """ key_verbalization_attribs["modelname"] = modelname - """ res = json.dumps(key_verbalization_attribs) return res else: @@ -193,6 +116,55 @@ def explain(config_path, to_json=False): else: return explain_nodev(config, to_json) + def explain_json(config_path): return explain(config_path, True) + +def to_string(explanations, texts, key, cutoff_top_k_single): + txt = "\nSAMPLE:\n" + fmtd_tokens = [] + for i, token in enumerate(texts[key]["input_ids"]): + if texts[key]["attributions"][i] >= sorted( + texts[key]["attributions"], reverse=True)[cutoff_top_k_single - 1]: + fmtd_tokens.append(color_str(color_str(token, Color.RED), Color.BOLD)) + elif texts[key]["attributions"][i] > 0: + fmtd_tokens.append(color_str(token, Color.BOLD)) + else: + fmtd_tokens.append(token) + txt += " ".join(fmtd_tokens) + sample_text = txt + c = 0 + txt_ = "" + for i in txt: + c += 1 + txt_ += i + if c > 150: + if i == " ": + txt_ += "\n" # makeshift \n-ing + c = 0 + else: + pass + sample = txt_ + txt = "" + for expl_subclass in explanations.keys(): + txt += "\nsubclass '{}'".format(expl_subclass) + _ = explanations[expl_subclass][key][:cutoff_top_k_single] + for __ in _: + txt += "\n" + __ + txt += "\nPrediction was correct." if texts[key]["was_correct"] else "\nPredicton was incorrect" + return txt, sample_text + + +def heatmap_com_verb(key, valid_keys, thermo_config, thermo_config_name, explanations): + if key in valid_keys: + #if thermo_config: + # thermounit = thermo_config[key] + # """ Only execute the code once! """ + # save_heatmap_as_image(thermounit.heatmap, filename=f"{thermo_config_name}/{key}.png") + + # sample = " ".join(texts[key]["input_ids"]) ##UNUSED?## + if key in explanations["compare searches"]: + smv = explanations["compare searches"][key] + + return smv \ No newline at end of file diff --git a/src/search_methods/tools.py b/src/search_methods/tools.py index fb3bb64..fef5cdb 100644 --- a/src/search_methods/tools.py +++ b/src/search_methods/tools.py @@ -256,6 +256,57 @@ def verbalize_field_span_search(prepared_data, samples, sgn="+"): return verbalization_dict +def compare_search(searches: dict, samples): + """ + + :param searches: + :param samples: + :return: + """ + search_types = searches.keys() + coincidences = {} + for subclass in search_types: + for subclass_2 in search_types: + if subclass == subclass_2: + pass + else: + for sample_key in searches[subclass].keys(): + sum_values = 0 + for i in samples[sample_key]["attributions"]: + sum_values += i if i > 0 else 0 + + _ = [] + for value_1 in searches[subclass][sample_key]["indices"]: + for value_2 in searches[subclass_2][sample_key]["indices"]: + if value_1 is None or value_2 is None: + continue + if value_1 == value_2 and value_1 not in coincidences.items(): + _.append(value_1) + + verbalizations = [] + for snippet in _: + verbalization = "snippet: '" + snippet_tokens = [] + for word_index in snippet: + if word_index is not None: + snippet_tokens.append(samples[sample_key]["input_ids"][word_index]) #.replace("▁", " ") + verbalization += ' '.join(snippet_tokens) + try: + verbalization += "' occurs in all searches and accounts for {}% of prediction score".format( + str(round( + (sum([samples[sample_key]["attributions"][i] for i in snippet])/sum_values)*100, 2 + ))) + except Exception as e: + pass + + verbalizations.append(verbalization) + if not verbalizations: + verbalizations = ["No snippet occurs in all searches simultaneously"] + coincidences[sample_key] = verbalizations + + return coincidences + + def compare_searches(searches: dict, samples): """ @@ -263,29 +314,18 @@ def compare_searches(searches: dict, samples): :param samples: :return: """ - #search_types = searches.keys() - def coverage(span, attributions): - if span[0]: - pos_att_sum = sum([float(a) if a > 0 else 0 for a in attributions]) - if pos_att_sum > 0: - return sum([attributions[w] for w in span]) / pos_att_sum - return 0 + # search_types = searches.keys() ##UNUSED## - sample_info = [] + # sample_info = [] ##UNUSED## verbalized_explanations = {} for sample_key in tqdm(searches[list(searches.keys())[0]].keys()): sample_atts = samples[sample_key]["attributions"] input_ids = samples[sample_key]["input_ids"] candidates = defaultdict(dict) - def explore_search(search_type): - candidates[search_type] = {} - for indices in searches[search_type][sample_key]["indices"]: - candidates[search_type][','.join([str(idx) for idx in indices])] = coverage(indices, sample_atts) - for stype in list(searches.keys()): - explore_search(stype) + explore_search(candidates, stype, searches, sample_key, sample_atts) for i, attr in enumerate(sample_atts): candidates['total search'][str(i)] = coverage([i], sample_atts) @@ -296,20 +336,9 @@ def explore_search(search_type): combined_candidate_indices = [] - def combine_results(result_dict): - for idx_cov_tuple in result_dict: - if ',' in idx_cov_tuple[0]: - indices = idx_cov_tuple[0].split(',') - else: - indices = [idx_cov_tuple[0]] - for idx in indices: - if idx == 'None': - continue - if int(idx) not in combined_candidate_indices: - combined_candidate_indices.append(int(idx)) - combine_results(conv_top5) - combine_results(span_top5) - combine_results(total_top5) + combine_results(conv_top5, combined_candidate_indices) + combine_results(span_top5, combined_candidate_indices) + combine_results(total_top5, combined_candidate_indices) final_spans = [] for i in sorted(combined_candidate_indices): @@ -363,54 +392,43 @@ def combine_results(result_dict): # TODO: The span.replace(...) has to be different for other models/tokenizers # BERT - #verbalized_explanations[sample_key] = " ".join([span.replace(' ##', '') for i, span in ranked_spans]) + # verbalized_explanations[sample_key] = " ".join([span.replace(' ##', '') for i, span in ranked_spans]) # RoBERTa verbalized_explanations[sample_key] = " ".join([span for i, span in ranked_spans]) - #sample_info.append(samples[sample_key]) + # sample_info.append(samples[sample_key]) + return verbalized_explanations - """ - coincidences = {} - for subclass in search_types: - for subclass_2 in search_types: - if subclass == subclass_2: - pass - else: - for sample_key in searches[subclass].keys(): - sum_values = 0 - for i in samples[sample_key]["attributions"]: - sum_values += i if i > 0 else 0 - _ = [] - for value_1 in searches[subclass][sample_key]["indices"]: - for value_2 in searches[subclass_2][sample_key]["indices"]: - if value_1 is None or value_2 is None: - continue - if value_1 == value_2 and value_1 not in coincidences.items(): - _.append(value_1) +@jit(nopython=True) +def explore_search(candidates, search_type, searches, sample_key, sample_atts): + candidates[search_type] = {} + for indices in searches[search_type][sample_key]["indices"]: + candidates[search_type][','.join([str(idx) for idx in indices])] = coverage(indices, sample_atts) + return candidates - verbalizations = [] - for snippet in _: - verbalization = "snippet: '" - snippet_tokens = [] - for word_index in snippet: - if word_index is not None: - snippet_tokens.append(samples[sample_key]["input_ids"][word_index]) #.replace("▁", " ") - verbalization += ' '.join(snippet_tokens) - try: - verbalization += "' occurs in all searches and accounts for {}% of prediction score".format( - str(round( - (sum([samples[sample_key]["attributions"][i] for i in snippet])/sum_values)*100, 2 - ))) - except Exception as e: - pass - verbalizations.append(verbalization) - if not verbalizations: - verbalizations = ["No snippet occurs in all searches simultaneously"] - coincidences[sample_key] = verbalizations - """ - return verbalized_explanations +@jit(nopython=True) +def coverage(span, attributions): + if span[0]: + pos_att_sum = sum([float(a) if a > 0 else 0 for a in attributions]) + if pos_att_sum > 0: + return sum([attributions[w] for w in span]) / pos_att_sum + return 0 + + +@jit(nopython=True) +def combine_results(result_dict, combined_candidate_indices): + for idx_cov_tuple in result_dict: + if ',' in idx_cov_tuple[0]: + indices = idx_cov_tuple[0].split(',') + else: + indices = [idx_cov_tuple[0]] + for idx in indices: + if idx == 'None': + continue + if int(idx) not in combined_candidate_indices: + combined_candidate_indices.append(int(idx)) def get_binary_attributions_from_annotator_rationales(text: str, rationales: List[str]):