diff --git a/synthesis_action_retriever/build_graph.py b/synthesis_action_retriever/build_graph.py index c772b2a..095e789 100644 --- a/synthesis_action_retriever/build_graph.py +++ b/synthesis_action_retriever/build_graph.py @@ -69,49 +69,54 @@ def __peek_neighbor_nounphrase(self, graph_window, direction): return graph_window['current_graph'] + def __remove_redundant_tags(self, joined_acts, joined_acts_ids): + acts_w_props = [ + act for act in joined_acts if + any([ + prop for prop in [ + act['temp_values'], act['time_values'] + ] + ]) + ] + if not acts_w_props: + true_act = self.graph_data_sent[-1] + joined_acts_ids.remove(self.graph_data_sent.index(true_act)) + for j in joined_acts_ids: + del self.graph_data_sent[j] + return + elif len(acts_w_props) == 1: + true_act = acts_w_props[0] + joined_acts_ids.remove(self.graph_data_sent.index(true_act)) + for j in sorted(joined_acts_ids, reverse=True): + del self.graph_data_sent[j] + return + else: + if self.__verbose: + print( + "In case with consecutive string of actions, multiple tokens" \ + "were assigned properties... all tokens were retained" + ) + return + def __clean_redundancy(self): joined_acts = [] joined_acts_ids = [] for i, act in enumerate(self.graph_data_sent[1:], start=1): - prev_act = self.graph_data_sent[i-1] + prev_act = self.graph_data_sent[i - 1] if ( - prev_act['act_id'] == act['act_id']-1 and - prev_act['act_type'] == act['act_type'] + prev_act['act_id'] == act['act_id'] - 1 and + prev_act['act_type'] == act['act_type'] ): if prev_act not in joined_acts: joined_acts.extend([prev_act, act]) - joined_acts_ids.extend([i-1, i]) + joined_acts_ids.extend([i - 1, i]) else: joined_acts.append(act) joined_acts_ids.append(i) + if i == len(self.graph_data_sent) - 1: + self.__remove_redundant_tags(joined_acts, joined_acts_ids) elif joined_acts: - acts_w_props = [ - act for act in joined_acts if - any([ - prop for prop in [ - act['temp_values'], act['time_values'] - ] - ]) - ] - if not acts_w_props: - true_act = self.graph_data_sent[-1] - joined_acts_ids.remove(self.graph_data_sent.index(true_act)) - for j in joined_acts_ids: - del self.graph_data_sent[j] - return - elif len(acts_w_props) == 1: - true_act = acts_w_props[0] - joined_acts_ids.remove(self.graph_data_sent.index(true_act)) - for j in sorted(joined_acts_ids, reverse=True): - del self.graph_data_sent[j] - return - else: - if self.__verbose: - print( - "In case with consecutive string of actions, multiple tokens" \ - "were assigned properties... all tokens were retained" - ) - return + self.__remove_redundant_tags(joined_acts, joined_acts_ids) def build_graph( self, @@ -240,16 +245,19 @@ def build_graph( return self.graph_data_sent - def refine_graph(self, full_graph, sentences_materials, sent_toks): + def refine_graph(self, full_graph, sentences_materials): """ Refine synthesis workflow graph to incorporate attributes from neighboring sentences :param full_graph: list of dicts returned from build_graph function :param sentences_materials: list of dicts of sentences and associated materials - :param sent_toks: list of lists of raw sentences tokens for full paragraph :return: list of dict """ sentences = [s["sentence"] for s in sentences_materials] + sent_toks = [] + for sent in sentences: + sent_toks.append([t.text for t in make_spacy_tokens(sent)]) + for graph in full_graph: cursor = full_graph.index(graph) if graph: diff --git a/test.py b/test.py index 14d6a88..ce6eef6 100644 --- a/test.py +++ b/test.py @@ -24,9 +24,5 @@ actions = sar.get_action_labels(spacy_tokens) graph.append(gb.build_graph(spacy_tokens, actions, sent["materials"])) -para = ' '.join([s["sentence"] for s in examples]) -para_sent_toks = Paragraph(para).raw_tokens - -refined_graph = gb.refine_graph(graph, examples, para_sent_toks) - +refined_graph = gb.refine_graph(graph, examples) pprint(refined_graph)