diff --git a/README.md b/README.md index 6287b3261..3ea845188 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Flair is: * **A powerful NLP library.** Flair allows you to apply our state-of-the-art natural language processing (NLP) models to your text, such as named entity recognition (NER), sentiment analysis, part-of-speech tagging (PoS), - special support for [biomedical data](/resources/docs/HUNFLAIR.md), + special support for [biomedical texts](/resources/docs/HUNFLAIR2.md), sense disambiguation and classification, with support for a rapidly growing number of languages. * **A text embedding library.** Flair has simple interfaces that allow you to use and combine different word and diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index b4e3edc1f..76173bac8 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -204,7 +204,7 @@ def __init__( self.__embedding_length: int = precomputed_word_embeddings.vector_size - vectors = np.row_stack( + vectors = np.vstack( ( precomputed_word_embeddings.vectors, np.zeros(self.__embedding_length, dtype="float"), @@ -399,7 +399,7 @@ def __setstate__(self, state: Dict[str, Any]): state.setdefault("field", None) if "precomputed_word_embeddings" in state: precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings") - vectors = np.row_stack( + vectors = np.vstack( ( precomputed_word_embeddings.vectors, np.zeros(precomputed_word_embeddings.vector_size, dtype="float"), diff --git a/flair/models/__init__.py b/flair/models/__init__.py index ac69e19aa..452c513f9 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -6,6 +6,7 @@ from .multitask_model import MultitaskModel from .pairwise_classification_model import TextPairClassifier from .pairwise_regression_model import TextPairRegressor +from .prefixed_tagger import PrefixedSequenceTagger # This import has to be after SequenceTagger! from .regexp_tagger import RegexpTagger from .relation_classifier_model import RelationClassifier from .relation_extractor_model import RelationExtractor @@ -26,6 +27,7 @@ "RelationExtractor", "RegexpTagger", "SequenceTagger", + "PrefixedSequenceTagger", "TokenClassifier", "WordTagger", "FewshotClassifier", diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 3f39003b7..9d0e28a43 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -1,6 +1,7 @@ import inspect import logging import os +import platform import re import stat import string @@ -648,6 +649,8 @@ def p(text: str) -> str: emb = emb / torch.norm(emb) dense_embeddings.append(emb.cpu().numpy()) sent.clear_embeddings() + + # empty cuda cache if device is a cuda device if flair.device.type == "cuda": torch.cuda.empty_cache() @@ -681,6 +684,11 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]: emb = emb / torch.norm(emb) query_embeddings["dense"].append(emb.cpu().numpy()) sent.clear_embeddings(self.embeddings["dense"].get_names()) + + # Sanity conversion: if flair.device was set as a string, convert to torch.device + if isinstance(flair.device, str): + flair.device = torch.device(flair.device) + if flair.device.type == "cuda": torch.cuda.empty_cache() @@ -836,9 +844,13 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict if any(label in ["diseases", "genes", "species", "chemical"] for label in sentence.annotation_layers): if not self._warned_legacy_sequence_tagger: logger.warning( - "The tagger `Classifier.load('hunflair') is deprecated. Please update to: `Classifier.load('hunflair2')`." + "It appears that the sentences have been annotated with HunFlair (version 1). " + "Consider using HunFlair2 for improved extraction performance: Classifier.load('hunflair2')." + "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further " + "information." ) self._warned_legacy_sequence_tagger = True + entity_types = {e for sublist in entity_label_types.values() for e in sublist} entities_mentions = [ label for label in sentence.get_labels() if normalize_entity_type(label.value) in entity_types @@ -939,6 +951,14 @@ def _fetch_model(model_name: str) -> str: if model_name in hf_model_map: model_name = hf_model_map[model_name] + if platform.system() == "Windows": + logger.warning( + "You seem to run your application on a Windows system. Unfortunately, the abbreviation " + "resolution of HunFlair2 is only available on Linux/Mac systems. Therefore, a model " + "without abbreviation resolution is therefore loaded" + ) + model_name += "-no-ab3p" + return hf_download(model_name) @classmethod diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py index d127f0e14..bcd9befb3 100644 --- a/flair/models/multitask_model.py +++ b/flair/models/multitask_model.py @@ -260,6 +260,14 @@ def _fetch_model(model_name) -> str: cache_dir = Path("models") if model_name in model_map: + if model_name in ["hunflair", "hunflair-paper", "bioner"]: + log.warning( + "HunFlair (version 1) is deprecated. Consider using HunFlair2 for improved extraction performance: " + "Classifier.load('hunflair2')." + "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further " + "information." + ) + model_name = cached_path(model_map[model_name], cache_dir=cache_dir) return model_name diff --git a/flair/models/prefixed_tagger.py b/flair/models/prefixed_tagger.py index b8c01c50a..a2b3012c2 100644 --- a/flair/models/prefixed_tagger.py +++ b/flair/models/prefixed_tagger.py @@ -9,7 +9,8 @@ import flair.data from flair.data import Corpus, Sentence, Token from flair.datasets import DataLoader, FlairDatapointDataset -from flair.models import SequenceTagger +from flair.file_utils import hf_download +from flair.models.sequence_tagger_model import SequenceTagger class PrefixedSentence(Sentence): @@ -317,3 +318,21 @@ def augment_sentences( sentences = [sentences] return [self.augmentation_strategy.augment_sentence(sentence, annotation_layers) for sentence in sentences] + + @staticmethod + def _fetch_model(model_name) -> str: + huggingface_model_map = {"hunflair2": "hunflair/hunflair2-ner"} + + # check if model name is a valid local file + if Path(model_name).exists(): + model_path = model_name + + # check if model name is a pre-configured hf model + elif model_name in huggingface_model_map: + hf_model_name = huggingface_model_map[model_name] + return hf_download(hf_model_name) + + else: + model_path = hf_download(model_name) + + return model_path diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 9e1ff7719..1f2a93c68 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -781,6 +781,14 @@ def _fetch_model(model_name) -> str: elif model_name in hu_model_map: model_path = cached_path(hu_model_map[model_name], cache_dir=cache_dir) + if model_name.startswith("hunflair-"): + log.warning( + "HunFlair (version 1) is deprecated. Consider using HunFlair2 for improved extraction performance: " + "Classifier.load('hunflair2')." + "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further " + "information." + ) + # special handling for the taggers by the @redewiegergabe project (TODO: move to model hub) elif model_name == "de-historic-indirect": model_file = flair.cache_root / cache_dir / "indirect" / "final-model.pt" diff --git a/resources/docs/HUNFLAIR.md b/resources/docs/HUNFLAIR.md index a85b88c8b..9ec20f7a6 100644 --- a/resources/docs/HUNFLAIR.md +++ b/resources/docs/HUNFLAIR.md @@ -8,6 +8,9 @@ NER data sets](HUNFLAIR_CORPORA.md) and comes with a Flair language model ("pubm FastText embeddings ("pubmed") that were trained on roughly 3 million full texts and about 25 million abstracts from the biomedical domain. +**Using HunFlair (version 1) is deprecated, please refer to [HunFlair2](HUNFLAIR2.md) +for an updated and improved version.** + Content: [Quick Start](#quick-start) | [BioNER-Tool Comparison](#comparison-to-other-biomedical-ner-tools) | diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md new file mode 100644 index 000000000..bef0f42ad --- /dev/null +++ b/resources/docs/HUNFLAIR2.md @@ -0,0 +1,137 @@ +# HunFlair2 + +*HunFlair2* is a state-of-the-art named entity tagger and linker for biomedical texts. It comes with +models for genes/proteins, chemicals, diseases, species and cell lines. *HunFlair2* +builds on pretrained domain-specific language models and outperforms other biomedical +NER tools on unseen corpora. + +Content: +[Quick Start](#quick-start) | +[Tool Comparison](#comparison-to-other-biomedical-entity-extraction-tools) | +[Tutorials](#tutorials) | +[Citing HunFlair](#citing-hunflair2) + +## Quick Start + +#### Requirements and Installation +*HunFlair2* is based on Flair 0.13+ and Python 3.8+. If you do not have Python 3.8, install it first. +Then, in your favorite virtual environment, simply do: +``` +pip install flair +``` + +#### Example 1: Biomedical NER +Let's run named entity recognition (NER) over an example sentence. All you need to do is +make a Sentence, load a pre-trained model and use it to predict tags for the sentence: +```python +from flair.data import Sentence +from flair.nn import Classifier + +# make a sentence +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# load biomedical NER tagger +tagger = Classifier.load("hunflair2") + +# tag sentence +tagger.predict(sentence) +``` +Done! The Sentence now has entity annotations. Let's print the entities found by the tagger: +```python +for entity in sentence.get_labels(): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → Disease (1.0) +Span[4:5]: "Fmr1" → Gene (1.0) +Span[6:7]: "Mouse" → Species (1.0) +Span[9:12]: "Fragile X Syndrome" → Disease (1.0) +``` + +#### Example 2: Biomedical NEN +For improved integration and aggregation from multiple different documents linking / normalizing the entities to +standardized ontologies or knowledge bases is required. Let's perform entity normalization by using +specialized models per entity type: +```python +from flair.data import Sentence +from flair.models import EntityMentionLinker +from flair.nn import Classifier + +# make a sentence +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# load biomedical NER tagger + predict entities +tagger = Classifier.load("hunflair2") +tagger.predict(sentence) + +# load gene linker and perform normalization +gene_linker = EntityMentionLinker.load("gene-linker") +gene_linker.predict(sentence) + +# load disease linker and perform normalization +disease_linker = EntityMentionLinker.load("disease-linker") +disease_linker.predict(sentence) + +# load species linker and perform normalization +species_linker = EntityMentionLinker.load("species-linker") +species_linker.predict(sentence) +``` +**Note**, the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed, +which might takes a certain amount of time. All further calls are then based on this pre-processing and run +much faster. + +Done! The Sentence now has entity normalizations. Let's print the entity identifiers found by the linkers: +```python +for entity in sentence.get_labels("link"): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → MESH:D001523/name=Mental Disorders (197.9467010498047) +Span[4:5]: "Fmr1" → 108684022/name=FRAXA (219.9510040283203) +Span[6:7]: "Mouse" → 10090/name=Mus musculus (213.6201934814453) +Span[9:12]: "Fragile X Syndrome" → MESH:D005600/name=Fragile X Syndrome (193.7115020751953) +``` + +## Comparison to other biomedical entity extraction tools +Tools for biomedical entity extraction are typically trained and evaluated on single, rather small gold standard +data sets. However, they are applied "in the wild" to a much larger collection of texts, often varying in +topic, entity distribution, genre (e.g. patents vs. scientific articles) and text type (e.g. abstract +vs. full text), which can lead to severe drops in performance. + +*HunFlair2* outperforms other biomedical entity extraction tools on corpora not used for training of neither +*HunFlair2* or any of the competitor tools. + +| Corpus | Entity Type | BENT | BERN2 | PubTator Central | SciSpacy | HunFlair | +|----------------------------------------------------------------------------------------------|-------------|-------|-------|------------------|----------|-------------| +| [MedMentions](https://github.com/chanzuckerberg/MedMentions) | Chemical | 40.90 | 41.79 | 31.28 | 34.95 | *__51.17__* | +| | Disease | 45.94 | 47.33 | 41.11 | 40.78 | *__57.27__* | +| [tmVar (v3)](https://github.com/ncbi/tmVar3?tab=readme-ov-file) | Gene | 0.54 | 43.96 | *__86.02__* | - | 76.75 | +| [BioID](https://biocreative.bioinformatics.udel.edu/media/store/files/2018/BC6_track1_1.pdf) | Species | 10.35 | 14.35 | *__58.90__* | 37.14 | 49.66 | +||||| +| Average | All | 24.43 | 36.86 | 54.33 | 37.61 | *__58.79__* | + +All results are F1 scores highlighting end-to-end performance, i.e., named entity recognition and normalization, +using partial matching of predicted text offsets with the original char offsets of the gold standard data. +We allow a shift by max one character. + +You can find detailed evaluations and discussions in [our paper](https://arxiv.org/abs/2402.12372). + +## Tutorials +We provide a set of quick tutorials to get you started with *HunFlair2*: +* [Tutorial 1: Tagging biomedical named entities](HUNFLAIR2_TUTORIAL_1_TAGGING.md) +* [Tutorial 2: Linking biomedical named entities](HUNFLAIR2_TUTORIAL_2_LINKING.md) +* [Tutorial 3: Training NER models](HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md) +* [Tutorial 4: Customizing linking](HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md) + +## Citing HunFlair2 +Please cite the following paper when using *HunFlair2*: +~~~ +@article{sanger2024hunflair2, + title={HunFlair2 in a cross-corpus evaluation of biomedical named entity recognition and normalization tools}, + author={S{\"a}nger, Mario and Garda, Samuele and Wang, Xing David and Weber-Genzel, Leon and Droop, Pia and Fuchs, Benedikt and Akbik, Alan and Leser, Ulf}, + journal={arXiv preprint arXiv:2402.12372}, + year={2024} +} +~~~ diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md b/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md new file mode 100644 index 000000000..20d9b643b --- /dev/null +++ b/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md @@ -0,0 +1,121 @@ +# HunFlair2 - Tutorial 1: Tagging + +This is part 1 of the tutorial, in which we show how to use our pre-trained *HunFlair2* models to tag your text. + +### Tagging with Pre-trained HunFlair2-Models +Let's use the pre-trained *HunFlair2* model for biomedical named entity recognition (NER). +This model was trained over multiple biomedical NER data sets and can recognize 5 different entity types, +i.e. cell lines, chemicals, disease, gene / proteins and species. +```python +from flair.nn import Classifier + +tagger = Classifier.load("hunflair2") +``` +All you need to do is use the predict() method of the tagger on a sentence. +This will add predicted tags to the tokens in the sentence. +Lets use a sentence with four named entities: +```python +from flair.data import Sentence + +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# predict NER tags +tagger.predict(sentence) + +# print the predicted tags +for entity in sentence.get_labels(): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → Disease (1.0) +Span[4:5]: "Fmr1" → Gene (1.0) +Span[6:7]: "Mouse" → Species (1.0) +Span[9:12]: "Fragile X Syndrome" → Disease (1.0) +``` +The output indicates that there are two diseases mentioned in the text ("_Behavioral Abnormalities_" and +"_Fragile X Syndrome_") as well as one gene ("_fmr1_") and one species ("_Mouse_"). For each entity the +text span in the sentence mention it is given and Label with a value and a score (confidence in the +prediction). You can also get additional information, such as the position offsets of each entity +in the sentence in a structured way by calling the `to_dict()` method: + +```python +print(sentence.to_dict()) +``` +This should print: +```python +{ + 'text': 'Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome', + 'labels': [], + 'entities': [ + {'text': 'Behavioral abnormalities', 'start_pos': 0, 'end_pos': 24, 'labels': [{'value': 'Disease', 'confidence': 0.9999860525131226}]}, + {'text': 'Fmr1', 'start_pos': 32, 'end_pos': 36, 'labels': [{'value': 'Gene', 'confidence': 0.9999895095825195}]}, + {'text': 'Mouse', 'start_pos': 41, 'end_pos': 46, 'labels': [{'value': 'Species', 'confidence': 0.9999873638153076}]}, + {'text': 'Fragile X Syndrome', 'start_pos': 56, 'end_pos': 74, 'labels': [{'value': 'Disease', 'confidence': 0.9999928871790568}]} + ], + # further sentence information +} +``` + +### Using a Biomedical Tokenizer +Tokenization, i.e. separating a text into tokens / words, is an important issue in natural language processing +in general and biomedical text mining in particular. So far, we used a tokenizer for general domain text. +This can be unfavourable if applied to biomedical texts. + +*HunFlair2* integrates [SciSpaCy](https://allenai.github.io/scispacy/), a library specially designed to work with scientific text. +To use the library we first have to install it and download one of it's models: +~~~ +pip install scispacy==0.5.1 +pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz +~~~ + +To use the tokenizer we just have to pass it as parameter to when instancing a sentence: +```python +from flair.tokenization import SciSpacyTokenizer + +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", + use_tokenizer=SciSpacyTokenizer()) +``` + +### Working with longer Texts +Often, we are concerned with complete scientific abstracts or full-texts when performing biomedical text mining, e.g. +```python +abstract = "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, " \ + "coding for the FMRP protein which is largely involved in synaptic function. FXS patients present several " \ + "behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive " \ + "deficits. Autistic symptoms, e.g., altered social interaction and communication, are also often observed: " \ + "FXS is indeed the most common monogenic cause of autism." +``` + +To work with complete abstracts or full-text, we first have to split them into separate sentences. +Again we can apply the integration of the [SciSpaCy](https://allenai.github.io/scispacy/) library: +```python +from flair.splitter import SciSpacySentenceSplitter + +# initialize the sentence splitter +splitter = SciSpacySentenceSplitter() + +# split text into a list of Sentence objects +sentences = splitter.split(abstract) + +# you can apply the HunFlair tagger directly to this list +tagger.predict(sentences) +``` +We can access the annotations of the single sentences by just iterating over the list: +```python +for sentence in sentences: + print(sentence.to_tagged_string()) +``` +This should print: +~~~ +Sentence[35]: "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, coding for the FMRP protein which is largely involved in synaptic function." \ + → ["Fragile X syndrome"/Disease, "FXS"/Disease, "developmental disorder"/Disease, "X-linked"/Gene, "FMR1"/Gene, "FMRP"/Gene] +Sentence[23]: "FXS patients present several behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive deficits." \ + → ["FXS"/Disease, "patients"/Species, "behavioral abnormalities"/Disease, "hyperactivity"/Disease, "anxiety"/Disease, "sensory hyper-responsiveness"/Disease, "cognitive deficits"/Disease] +Sentence[27]: "Autistic symptoms, e.g., altered social interaction and communication, are also often observed: FXS is indeed the most common monogenic cause of autism." \ + → ["Autistic symptoms"/Disease, "altered social interaction and communication"/Disease, "FXS"/Disease, "autism"/Disease] +~~~ + +### Next +Now, let us look at how to [link / normalize the entities to standard ontologies](HUNFLAIR2_TUTORIAL_2_LINKING.md) +in the second tutorial. diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md b/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md new file mode 100644 index 000000000..182fd0cd9 --- /dev/null +++ b/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md @@ -0,0 +1,88 @@ +# HunFlair2 - Tutorial 2: Entity Linking + +[Part 1](HUNFLAIR2_TUTORIAL_1_TAGGING.md) of the tutorial, showed how to use our pre-trained *HunFlair2* models to +tag biomedical entities in your text. However, documents from different biomedical (sub-) fields may use different +terms to refer to the exact same concept, e.g., “_tumor protein p53_”, “_tumor suppressor p53_”, “_TRP53_” are all +valid names for the gene “TP53” ([NCBI Gene:7157](https://www.ncbi.nlm.nih.gov/gene/7157)). +For improved integration and aggregation of entity mentions from multiple different documents linking / normalizing +the entities to standardized ontologies or knowledge bases is required. + +### Linking with pre-trained HunFlair2 Models + +After adding named entity recognition tags to your sentence, you can link the entities to standard ontologies +using distinct, type-specific linking models: + +```python +from flair.models import EntityMentionLinker +from flair.nn import Classifier +from flair.data import Sentence + +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in mouse populations." +) + +# Tag named entities in the text +ner_tagger = Classifier.load("hunflair2") +ner_tagger.predict(sentence) + +# Load disease linker and perform disease linking +disease_linker = EntityMentionLinker.load("disease-linker") +disease_linker.predict(sentence) + +# Load gene linker and perform gene linking +gene_linker = EntityMentionLinker.load("gene-linker") +gene_linker.predict(sentence) + +# Load chemical linker and perform chemical linking +chemical_linker = EntityMentionLinker.load("chemical-linker") +chemical_linker.predict(sentence) + +# Load species linker and perform species linking +species_linker = EntityMentionLinker.load("species-linker") +species_linker.predict(sentence) +``` + +**Note**: the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed, +which might takes a certain amount of time. All further calls are then based on this pre-processing and run +much faster. + +After running the code we can inspect and output the linked entities via: + +```python +for tag in sentence.get_labels("link"): + print(tag) +``` + +This should print: + +``` +Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062) +Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875) +Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203) +Span[23:24]: "mercury" → MESH:D008628/name=Mercury (220.39199829101562) +Span[25:26]: "mouse" → 10090/name=Mus musculus (213.6201934814453) +``` + +For each entity, the output contains both the NER mention annotations and their ontology identifiers to which +the mentions were mapped. Moreover, the official name of the entity in the ontology and the similarity score +of the entity mention and the ontology concept is given. For instance, the official name for the disease +"_X-linked adrenoleukodystrophy_" is adrenoleukodystrophy. The similarity scores are specific to entity type, +ontology and linking model used and can therefore only be compared and related to those using the exact same +setup. + +### Overview of pre-trained Entity Linking Models + +HunFlair2 comes with the following pre-trained linking models: + +| Entity Type | Model Name | Ontology / Dictionary | Linking Algorithm / Base Model (Data Set) | +| ----------- | ----------------- | ---------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| Chemical | `chemical-linker` | [CTD Chemicals](https://ctdbase.org/downloads/#allchems) | [SapBERT (BC5CDR)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-chemical) | +| Disease | `disease-linker` | [CTD Diseases](https://ctdbase.org/downloads/#alldiseases) | [SapBERT (NCBI Disease)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-disease) | +| Gene | `gene-linker` | [NCBI Gene (Human)](https://www.ncbi.nlm.nih.gov/gene) | [SapBERT (BC2GN)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc2gn) | +| Species | `species-linker` | [NCBI Taxonmy](https://www.ncbi.nlm.nih.gov/taxonomy) | [SapBERT (UMLS)](https://huggingface.co/cambridgeltl/SapBERT-from-PubMedBERT-fulltext) | + +For detailed information concerning the different models and their integration please refer to [our paper](https://arxiv.org/abs/2402.12372). + +If you wish to customize the models and dictionaries please refer to the [dedicated tutorial](HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md). diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md b/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md new file mode 100644 index 000000000..762fe14b9 --- /dev/null +++ b/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md @@ -0,0 +1,223 @@ +# HunFlair2 Tutorial 3: Training NER models + +This part of the tutorial shows how you can train your own biomedical named entity recognition models +using state-of-the-art pretrained Transformers embeddings. + +For this tutorial, we assume that you're familiar with the [base types](https://flairnlp.github.io/docs/tutorial-basics/basic-types) of Flair +and how [transformers_word embeddings](https://flairnlp.github.io/docs/tutorial-training/how-to-train-sequence-tagger). +You should also know how to [load a corpus](https://flairnlp.github.io/docs/tutorial-training/how-to-load-prepared-dataset). + +## Train a biomedical NER model from scratch: single entity type + +Here is example code for a biomedical NER model trained on the `NCBI_DISEASE` corpus using Transformer word embeddings. +This will result in a tagger specialized for a single entity type, i.e. "Disease". + +```python +from flair.datasets import NCBI_DISEASE + +# 1. get the corpus +corpus = NCBI_DISEASE() +print(corpus) + +# 2. make the tag dictionary from the corpus +tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False) + +# 3. initialize embeddings +from flair.embeddings import TransformerWordEmbeddings + +embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings( + "michiyasunaga/BioLinkBERT-base", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, + model_max_length=512, +) + +# 4. initialize sequence tagger +from flair.models import SequenceTagger + +tagger: SequenceTagger = SequenceTagger( + hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_format="BIOES", + tag_type="ner", + use_crf=False, + use_rnn=False, + reproject_embeddings=False, +) + +# 5. initialize trainer +from flair.trainers import ModelTrainer + +trainer: ModelTrainer = ModelTrainer(tagger, corpus) + +trainer.fine_tune( + base_path="taggers/ncbi-disease", + train_with_dev=False, + max_epochs=16, + learning_rate=2.0e-5, + mini_batch_size=16, + shuffle=False, +) +``` + +Once the model is trained you can use it to predict tags for new sentences. +Just call the predict method of the model. + +```python +# load the model you trained +model = SequenceTagger.load("taggers/ncbi-disease/best-model.pt") + +# create example sentence +from flair.data import Sentence +sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.") + +# predict tags and print +model.predict(sentence) + +print(sentence.to_tagged_string()) +``` + +If the model works well, it will correctly tag "breast cancer" as disease in this example: + +``` +Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer . +``` + +## Train a biomedical NER model: multiple entity types + +If you are dealing with multiple entity types, e.g. "Disease" and "Chemicals", you can opt +to train a single model capable of handling multiple entity types at once. +This can be achieved by using the `PrefixedSequenceTagger()` class, which implements the method described in \[1\]. +This model uses prompting, i.e. it adds a prefix (hence the name) string in front of specifying the +entity types requested for tagging: `[Tag , , ...]`. +Thist is especially usefull for training, as it allows to combine multiple corpora even if they cover different subsets of entity types. + +```python +# 1. get the corpora +from flair.datasets.biomedical import HUNER_ALL_CDR, HUNER_CHEMICAL_NLM_CHEM +corpora = (HUNER_ALL_CDR(), HUNER_CHEMICAL_NLM_CHEM()) + +# 2. add prefixed strings to each corpus by prepending its tagged entity +# types "[Tag , , ...]" +from flair.data import MultiCorpus +from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy +from flair.datasets.biomedical import ( + BIGBIO_NER_CORPUS, + CELL_LINE_TAG, + CHEMICAL_TAG, + DISEASE_TAG, + GENE_TAG, + SPECIES_TAG, +) + +mapping = { + CELL_LINE_TAG: "cell lines", + CHEMICAL_TAG: "chemicals", + DISEASE_TAG: "diseases", + GENE_TAG: "genes", + SPECIES_TAG: "species", +} + +prefixed_corpora = [] +all_entity_types = set() +for corpus in corpora: + entity_types = sorted( + set( + [ + mapping[tag] + for tag in corpus.get_entity_type_mapping().values() + ] + ) + ) + all_entity_types.update(set(entity_types)) + + print(f"Entity types in {corpus}: {entity_types}") + + augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy( + entity_types + ) + prefixed_corpora.append( + augmentation_strategy.augment_corpus(corpus) + ) + +corpora = MultiCorpus(prefixed_corpora) +all_entity_types = sorted(all_entity_types) + +# 3. make the tag dictionary from the corpus +tag_dictionary = corpus.make_label_dictionary(label_type="ner") + +# 4. the final model will on default predict all the entity types seen +# in the training corpora, e.g., disease and chemicals here +augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy( + all_entity_types +) + +# 5. initialize embeddings +from flair.embeddings import TransformerWordEmbeddings + +embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings( + "michiyasunaga/BioLinkBERT-base", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, + model_max_length=512, +) + +# 4. initialize sequence tagger +from flair.models.prefixed_tagger import PrefixedSequenceTagger + +tagger: SequenceTagger = PrefixedSequenceTagger( + hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_format="BIOES", + tag_type="ner", + use_crf=False, + use_rnn=False, + reproject_embeddings=False, + augmentation_strategy=augmentation_strategy, +) + +# 5. initialize trainer +from flair.trainers import ModelTrainer + +trainer: ModelTrainer = ModelTrainer(tagger, corpus) + +trainer.fine_tune( + base_path="taggers/cdr_nlm_chem", + train_with_dev=False, + max_epochs=16, + learning_rate=2.0e-5, + mini_batch_size=16, + shuffle=False, +) +``` + +## Training HunFlair2 from scratch + +*HunFlair2* uses the `PrefixedSequenceTagger()` class as defined above but adds the following corpora to the training set instead: + +```python +from flair.datasets.biomedical import ( + HUNER_ALL_BIORED, HUNER_GENE_NLM_GENE, + HUNER_GENE_GNORMPLUS, HUNER_ALL_SCAI, + HUNER_CHEMICAL_NLM_CHEM, HUNER_SPECIES_LINNEAUS, + HUNER_SPECIES_S800, HUNER_DISEASE_NCBI +) + +corpora = ( + HUNER_ALL_BIORED(), HUNER_GENE_NLM_GENE(), + HUNER_GENE_GNORMPLUS(), HUNER_ALL_SCAI(), + HUNER_CHEMICAL_NLM_CHEM(), HUNER_SPECIES_LINNEAUS(), + HUNER_SPECIES_S800(), HUNER_DISEASE_NCBI() +) + +``` + +## References + +\[1\] Luo, L., Wei, C. H., Lai, P. T., Leaman, R., Chen, Q., & Lu, Z. (2023). AIONER: all-in-one scheme-based biomedical named entity recognition using deep learning. Bioinformatics, 39(5), btad310. diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md b/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md new file mode 100644 index 000000000..d435aefcc --- /dev/null +++ b/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md @@ -0,0 +1,145 @@ +# HunFlair2 Tutorial 4: Customizing linking models + +In this tutorial you'll find information on how to customize the entity linking models according to your needs. +As of now, fine-tuning the models is not supported. + +## Customize dictionary + +All linking models come with a pre-defined pairing of entity type and dictionary, +e.g. "Disease" mentions are linked by default to the [CTD Diseases](https://ctdbase.org/help/diseaseDetailHelp.jsp). +You can change the dictionary to which mentions are linked by following the steps below. +We'll be using the [Human Phenotype Ontology](https://hpo.jax.org/app/) in our example +(Download the `hp.json` file you find [here](https://hpo.jax.org/app/data/ontology) if you want to follow along). + +First we load from the original data a python dictionary mapping names to concept identifiers + +```python +import json +from collections import defaultdict +with open("hp.json") as fp: + data = json.load(fp) + +nodes = [n for n in data['graphs'][0]['nodes'] if n.get('type') == 'CLASS'] +hpo = defaultdict(list) +for node in nodes: + concept_id = node['id'].replace('http://purl.obolibrary.org/obo/', '') + names = [node['lbl']] + [s['val'] for s in node.get('synonym', [])] + for name in names: + hpo[name].append(concept_id) +``` + +Then we can convert this mapping into a dictionary that can be used by our linking model: + +```python +from flair.datasets.entity_linking import ( + InMemoryEntityLinkingDictionary, + EntityCandidate, +) + +database_name="HPO" + +candidates = [ + EntityCandidate( + concept_id=ids[0], + concept_name=name, + additional_ids=ids[1:], + database_name=database_name, + ) + for name, ids in hpo.items() +] + +dictionary = InMemoryEntityLinkingDictionary( + candidates=candidates, dataset_name=database_name +) +``` + +To use this dictionary you need to initialize a new linker model with it. +See the section below for that. + +## Custom pre-trained model + +You can initialize a new linker model with both a custom model and custom dictionary (see section above) like this: + +```python +pretrained_model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext" +linker = EntityMentionLinker.build( + pretrained_model, + dictionary=dictionary, + hybrid_search=False, + entity_type="disease", + ) +``` + +Omitting the `dictionary` parameter will load the default dictionary for the specified `entity_type`. + +## Customizing Prediction Labels + +In the default setup all linker models output their prediction into the same annotation category *link*. +To record the NEN annotation in separate categories, you can use the `pred_label_type` parameter of the +`predict()` method: + +```python +gene_linker.predict(sentence, pred_label_type="my-genes") +disease_linker.predict(sentence, pred_label_type="my-diseases") + +print("Diseases:") +for disease_tag in sentence.get_labels("my-diseases"): + print(disease_tag) + +print("\nGenes:") +for gene_tag in sentence.get_labels("my-genes"): + print(gene_tag) +``` + +This will output: + +``` +Diseases: +Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875) +Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203) + +Genes: +Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062) +``` + +Moreover, each linker has a pre-defined configuration specifying for which NER annotations it should compute +entity links: + +```python +print(gene_linker.entity_label_types) +print(disease_linker.entity_label_types) +``` + +By default all models will use the *ner* annotation category and apply the linking algorithm for annotations +of the respective entity type: + +```python +{'ner': {'gene'}} +{'ner': {'disease'}} +``` + +You can customize this by using the `entity_label_types` parameter of the `predict()` method: + +```python +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in mouse populations." +) + +from flair.models import SequenceTagger + +# Use disease ner tagger from HunFlair v1 +hunflair1_tagger = SequenceTagger.load("hunflair-disease") +hunflair1_tagger.predict(sentence, label_name="my-diseases") + +# Use the entity_label_types parameter in predict() to specify the annotation category +disease_linker.predict(sentence, entity_label_types="my-diseases") +``` + +If you are using annotated texts with more fine-granular NER annotations you are able to specify the +annotation category and tag type using a dictionary. For instance: + +```python +gene_linker.predict(sentence, entity_label_types={"ner": {"gene": "protein"}}) +```