From e14ab53d8c9a172b5d64a2e4e577306cc9310c3e Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Sun, 10 Apr 2022 14:12:54 +0200 Subject: [PATCH 1/2] GH-2717: add option to ignore labels to ColumnCorpus --- flair/datasets/relation_extraction.py | 10 ++++++++-- flair/datasets/sequence_labeling.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/flair/datasets/relation_extraction.py b/flair/datasets/relation_extraction.py index a1f1243f0..178eddf2b 100644 --- a/flair/datasets/relation_extraction.py +++ b/flair/datasets/relation_extraction.py @@ -39,6 +39,7 @@ def __init__( base_path: Union[str, Path] = None, in_memory: bool = True, augment_train: bool = False, + **corpusargs, ): """ SemEval-2010 Task 8 on Multi-Way Classification of Semantic Relations Between Pairs of @@ -83,6 +84,7 @@ def __init__( column_format={1: "text", 2: "ner"}, comment_symbol="# ", in_memory=in_memory, + **corpusargs, ) def extract_and_convert_to_conllu(self, data_file, data_folder, augment_train): @@ -227,7 +229,7 @@ def _semeval_lines_to_token_list(self, raw_lines, augment_relations): class RE_ENGLISH_TACRED(ColumnCorpus): - def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True): + def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, **corpusargs): """ TAC Relation Extraction Dataset with 41 relations from https://nlp.stanford.edu/projects/tacred/. Manual download is required for this dataset. @@ -260,6 +262,7 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True): column_format={1: "text", 2: "ner"}, comment_symbol="# ", in_memory=in_memory, + **corpusargs, ) def extract_and_convert_to_conllu(self, data_file, data_folder): @@ -351,7 +354,7 @@ def _tacred_example_to_token_list(self, example: Dict[str, Any]) -> conllu.Token class RE_ENGLISH_CONLL04(ColumnCorpus): - def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True): + def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, **corpusargs): if not base_path: base_path = flair.cache_root / "datasets" else: @@ -385,6 +388,7 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True): in_memory=in_memory, column_format={1: "text", 2: "ner"}, comment_symbol="# ", + **corpusargs, ) def _parse_incr(self, source_file) -> Iterable[conllu.TokenList]: @@ -536,6 +540,7 @@ def __init__( base_path: Union[str, Path] = None, in_memory: bool = True, sentence_splitter: SentenceSplitter = SegtokSentenceSplitter(), + **corpusargs, ): """ DrugProt corpus: Biocreative VII Track 1 from https://zenodo.org/record/5119892#.YSdSaVuxU5k/ on @@ -570,6 +575,7 @@ def __init__( sample_missing_splits=False, column_format={1: "text", 2: "ner", 3: "ner"}, comment_symbol="# ", + **corpusargs, ) def extract_and_convert_to_conllu(self, data_file, data_folder): diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index bda7f87cf..600de3cfe 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -662,7 +662,8 @@ def _convert_lines_to_sentence( for span_indices, score, label in predicted_spans: span = sentence[span_indices[0] : span_indices[-1] + 1] value = self._remap_label(label) - span.add_label(span_level_tag_columns[span_column], value=value, score=score) + if value != 'O': + span.add_label(span_level_tag_columns[span_column], value=value, score=score) except Exception: pass @@ -681,7 +682,9 @@ def _convert_lines_to_sentence( relation = Relation( first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end] ) - relation.add_label(typename="relation", value=self._remap_label(label)) + remapped = self._remap_label(label) + if remapped != 'O': + relation.add_label(typename="relation", value=remapped) if len(sentence) > 0: return sentence @@ -719,7 +722,8 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O # add each other feature as label-value pair label_name = feature.split("=")[0] label_value = self._remap_label(feature.split("=")[1]) - token.add_label(label_name, label_value) + if label_value != 'O': + token.add_label(label_name, label_value) else: # get the task name (e.g. 'ner') @@ -727,7 +731,8 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O # get the label value label_value = self._remap_label(fields[column]) # add label - token.add_label(label_name, label_value) + if label_value != 'O': + token.add_label(label_name, label_value) if column_name_map[column] == self.SPACE_AFTER_KEY and fields[column] == "-": token.whitespace_after = False From 78b17f29d37f6d4ce6666110eff8940e49aedcf6 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Sun, 10 Apr 2022 14:13:28 +0200 Subject: [PATCH 2/2] GH-2717: formatting --- flair/datasets/sequence_labeling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 600de3cfe..1127b1972 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -662,7 +662,7 @@ def _convert_lines_to_sentence( for span_indices, score, label in predicted_spans: span = sentence[span_indices[0] : span_indices[-1] + 1] value = self._remap_label(label) - if value != 'O': + if value != "O": span.add_label(span_level_tag_columns[span_column], value=value, score=score) except Exception: pass @@ -683,7 +683,7 @@ def _convert_lines_to_sentence( first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end] ) remapped = self._remap_label(label) - if remapped != 'O': + if remapped != "O": relation.add_label(typename="relation", value=remapped) if len(sentence) > 0: @@ -722,7 +722,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O # add each other feature as label-value pair label_name = feature.split("=")[0] label_value = self._remap_label(feature.split("=")[1]) - if label_value != 'O': + if label_value != "O": token.add_label(label_name, label_value) else: @@ -731,7 +731,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O # get the label value label_value = self._remap_label(fields[column]) # add label - if label_value != 'O': + if label_value != "O": token.add_label(label_name, label_value) if column_name_map[column] == self.SPACE_AFTER_KEY and fields[column] == "-":