diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py index 4816569c1..b05288c3c 100644 --- a/flair/datasets/biomedical.py +++ b/flair/datasets/biomedical.py @@ -46,9 +46,6 @@ SENTENCE_TAG = "[__SENT__]" -MULTI_TASK_LEARNING = False -IGNORE_NEGATIVE_SAMPLES = False - logger = logging.getLogger("flair") @@ -363,11 +360,11 @@ def __init__( def process_dataset(self, datasets: Dict[str, InternalBioNerDataset], out_dir: Path): if "train" in datasets: - self.write_to_conll(datasets["train"], out_dir / "train.conll") + self.write_to_conll(datasets["train"], out_dir / (self.sentence_splitter.name + "_train.conll")) if "dev" in datasets: - self.write_to_conll(datasets["dev"], out_dir / "dev.conll") + self.write_to_conll(datasets["dev"], out_dir / (self.sentence_splitter.name + "_dev.conll")) if "test" in datasets: - self.write_to_conll(datasets["test"], out_dir / "test.conll") + self.write_to_conll(datasets["test"], out_dir / (self.sentence_splitter.name + "_test.conll")) def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): os.makedirs(str(output_file.parent), exist_ok=True) @@ -383,19 +380,6 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): GENE_TAG: "genes", SPECIES_TAG: "species", } - if not MULTI_TASK_LEARNING: - task_description = "" - else: - task_description = "[Tag" - for i, entity_type in enumerate(entity_types): - if i == 0: - task_description += f" {mapping[entity_type]}" - elif i == len(entity_types) - 1: - task_description += f" and {mapping[entity_type]}" - else: - task_description += f", {mapping[entity_type]}" - task_description += "]" - task_sentence = self.sentence_splitter.split(task_description) with output_file.open("w", encoding="utf8") as f: for document_id in Tqdm.tqdm( @@ -407,7 +391,6 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): document_text = re.sub(r"[\u2000-\u200B]", " ", document_text) # replace unicode space characters! document_text = document_text.replace("\xa0", " ") # replace non-break space document_buffer = "" - document_had_tags = not IGNORE_NEGATIVE_SAMPLES entities = deque( sorted( @@ -423,18 +406,6 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): in_entity = False sentence_had_tokens = False - # Add task description for multi-task learning - if MULTI_TASK_LEARNING: - for i, flair_token in enumerate(task_sentence[0].tokens): - token = flair_token.text.strip() - if flair_token.whitespace_after > 0 or i == len(task_sentence[0].tokens) - 1: - whitespace_after = "+" - else: - whitespace_after = "-" - if len(token) > 0: - tag = "O" - document_buffer += " ".join([token, tag, whitespace_after]) + "\n" - for flair_token in sentence.tokens: token = flair_token.text.strip() assert sentence.start_position is not None @@ -462,14 +433,11 @@ def write_to_conll(self, dataset: InternalBioNerDataset, output_file: Path): if len(token) > 0: document_buffer += " ".join([token, tag, whitespace_after]) + "\n" sentence_had_tokens = True - if not tag.startswith("O"): - document_had_tags = True if sentence_had_tokens: document_buffer += "\n" - if document_had_tags: - f.write(document_buffer) + f.write(document_buffer) class HunerDataset(ColumnCorpus, ABC): @@ -664,6 +632,13 @@ class HUNER_GENE_BIO_INFER(HunerDataset): """HUNER version of the BioInfer corpus containing only gene/protein annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = { + "Individual_protein": GENE_TAG, + "Gene/protein/RNA": GENE_TAG, + "Gene": GENE_TAG, + "DNA_family_or_group": GENE_TAG, + "Protein_family_or_group": GENE_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -675,19 +650,14 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: train_data = BIO_INFER.parse_dataset(corpus_folder / "BioInfer-train.xml") test_data = BIO_INFER.parse_dataset(corpus_folder / "BioInfer-test.xml") - entity_type_mapping = { - "Individual_protein": GENE_TAG, - "Gene/protein/RNA": GENE_TAG, - "Gene": GENE_TAG, - "DNA_family_or_group": GENE_TAG, - "Protein_family_or_group": GENE_TAG, - } - - train_data = filter_and_map_entities(train_data, entity_type_mapping) - test_data = filter_and_map_entities(test_data, entity_type_mapping) + train_data = filter_and_map_entities(train_data, self.entity_type_mapping) + test_data = filter_and_map_entities(test_data, self.entity_type_mapping) return merge_datasets([train_data, test_data]) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + @deprecated(version="0.13.0", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") class JNLPBA(ColumnCorpus): @@ -856,6 +826,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_GENE_JNLPBA(HUNER_JNLPBA): """HUNER version of the JNLPBA corpus containing gene annotations.""" @@ -1021,6 +994,11 @@ class HUNER_ALL_CELL_FINDER(HunerDataset): """HUNER version of the CellFinder corpus containing only gene annotations.""" def __init__(self, *args, **kwargs): + self.entity_type_mapping = { + "CellLine": CELL_LINE_TAG, + "Species": SPECIES_TAG, + "GeneProtein": GENE_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -1034,15 +1012,14 @@ def split_url() -> List[str]: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: data = CELL_FINDER.download_and_prepare(data_dir) - entity_type_mapping = { - "CellLine": CELL_LINE_TAG, - "Species": SPECIES_TAG, - "GeneProtein": GENE_TAG, - } - data = filter_and_map_entities(data, entity_type_mapping) + + data = filter_and_map_entities(data, self.entity_type_mapping) return data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class MIRNA(ColumnCorpus): """Original miRNA corpus. @@ -1207,6 +1184,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, test_data]) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_GENE_MIRNA(HUNER_MIRNA): """HUNER version of the miRNA corpus containing protein / gene annotations.""" @@ -1594,6 +1574,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(dataset, self.entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_SPECIES_LOCTEXT(HUNER_LOCTEXT): """HUNER version of the Loctext corpus containing species annotations.""" @@ -1920,6 +1903,7 @@ class HUNER_SPECIES_LINNEAUS(HunerDataset): """HUNER version of the LINNEAUS corpus containing species annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Species": SPECIES_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -1929,6 +1913,9 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return LINNEAUS.download_and_parse_dataset(data_dir) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + @deprecated(version="0.13.0", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") class CDR(ColumnCorpus): @@ -1998,6 +1985,7 @@ class HUNER_DISEASE_CDR(HunerDataset): """HUNER version of the IEPA corpus containing disease annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Disease": DISEASE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2011,15 +1999,19 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dev_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_DevelopmentSet.BioC.xml") test_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_TestSet.BioC.xml") all_data = merge_datasets([train_data, dev_data, test_data]) - all_data = filter_and_map_entities(all_data, {"Disease": DISEASE_TAG}) + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_CHEMICAL_CDR(HunerDataset): """HUNER version of the IEPA corpus containing chemical annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Chemical": CHEMICAL_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2033,15 +2025,19 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dev_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_DevelopmentSet.BioC.xml") test_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_TestSet.BioC.xml") all_data = merge_datasets([train_data, dev_data, test_data]) - all_data = filter_and_map_entities(all_data, {"Chemical": CHEMICAL_TAG}) + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_ALL_CDR(HunerDataset): """HUNER version of the IEPA corpus containing disease and chemical annotations.""" def __init__(self, *args, **kwargs): + self.entity_type_mapping = {"Disease": DISEASE_TAG, "Chemical": CHEMICAL_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2059,11 +2055,14 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dev_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_DevelopmentSet.BioC.xml") test_data = bioc_to_internal(data_dir / "CDR_Data" / "CDR.Corpus.v010516" / "CDR_TestSet.BioC.xml") all_data = merge_datasets([train_data, dev_data, test_data]) - entity_type_mapping = {"Disease": DISEASE_TAG, "Chemical": CHEMICAL_TAG} - all_data = filter_and_map_entities(all_data, entity_type_mapping) + + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class VARIOME(ColumnCorpus): """Variome corpus as provided by http://corpora.informatik.hu-berlin.de/corpora/brat2bioc/hvp_bioc.xml.zip. @@ -2161,6 +2160,7 @@ class HUNER_GENE_VARIOME(HunerDataset): """HUNER version of the Variome corpus containing gene annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"gene": GENE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2171,15 +2171,19 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: os.makedirs(str(data_dir), exist_ok=True) VARIOME.download_dataset(data_dir) all_data = VARIOME.parse_corpus(data_dir / "hvp_bioc.xml") - all_data = filter_and_map_entities(all_data, {"gene": GENE_TAG}) + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_DISEASE_VARIOME(HunerDataset): """HUNER version of the Variome corpus containing disease annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Disorder": DISEASE_TAG, "disease": DISEASE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2190,15 +2194,19 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: os.makedirs(str(data_dir), exist_ok=True) VARIOME.download_dataset(data_dir) all_data = VARIOME.parse_corpus(data_dir / "hvp_bioc.xml") - all_data = filter_and_map_entities(all_data, {"Disorder": DISEASE_TAG, "disease": DISEASE_TAG}) + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_SPECIES_VARIOME(HunerDataset): """HUNER version of the Variome corpus containing species annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Living_Beings": SPECIES_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2209,15 +2217,24 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: os.makedirs(str(data_dir), exist_ok=True) VARIOME.download_dataset(data_dir) all_data = VARIOME.parse_corpus(data_dir / "hvp_bioc.xml") - all_data = filter_and_map_entities(all_data, {"Living_Beings": SPECIES_TAG}) + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_ALL_VARIOME(HunerDataset): """HUNER version of the Variome corpus containing gene, disease and species annotations.""" def __init__(self, *args, **kwargs): + self.entity_type_mapping = { + "gene": GENE_TAG, + "Disorder": DISEASE_TAG, + "disease": DISEASE_TAG, + "Living_Beings": SPECIES_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -2233,16 +2250,14 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: os.makedirs(str(data_dir), exist_ok=True) VARIOME.download_dataset(data_dir) all_data = VARIOME.parse_corpus(data_dir / "hvp_bioc.xml") - entity_type_mapping = { - "gene": GENE_TAG, - "Disorder": DISEASE_TAG, - "disease": DISEASE_TAG, - "Living_Beings": SPECIES_TAG, - } - all_data = filter_and_map_entities(all_data, entity_type_mapping) + + all_data = filter_and_map_entities(all_data, self.entity_type_mapping) return all_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class NCBI_DISEASE(ColumnCorpus): """Original NCBI disease corpus containing disease annotations. @@ -2386,6 +2401,7 @@ class HUNER_DISEASE_NCBI(HunerDataset): """HUNER version of the NCBI corpus containing disease annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Disease": DISEASE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2401,6 +2417,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data, test_data]) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class ScaiCorpus(ColumnCorpus): """Base class to support the SCAI chemicals and disease corpora.""" @@ -2554,6 +2573,16 @@ class HUNER_CHEMICAL_SCAI(HunerDataset): """HUNER version of the SCAI chemicals corpus containing chemical annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = { + "FAMILY": CHEMICAL_TAG, + "TRIVIALVAR": CHEMICAL_TAG, + "PARTIUPAC": CHEMICAL_TAG, + "TRIVIAL": CHEMICAL_TAG, + "ABBREVIATION": CHEMICAL_TAG, + "IUPAC": CHEMICAL_TAG, + "MODIFIER": CHEMICAL_TAG, + "SUM": CHEMICAL_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -2564,25 +2593,17 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: original_file = SCAI_CHEMICALS.perform_corpus_download(data_dir) corpus = ScaiCorpus.parse_input_file(original_file) - # Map all entities to chemicals - entity_type_mapping = { - "FAMILY": CHEMICAL_TAG, - "TRIVIALVAR": CHEMICAL_TAG, - "PARTIUPAC": CHEMICAL_TAG, - "TRIVIAL": CHEMICAL_TAG, - "ABBREVIATION": CHEMICAL_TAG, - "IUPAC": CHEMICAL_TAG, - "MODIFIER": CHEMICAL_TAG, - "SUM": CHEMICAL_TAG, - } + return filter_and_map_entities(corpus, self.entity_type_mapping) - return filter_and_map_entities(corpus, entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping class HUNER_DISEASE_SCAI(HunerDataset): """HUNER version of the SCAI chemicals corpus containing disease annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"DISEASE": DISEASE_TAG, "ADVERSE": DISEASE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2593,16 +2614,28 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: original_file = SCAI_DISEASE.perform_corpus_download(data_dir) corpus = ScaiCorpus.parse_input_file(original_file) - # Map all entities to disease - entity_type_mapping = {"DISEASE": DISEASE_TAG, "ADVERSE": DISEASE_TAG} + return filter_and_map_entities(corpus, self.entity_type_mapping) - return filter_and_map_entities(corpus, entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping class HUNER_ALL_SCAI(HunerDataset): """HUNER version of the SCAI chemicals corpus containing chemical and disease annotations.""" def __init__(self, *args, **kwargs): + self.entity_type_mapping = { + "DISEASE": DISEASE_TAG, + "ADVERSE": DISEASE_TAG, + "FAMILY": CHEMICAL_TAG, + "TRIVIALVAR": CHEMICAL_TAG, + "PARTIUPAC": CHEMICAL_TAG, + "TRIVIAL": CHEMICAL_TAG, + "ABBREVIATION": CHEMICAL_TAG, + "IUPAC": CHEMICAL_TAG, + "MODIFIER": CHEMICAL_TAG, + "SUM": CHEMICAL_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -2617,20 +2650,10 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: original_file = SCAI_DISEASE.perform_corpus_download(data_dir) corpus = ScaiCorpus.parse_input_file(original_file) - entity_type_mapping = { - "DISEASE": DISEASE_TAG, - "ADVERSE": DISEASE_TAG, - "FAMILY": CHEMICAL_TAG, - "TRIVIALVAR": CHEMICAL_TAG, - "PARTIUPAC": CHEMICAL_TAG, - "TRIVIAL": CHEMICAL_TAG, - "ABBREVIATION": CHEMICAL_TAG, - "IUPAC": CHEMICAL_TAG, - "MODIFIER": CHEMICAL_TAG, - "SUM": CHEMICAL_TAG, - } + return filter_and_map_entities(corpus, self.entity_type_mapping) - return filter_and_map_entities(corpus, entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping @deprecated(version="0.13.0", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") @@ -2738,6 +2761,7 @@ class HUNER_GENE_OSIRIS(HunerDataset): """HUNER version of the OSIRIS corpus containing (only) gene annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"ge": GENE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2748,8 +2772,10 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: original_file = OSIRIS.download_dataset(data_dir) corpus = OSIRIS.parse_dataset(original_file / "OSIRIScorpusv02") - entity_type_mapping = {"ge": GENE_TAG} - return filter_and_map_entities(corpus, entity_type_mapping) + return filter_and_map_entities(corpus, self.entity_type_mapping) + + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping class S800(ColumnCorpus): @@ -2834,6 +2860,7 @@ class HUNER_SPECIES_S800(HunerDataset): """HUNER version of the S800 corpus containing species annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Species": SPECIES_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2843,10 +2870,13 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: S800.download_dataset(data_dir) data = S800.parse_dataset(data_dir) - data = filter_and_map_entities(data, {"Species": SPECIES_TAG}) + data = filter_and_map_entities(data, self.entity_type_mapping) return data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class GPRO(ColumnCorpus): """Original GPRO corpus containing gene annotations. @@ -2971,6 +3001,7 @@ class HUNER_GENE_GPRO(HunerDataset): """HUNER version of the GPRO corpus containing gene annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Gene": GENE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -2990,6 +3021,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return merge_datasets([train_data, dev_data]) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class DECA(ColumnCorpus): """Original DECA corpus containing gene annotations. @@ -3085,6 +3119,7 @@ class HUNER_GENE_DECA(HunerDataset): """HUNER version of the DECA corpus containing gene annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Gene": GENE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -3098,6 +3133,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return DECA.parse_corpus(text_dir, gold_file) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class FSU(ColumnCorpus): """Original FSU corpus containing protein and derived annotations. @@ -3233,6 +3271,13 @@ class HUNER_GENE_FSU(HunerDataset): """HUNER version of the FSU corpus containing (only) gene annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = { + "protein": GENE_TAG, + "protein_familiy_or_group": GENE_TAG, + "protein_complex": GENE_TAG, + "protein_variant": GENE_TAG, + "protein_enum": GENE_TAG, + } super().__init__(*args, **kwargs) @staticmethod @@ -3251,14 +3296,10 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: corpus = FSU.parse_corpus(corpus_dir, sentence_separator) - entity_type_mapping = { - "protein": GENE_TAG, - "protein_familiy_or_group": GENE_TAG, - "protein_complex": GENE_TAG, - "protein_variant": GENE_TAG, - "protein_enum": GENE_TAG, - } - return filter_and_map_entities(corpus, entity_type_mapping) + return filter_and_map_entities(corpus, self.entity_type_mapping) + + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping class CRAFT(ColumnCorpus): @@ -3744,6 +3785,18 @@ class HUNER_CHEMICAL_CEMP(HunerDataset): """HUNER version of the CEMP corpus containing chemical annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = { + x: CHEMICAL_TAG + for x in [ + "ABBREVIATION", + "FAMILY", + "FORMULA", + "IDENTIFIERS", + "MULTIPLE", + "SYSTEMATIC", + "TRIVIAL", + ] + } super().__init__(*args, **kwargs) @staticmethod @@ -3762,19 +3815,10 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: dev_data = CEMP.parse_input_file(dev_text_file, dev_ann_file) dataset = merge_datasets([train_data, dev_data]) - entity_type_mapping = { - x: CHEMICAL_TAG - for x in [ - "ABBREVIATION", - "FAMILY", - "FORMULA", - "IDENTIFIERS", - "MULTIPLE", - "SYSTEMATIC", - "TRIVIAL", - ] - } - return filter_and_map_entities(dataset, entity_type_mapping) + return filter_and_map_entities(dataset, self.entity_type_mapping) + + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping @deprecated(version="0.13", reason="Please use data set implementation from BigBio instead (see BIGBIO_NER_CORPUS)") @@ -3912,6 +3956,9 @@ def to_internal(self, data_dir: Path, annotator: int = 0) -> InternalBioNerDatas dataset = CHEBI.parse_dataset(corpus_dir, annotator=annotator) return filter_and_map_entities(dataset, self.entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_CHEMICAL_CHEBI(HUNER_CHEBI): """HUNER version of the CHEBI corpus containing chemical annotations.""" @@ -4698,6 +4745,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_CHEMICAL_CRAFT_V4(HUNER_CRAFT_V4): """HUNER version of the CRAFT corpus containing (only) chemical annotations.""" @@ -4753,6 +4803,9 @@ def to_internal(self, data_dir: Path) -> InternalBioNerDataset: return filter_and_map_entities(corpus, self.entity_type_mapping) + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HUNER_CHEMICAL_BIONLP2013_CG(HUNER_BIONLP2013_CG): def __init__(self, *args, **kwargs): @@ -4958,6 +5011,7 @@ class HUNER_DISEASE_PDR(HunerDataset): """PDR Dataset with only Disease annotations.""" def __init__(self, *args, **kwargs) -> None: + self.entity_type_mapping = {"Disease": DISEASE_TAG} super().__init__(*args, **kwargs) @staticmethod @@ -4967,10 +5021,13 @@ def split_url() -> str: def to_internal(self, data_dir: Path) -> InternalBioNerDataset: corpus_folder = PDR.download_corpus(data_dir) corpus_data = brat_to_internal(corpus_folder, ann_file_suffixes=[".ann", ".ann2"]) - corpus_data = filter_and_map_entities(corpus_data, {"Disease": DISEASE_TAG}) + corpus_data = filter_and_map_entities(corpus_data, self.entity_type_mapping) return corpus_data + def get_entity_type_mapping(self) -> Optional[Dict]: + return self.entity_type_mapping + class HunerMultiCorpus(MultiCorpus): """Base class to build the union of all HUNER data sets considering a particular entity type.""" @@ -5122,9 +5179,9 @@ def __init__( self.sentence_splitter = sentence_splitter if sentence_splitter else SciSpacySentenceSplitter() dataset_dir_name = self.build_corpus_directory_name(dataset_name) - data_folder = base_path / dataset_dir_name / self.sentence_splitter.name + data_folder = base_path / dataset_dir_name - train_file = data_folder / "train.conll" + train_file = data_folder / (self.sentence_splitter.name + "_train.conll") # Download data if necessary # Some datasets in BigBio only have train or test splits, not both diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 22e644a2a..48f463812 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1,7 +1,6 @@ import logging import tempfile from abc import ABC, abstractmethod - from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, cast from urllib.error import HTTPError @@ -14,7 +13,7 @@ from tqdm import tqdm import flair.nn -from flair.data import Dictionary, Label, Sentence, Span, get_spans_from_bio, Corpus, Token +from flair.data import Corpus, Dictionary, Label, Sentence, Span, Token, get_spans_from_bio from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import TokenEmbeddings from flair.file_utils import cached_path, unzip_file @@ -1052,60 +1051,54 @@ class AugmentedSentence(Sentence): class SentenceAugmentationStrategy(ABC): - @abstractmethod def augment_sentence( - self, - sentence: Sentence, - annotation_layers: Union[str, List[str]] = None + self, sentence: Sentence, annotation_layers: Union[str, List[str]] = None ) -> AugmentedSentence: """ - Augments the given sentence text with additional instructions for working / predicting - the task on the given annotations. + Augments the given sentence text with additional instructions for working / predicting + the task on the given annotations. - Args: - sentence: The sentence to be augmented - annotation_layers: Annotations which should be predicted. + Args: + sentence: The sentence to be augmented + annotation_layers: Annotations which should be predicted. """ ... @abstractmethod def apply_predictions( - self, - augmented_sentence: Sentence, - original_sentence: Sentence, - source_annotation_layer: str, - target_annotation_layer: str + self, + augmented_sentence: Sentence, + original_sentence: Sentence, + source_annotation_layer: str, + target_annotation_layer: str, ): """ - Transfers the predictions made on the augmented sentence to the original one. + Transfers the predictions made on the augmented sentence to the original one. - Args: - augmented_sentence: The augmented sentence instance - original_sentence: The original sentence before the augmentation was applied - source_annotation_layer: Annotation layer of the augmented sentence in which the predictions are stored. - target_annotation_layer: Annotation layer in which the predictions should be stored in the original sentence. + Args: + augmented_sentence: The augmented sentence instance + original_sentence: The original sentence before the augmentation was applied + source_annotation_layer: Annotation layer of the augmented sentence in which the predictions are stored. + target_annotation_layer: Annotation layer in which the predictions should be stored in the original sentence. """ ... @abstractmethod def _get_state_dict(self): """ - Returns the state dict for the given augmentation strategy. + Returns the state dict for the given augmentation strategy. """ ... @classmethod def _init_strategy_with_state_dict(cls, state, **kwargs): """ - Initializes the strategy from the given state. + Initializes the strategy from the given state. """ def augment_dataset( - self, - dataset: Dataset[Sentence], - annotation_layers: Union[str, List[str]] = None - + self, dataset: Dataset[Sentence], annotation_layers: Union[str, List[str]] = None ) -> FlairDatapointDataset[AugmentedSentence]: """Transforms a dataset into a dataset containing augmented sentences specific to the `HunFlairAllInOneSequenceTagger`. @@ -1122,18 +1115,12 @@ def augment_dataset( data_loader: DataLoader = DataLoader(dataset, batch_size=1) original_sentences: List[Sentence] = [batch[0] for batch in iter(data_loader)] - augmented_sentences = [ - self.augment_sentence(sentence, annotation_layers) - for sentence in original_sentences - - ] + augmented_sentences = [self.augment_sentence(sentence, annotation_layers) for sentence in original_sentences] return FlairDatapointDataset(augmented_sentences) def augment_corpus( - self, - corpus: Corpus[Sentence], - annotation_layers: Union[str, List[str]] = None + self, corpus: Corpus[Sentence], annotation_layers: Union[str, List[str]] = None ) -> Corpus[AugmentedSentence]: """Transforms a corpus into a corpus containing augmented sentences specific to the `HunFlairAllInOneSequenceTagger`. @@ -1160,15 +1147,15 @@ def augment_corpus( class EntityTypeTaskPromptAugmentationStrategy(SentenceAugmentationStrategy): """ - Augmentation strategy that augments a sentence with a task description which specifies - which entity types should be tagged. + Augmentation strategy that augments a sentence with a task description which specifies + which entity types should be tagged. - Example: - "[Tag gene and disease] Mutations in the TP53 tumour suppressor gene are found in ~50% of human cancers" + Example: + "[Tag gene and disease] Mutations in the TP53 tumour suppressor gene are found in ~50% of human cancers" - This approach is inspired by the paper from Luo et al.: - AIONER: All-in-one scheme-based biomedical named entity recognition using deep learning - https://arxiv.org/abs/2211.16944 + This approach is inspired by the paper from Luo et al.: + AIONER: All-in-one scheme-based biomedical named entity recognition using deep learning + https://arxiv.org/abs/2211.16944 """ def __init__(self, entity_types: List[str]): @@ -1179,16 +1166,14 @@ def __init__(self, entity_types: List[str]): self.task_prompt = self._build_tag_prompt_prefix(entity_types) def augment_sentence( - self, - sentence: Sentence, - annotation_layers: Union[str, List[str]] = None + self, sentence: Sentence, annotation_layers: Union[str, List[str]] = None ) -> AugmentedSentence: # Prepend the task description prompt to the sentence text augmented_sentence = AugmentedSentence( text=self.task_prompt + [t.text for t in sentence.tokens], use_tokenizer=False, language_code=sentence.language_code, - start_position=sentence.start_position + start_position=sentence.start_position, ) # Make sure it's a list @@ -1203,33 +1188,40 @@ def augment_sentence( for label in sentence.get_labels(layer): if isinstance(label.data_point, Token): label_span = augmented_sentence[ - len_task_prompt + label.data_point.idx - 1: - len_task_prompt + label.data_point.idx] + len_task_prompt + label.data_point.idx - 1 : len_task_prompt + label.data_point.idx + ] else: label_span = augmented_sentence[ - len_task_prompt + label.data_point.tokens[0].idx - 1: - len_task_prompt + label.data_point.tokens[-1].idx] + len_task_prompt + label.data_point.tokens[0].idx - 1 : len_task_prompt + + label.data_point.tokens[-1].idx + ] label_span.add_label(layer, label.value, label.score) return augmented_sentence def apply_predictions( - self, - augmented_sentence: Sentence, - original_sentence: Sentence, - source_annotation_layer: str, - target_annotation_layer: str + self, + augmented_sentence: Sentence, + original_sentence: Sentence, + source_annotation_layer: str, + target_annotation_layer: str, ): new_labels = augmented_sentence.get_labels(source_annotation_layer) len_task_prompt = len(self.task_prompt) - for label in new_labels: - orig_span = original_sentence[ - label.data_point.tokens[0].idx - len_task_prompt - 1: - label.data_point.tokens[-1].idx - len_task_prompt - ] - orig_span.add_label(target_annotation_layer, label.value, label.score) + try: + for label in new_labels: + if label.data_point.tokens[0].idx - len_task_prompt - 1 < 0: + continue + orig_span = original_sentence[ + label.data_point.tokens[0].idx - len_task_prompt - 1 : label.data_point.tokens[-1].idx + - len_task_prompt + ] + orig_span.add_label(target_annotation_layer, label.value, label.score) + except IndexError: + for token in original_sentence: + print(token) def _build_tag_prompt_prefix(self, entity_types: List[str]) -> List[str]: if len(self.entity_types) == 1: @@ -1248,24 +1240,17 @@ def _init_strategy_with_state_dict(cls, state, **kwargs): class AugmentedSentenceSequenceTagger(SequenceTagger): - - def __init__( - self, - *args, - augmentation_strategy: SentenceAugmentationStrategy, - **kwargs - ): + def __init__(self, *args, augmentation_strategy: SentenceAugmentationStrategy, **kwargs): super().__init__(*args, **kwargs) if augmentation_strategy is None: - raise AssertionError("Please provide an augmentation strategy") + logging.warning("No augmentation strategy provided. Make sure that the strategy is set.") self.augmentation_strategy = augmentation_strategy def _get_state_dict(self): state = super(AugmentedSentenceSequenceTagger, self)._get_state_dict() - class_name = ".".join([self.augmentation_strategy.__module__, - self.augmentation_strategy.__class__.__name__]) + class_name = ".".join([self.augmentation_strategy.__module__, self.augmentation_strategy.__class__.__name__]) state["augmentation_strategy_cls"] = class_name state["augmentation_strategy_state"] = self.augmentation_strategy._get_state_dict() @@ -1284,13 +1269,9 @@ def _init_model_with_state_dict(cls, state, **kwargs): for subclass_name, subclass in subclasses: if aug_strategy_cls_name == subclass_name: - strategy = subclass._init_strategy_with_state_dict( - state.get("augmentation_strategy_state")) + strategy = subclass._init_strategy_with_state_dict(state.get("augmentation_strategy_state")) break - if strategy is None: - raise AssertionError("Can't load augmentation strategy") - return super()._init_model_with_state_dict( state, augmentation_strategy=strategy, @@ -1298,18 +1279,28 @@ def _init_model_with_state_dict(cls, state, **kwargs): ) @classmethod - def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "AugmentedSentenceSequenceTagger": + def load( + cls, + model_path: Union[str, Path, Dict[str, Any]], + augmentation_strategy: EntityTypeTaskPromptAugmentationStrategy, + ) -> "AugmentedSentenceSequenceTagger": from typing import cast - return cast("AugmentedSentenceSequenceTagger", super().load(model_path=model_path)) + if augmentation_strategy is None: + raise AssertionError("Please provide an augmentation strategy") + + model_instance = cast("AugmentedSentenceSequenceTagger", super().load(model_path=model_path)) + + model_instance.augmentation_strategy = augmentation_strategy + + logging.warning(f"Loaded model '{model_path}' with augmentation strategy '{augmentation_strategy}'") + + return model_instance def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: # If all sentences are not augmented -> augment them if all(isinstance(sentence, Sentence) for sentence in sentences): - sentences = self.augment_sentences( - sentences=sentences, - annotation_layers=self.tag_type - ) + sentences = self.augment_sentences(sentences=sentences, annotation_layers=self.tag_type) elif not all(isinstance(sentence, AugmentedSentence) for sentence in sentences): raise ValueError("All passed sentences must be either uniformly augmented or not.") @@ -1341,16 +1332,12 @@ def predict(self, sentences: Union[List[Sentence], Sentence], **kwargs): # Predict on augmented sentence and store it in an internal annotation layer / label loss_and_count = super(AugmentedSentenceSequenceTagger, self).predict( - sentences=augmented_sentences, - label_name=orig_label_name, - **kwargs + sentences=augmented_sentences, label_name=orig_label_name, **kwargs ) # Append predicted labels to the original sentences for orig_sent, aug_sent in zip(sentences, augmented_sentences): - self.augmentation_strategy.apply_predictions( - aug_sent, orig_sent, orig_label_name, orig_label_name - ) + self.augmentation_strategy.apply_predictions(aug_sent, orig_sent, orig_label_name, orig_label_name) if orig_label_name == "predicted": orig_sent.remove_labels("predicted_bio") @@ -1360,14 +1347,9 @@ def predict(self, sentences: Union[List[Sentence], Sentence], **kwargs): return loss_and_count def augment_sentences( - self, - sentences: Union[Sentence, List[Sentence]], - annotation_layers: Union[str, List[str]] = None + self, sentences: Union[Sentence, List[Sentence]], annotation_layers: Union[str, List[str]] = None ) -> List[AugmentedSentence]: if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): sentences = [sentences] - return [ - self.augmentation_strategy.augment_sentence(sentence, annotation_layers) - for sentence in sentences - ] + return [self.augmentation_strategy.augment_sentence(sentence, annotation_layers) for sentence in sentences]