diff --git a/docs/tutorial/index.rst b/docs/tutorial/index.rst index 9636c33c5..288c9363c 100644 --- a/docs/tutorial/index.rst +++ b/docs/tutorial/index.rst @@ -10,4 +10,5 @@ Tutorials intro tutorial-basics/index tutorial-training/index - tutorial-embeddings/index \ No newline at end of file + tutorial-embeddings/index + tutorial-hunflair2/index \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/entity-mention-linking.md b/docs/tutorial/tutorial-basics/entity-mention-linking.md index e15395500..409a394f5 100644 --- a/docs/tutorial/tutorial-basics/entity-mention-linking.md +++ b/docs/tutorial/tutorial-basics/entity-mention-linking.md @@ -1,6 +1,7 @@ # Using and creating entity mention linker -As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN aproach](https://huggingface.co/hunflair)]. +As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN approach](https://huggingface.co/hunflair)]. +You can read more at the [Hunflair2 tutorials](project:../tutorial-hunflair2/overview.md) ## Example 1: Printing Entity linking outputs to console @@ -19,7 +20,7 @@ sentence = Sentence( use_tokenizer=SciSpacyTokenizer() ) -ner_tagger = Classifier.load("hunflair") +ner_tagger = Classifier.load("hunflair2") ner_tagger.predict(sentence) nen_tagger = EntityMentionLinker.load("disease-linker-no-ab3p") @@ -31,7 +32,7 @@ for tag in sentence.get_labels(): ```{note} Here we use the `disease-linker-no-ab3p` model, as it is the simplest model to run. You might get better results by using `disease-linker` instead, - but under the hood ab3p uses an executeable that is only compiled for linux and therefore won't run on every system. + but that would require you to install `pyab3p` via `pip install pyab3p`. Analogously to `disease` there are also linker for `chemical`, `species` and `gene` all work with the `{entity_type}-linker` or `{entity_type}-linker-no-ab3p` naming-schema diff --git a/docs/tutorial/tutorial-hunflair2/customize-linking.md b/docs/tutorial/tutorial-hunflair2/customize-linking.md new file mode 100644 index 000000000..bef150782 --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/customize-linking.md @@ -0,0 +1,146 @@ +# HunFlair2 Tutorial 4: Customizing linking models + +In this tutorial you'll find information on how to customize the entity linking models according to your needs. +As of now, fine-tuning the models is not supported. + +## Customize dictionary + +All linking models come with a pre-defined pairing of entity type and dictionary, +e.g. "Disease" mentions are linked by default to the [CTD Diseases](https://ctdbase.org/help/diseaseDetailHelp.jsp). +You can change the dictionary to which mentions are linked by following the steps below. +We'll be using the [Human Phenotype Ontology](https://hpo.jax.org/app/) in our example +(Download the `hp.json` file you find [here](https://hpo.jax.org/app/data/ontology) if you want to follow along). + +First we load from the original data a python dictionary mapping names to concept identifiers + +```python +import json +from collections import defaultdict +with open("hp.json") as fp: + data = json.load(fp) + +nodes = [n for n in data['graphs'][0]['nodes'] if n.get('type') == 'CLASS'] +hpo = defaultdict(list) +for node in nodes: + concept_id = node['id'].replace('http://purl.obolibrary.org/obo/', '') + names = [node['lbl']] + [s['val'] for s in node.get('synonym', [])] + for name in names: + hpo[name].append(concept_id) +``` + +Then we can convert this mapping into a [`InMemoryEntityLinkingDictionary`](#flair.datasets.entity_linking.InMemoryEntityLinkingDictionary) that can be used by our linking model: + +```python +from flair.datasets.entity_linking import ( + InMemoryEntityLinkingDictionary, + EntityCandidate, +) + +database_name="HPO" + +candidates = [ + EntityCandidate( + concept_id=ids[0], + concept_name=name, + additional_ids=ids[1:], + database_name=database_name, + ) + for name, ids in hpo.items() +] + +dictionary = InMemoryEntityLinkingDictionary( + candidates=candidates, dataset_name=database_name +) +``` + +To use this dictionary you need to initialize a new linker model with it. +See the section below for that. + +## Custom pre-trained model + +You can initialize a new [`EntityMentionLinker`](#flair.models.EntityMentionLinker) with both a custom model and custom dictionary (see section above) like this: + +```python +from flair.models import EntityMentionLinker +pretrained_model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext" +linker = EntityMentionLinker.build( + pretrained_model, + dictionary=dictionary, + hybrid_search=False, + entity_type="disease", + ) +``` + +Omitting the `dictionary` parameter will load the default dictionary for the specified `entity_type`. + +## Customizing Prediction Labels + +In the default setup all linker models output their prediction into the same annotation category *link*. +To record the NEN annotation in separate categories, you can use the `pred_label_type` parameter of the +[`predict()`](#flair.models.EntityMentionLinker.predict) method: + +```python +gene_linker.predict(sentence, pred_label_type="my-genes") +disease_linker.predict(sentence, pred_label_type="my-diseases") + +print("Diseases:") +for disease_tag in sentence.get_labels("my-diseases"): + print(disease_tag) + +print("\nGenes:") +for gene_tag in sentence.get_labels("my-genes"): + print(gene_tag) +``` + +This will output: + +``` +Diseases: +Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875) +Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203) + +Genes: +Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062) +``` + +Moreover, each linker has a pre-defined configuration specifying for which NER annotations it should compute +entity links: + +```python +print(gene_linker.entity_label_types) +print(disease_linker.entity_label_types) +``` + +By default all models will use the *ner* annotation category and apply the linking algorithm for annotations +of the respective entity type: + +```python +{'ner': {'gene'}} +{'ner': {'disease'}} +``` + +You can customize this by using the `entity_label_types` parameter of the [`predict()`](#flair.models.EntityMentionLinker.predict) method: + +```python +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in mouse populations." +) + +from flair.models import SequenceTagger + +# Use disease ner tagger from HunFlair v1 +hunflair1_tagger = SequenceTagger.load("hunflair-disease") +hunflair1_tagger.predict(sentence, label_name="my-diseases") + +# Use the entity_label_types parameter in predict() to specify the annotation category +disease_linker.predict(sentence, entity_label_types="my-diseases") +``` + +If you are using annotated texts with more fine-granular NER annotations you are able to specify the +annotation category and tag type using a dictionary. For instance: + +```python +gene_linker.predict(sentence, entity_label_types={"ner": {"gene": "protein"}}) +``` diff --git a/docs/tutorial/tutorial-hunflair2/index.rst b/docs/tutorial/tutorial-hunflair2/index.rst new file mode 100644 index 000000000..7097f4280 --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/index.rst @@ -0,0 +1,17 @@ +Tutorial: HunFlair2 +=================== + +*HunFlair2* is a state-of-the-art named entity tagger and linker for biomedical texts. It comes with +models for genes/proteins, chemicals, diseases, species and cell lines. *HunFlair2* +builds on pretrained domain-specific language models and outperforms other biomedical +NER tools on unseen corpora. + +.. toctree:: + :glob: + :maxdepth: 1 + + overview + tagging + linking + training-ner-models + customize-linking diff --git a/docs/tutorial/tutorial-hunflair2/linking.md b/docs/tutorial/tutorial-hunflair2/linking.md new file mode 100644 index 000000000..d6ff1d408 --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/linking.md @@ -0,0 +1,90 @@ +# HunFlair2 - Tutorial 2: Entity Linking + +[Part 1](project:./tagging.md) of the tutorial, showed how to use our pre-trained *HunFlair2* models to +tag biomedical entities in your text. However, documents from different biomedical (sub-) fields may use different +terms to refer to the exact same concept, e.g., “_tumor protein p53_”, “_tumor suppressor p53_”, “_TRP53_” are all +valid names for the gene “TP53” ([NCBI Gene:7157](https://www.ncbi.nlm.nih.gov/gene/7157)). +For improved integration and aggregation of entity mentions from multiple different documents linking / normalizing +the entities to standardized ontologies or knowledge bases is required. + +## Linking with pre-trained HunFlair2 Models + +After adding named entity recognition tags to your sentence, you can link the entities to standard ontologies +using distinct, type-specific linking models: + +```python +from flair.models import EntityMentionLinker +from flair.nn import Classifier +from flair.data import Sentence + +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in mouse populations." +) + +# Tag named entities in the text +ner_tagger = Classifier.load("hunflair2") +ner_tagger.predict(sentence) + +# Load disease linker and perform disease linking +disease_linker = EntityMentionLinker.load("disease-linker") +disease_linker.predict(sentence) + +# Load gene linker and perform gene linking +gene_linker = EntityMentionLinker.load("gene-linker") +gene_linker.predict(sentence) + +# Load chemical linker and perform chemical linking +chemical_linker = EntityMentionLinker.load("chemical-linker") +chemical_linker.predict(sentence) + +# Load species linker and perform species linking +species_linker = EntityMentionLinker.load("species-linker") +species_linker.predict(sentence) +``` + +```{note} +the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed, +which might takes a certain amount of time. All further calls are then based on this pre-processing and run +much faster. +``` + +After running the code we can inspect and output the linked entities via: + +```python +for tag in sentence.get_labels("link"): + print(tag) +``` + +This should print: + +``` +Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062) +Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875) +Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203) +Span[23:24]: "mercury" → MESH:D008628/name=Mercury (220.39199829101562) +Span[25:26]: "mouse" → 10090/name=Mus musculus (213.6201934814453) +``` + +For each entity, the output contains both the NER mention annotations and their ontology identifiers to which +the mentions were mapped. Moreover, the official name of the entity in the ontology and the similarity score +of the entity mention and the ontology concept is given. For instance, the official name for the disease +"_X-linked adrenoleukodystrophy_" is adrenoleukodystrophy. The similarity scores are specific to entity type, +ontology and linking model used and can therefore only be compared and related to those using the exact same +setup. + +## Overview of pre-trained Entity Linking Models + +HunFlair2 comes with the following pre-trained linking models: + +| Entity Type | Model Name | Ontology / Dictionary | Linking Algorithm / Base Model (Data Set) | +| ----------- | ----------------- | ---------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| Chemical | `chemical-linker` | [CTD Chemicals](https://ctdbase.org/downloads/#allchems) | [SapBERT (BC5CDR)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-chemical) | +| Disease | `disease-linker` | [CTD Diseases](https://ctdbase.org/downloads/#alldiseases) | [SapBERT (NCBI Disease)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-disease) | +| Gene | `gene-linker` | [NCBI Gene (Human)](https://www.ncbi.nlm.nih.gov/gene) | [SapBERT (BC2GN)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc2gn) | +| Species | `species-linker` | [NCBI Taxonmy](https://www.ncbi.nlm.nih.gov/taxonomy) | [SapBERT (UMLS)](https://huggingface.co/cambridgeltl/SapBERT-from-PubMedBERT-fulltext) | + +For detailed information concerning the different models and their integration please refer to [our paper](https://arxiv.org/abs/2402.12372). + +If you wish to customize the models and dictionaries please refer to the [dedicated tutorial](project:./customize-linking.md). diff --git a/docs/tutorial/tutorial-hunflair2/overview.md b/docs/tutorial/tutorial-hunflair2/overview.md new file mode 100644 index 000000000..bf62d7fa8 --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/overview.md @@ -0,0 +1,121 @@ +# HunFlair2 - Overview + +*HunFlair2* is a state-of-the-art named entity tagger and linker for biomedical texts. It comes with +models for genes/proteins, chemicals, diseases, species and cell lines. *HunFlair2* +builds on pretrained domain-specific language models and outperforms other biomedical +NER tools on unseen corpora. + +## Quick Start + + +### Example 1: Biomedical NER +Let's run named entity recognition (NER) over an example sentence. All you need to do is +make a Sentence, load a pre-trained model and use it to predict tags for the sentence: +```python +from flair.data import Sentence +from flair.nn import Classifier + +# make a sentence +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# load biomedical NER tagger +tagger = Classifier.load("hunflair2") + +# tag sentence +tagger.predict(sentence) +``` +Done! The [`Sentence`](#flair.data.Sentence) now has entity annotations. Let's print the entities found by the tagger: +```python +for entity in sentence.get_labels(): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → Disease (1.0) +Span[4:5]: "Fmr1" → Gene (1.0) +Span[6:7]: "Mouse" → Species (1.0) +Span[9:12]: "Fragile X Syndrome" → Disease (1.0) +``` + +### Example 2: Biomedical NEN +For improved integration and aggregation from multiple different documents linking / normalizing the entities to +standardized ontologies or knowledge bases is required. Let's perform entity normalization by using +specialized models per entity type: +```python +from flair.data import Sentence +from flair.models import EntityMentionLinker +from flair.nn import Classifier + +# make a sentence +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# load biomedical NER tagger + predict entities +tagger = Classifier.load("hunflair2") +tagger.predict(sentence) + +# load gene linker and perform normalization +gene_linker = EntityMentionLinker.load("gene-linker") +gene_linker.predict(sentence) + +# load disease linker and perform normalization +disease_linker = EntityMentionLinker.load("disease-linker") +disease_linker.predict(sentence) + +# load species linker and perform normalization +species_linker = EntityMentionLinker.load("species-linker") +species_linker.predict(sentence) +``` + +```{note} +the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed, +which might takes a certain amount of time. All further calls are then based on this pre-processing and run +much faster. +``` + +Done! The Sentence now has entity normalizations. Let's print the entity identifiers found by the linkers: +```python +for entity in sentence.get_labels("link"): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → MESH:D001523/name=Mental Disorders (197.9467010498047) +Span[4:5]: "Fmr1" → 108684022/name=FRAXA (219.9510040283203) +Span[6:7]: "Mouse" → 10090/name=Mus musculus (213.6201934814453) +Span[9:12]: "Fragile X Syndrome" → MESH:D005600/name=Fragile X Syndrome (193.7115020751953) +``` + +## Comparison to other biomedical entity extraction tools +Tools for biomedical entity extraction are typically trained and evaluated on single, rather small gold standard +data sets. However, they are applied "in the wild" to a much larger collection of texts, often varying in +topic, entity distribution, genre (e.g. patents vs. scientific articles) and text type (e.g. abstract +vs. full text), which can lead to severe drops in performance. + +*HunFlair2* outperforms other biomedical entity extraction tools on corpora not used for training of neither +*HunFlair2* nor any of the competitor tools. + +| Corpus | Entity Type | BENT | BERN2 | PubTator Central | SciSpacy | HunFlair | +|----------------------------------------------------------------------------------------------|-------------|-------|-------|------------------|----------|-------------| +| [MedMentions](https://github.com/chanzuckerberg/MedMentions) | Chemical | 40.90 | 41.79 | 31.28 | 34.95 | *__51.17__* | +| | Disease | 45.94 | 47.33 | 41.11 | 40.78 | *__57.27__* | +| [tmVar (v3)](https://github.com/ncbi/tmVar3?tab=readme-ov-file) | Gene | 0.54 | 43.96 | *__86.02__* | - | 76.75 | +| [BioID](https://biocreative.bioinformatics.udel.edu/media/store/files/2018/BC6_track1_1.pdf) | Species | 10.35 | 14.35 | *__58.90__* | 37.14 | 49.66 | +||||| +| Average | All | 24.43 | 36.86 | 54.33 | 37.61 | *__58.79__* | + +All results are F1 scores highlighting end-to-end performance, i.e., named entity recognition and normalization, +using partial matching of predicted text offsets with the original char offsets of the gold standard data. +We allow a shift by max one character. + +You can find detailed evaluations and discussions in [our paper](https://arxiv.org/abs/2402.12372). + +## Citing HunFlair2 +Please cite the following paper when using *HunFlair2*: +~~~ +@article{sanger2024hunflair2, + title={HunFlair2 in a cross-corpus evaluation of biomedical named entity recognition and normalization tools}, + author={S{\"a}nger, Mario and Garda, Samuele and Wang, Xing David and Weber-Genzel, Leon and Droop, Pia and Fuchs, Benedikt and Akbik, Alan and Leser, Ulf}, + journal={arXiv preprint arXiv:2402.12372}, + year={2024} +} +~~~ diff --git a/docs/tutorial/tutorial-hunflair2/tagging.md b/docs/tutorial/tutorial-hunflair2/tagging.md new file mode 100644 index 000000000..c89be4d59 --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/tagging.md @@ -0,0 +1,119 @@ +# HunFlair2 - Tutorial 1: Tagging + +This is part 1 of the tutorial, in which we show how to use our pre-trained *HunFlair2* models to tag your text. + +## Tagging with Pre-trained HunFlair2-Models +Let's use the pre-trained *HunFlair2* model for biomedical named entity recognition (NER). +This model was trained over multiple biomedical NER data sets and can recognize 5 different entity types, +i.e. cell lines, chemicals, disease, gene / proteins and species. +```python +from flair.nn import Classifier + +tagger = Classifier.load("hunflair2") +``` +All you need to do is use the predict() method of the tagger on a sentence. +This will add predicted tags to the tokens in the sentence. +Lets use a sentence with four named entities: +```python +from flair.data import Sentence + +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome") + +# predict NER tags +tagger.predict(sentence) + +# print the predicted tags +for entity in sentence.get_labels(): + print(entity) +``` +This should print: +```console +Span[0:2]: "Behavioral abnormalities" → Disease (1.0) +Span[4:5]: "Fmr1" → Gene (1.0) +Span[6:7]: "Mouse" → Species (1.0) +Span[9:12]: "Fragile X Syndrome" → Disease (1.0) +``` +The output indicates that there are two diseases mentioned in the text ("_Behavioral Abnormalities_" and +"_Fragile X Syndrome_") as well as one gene ("_fmr1_") and one species ("_Mouse_"). For each entity the +text span in the sentence mention it is given and Label with a value and a score (confidence in the +prediction). You can also get additional information, such as the position offsets of each entity +in the sentence in a structured way by calling the `to_dict()` method: + +```python +print(sentence.to_dict()) +``` +This should print: +```python +{ + 'text': 'Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome', + 'labels': [], + 'entities': [ + {'text': 'Behavioral abnormalities', 'start_pos': 0, 'end_pos': 24, 'labels': [{'value': 'Disease', 'confidence': 0.9999860525131226}]}, + {'text': 'Fmr1', 'start_pos': 32, 'end_pos': 36, 'labels': [{'value': 'Gene', 'confidence': 0.9999895095825195}]}, + {'text': 'Mouse', 'start_pos': 41, 'end_pos': 46, 'labels': [{'value': 'Species', 'confidence': 0.9999873638153076}]}, + {'text': 'Fragile X Syndrome', 'start_pos': 56, 'end_pos': 74, 'labels': [{'value': 'Disease', 'confidence': 0.9999928871790568}]} + ], + # further sentence information +} +``` + +## Using a Biomedical Tokenizer +Tokenization, i.e. separating a text into tokens / words, is an important issue in natural language processing +in general and biomedical text mining in particular. So far, we used a tokenizer for general domain text. +This can be unfavourable if applied to biomedical texts. + +*HunFlair2* integrates [SciSpaCy](https://allenai.github.io/scispacy/), a library specially designed to work with scientific text. +To use the library we first have to install it and download one of its models: +~~~ +pip install scispacy==0.5.1 +pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz +~~~ + +Then we can use the [`SciSpacyTokenizer`](#flair.tokenization.SciSpacyTokenizer), we just have to pass it as parameter to when instancing a sentence: +```python +from flair.tokenization import SciSpacyTokenizer + +tokenizer = SciSpacyTokenizer() + +sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome", + use_tokenizer=tokenizer) +``` + +## Working with longer Texts +Often, we are concerned with complete scientific abstracts or full-texts when performing biomedical text mining, e.g. +```python +abstract = "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, " \ + "coding for the FMRP protein which is largely involved in synaptic function. FXS patients present several " \ + "behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive " \ + "deficits. Autistic symptoms, e.g., altered social interaction and communication, are also often observed: " \ + "FXS is indeed the most common monogenic cause of autism." +``` + +To work with complete abstracts or full-text, we first have to split them into separate sentences. +We can apply the [`SciSpacySentenceSplitter`](#flair.splitter.SciSpacySentenceSplitter), an integration of the [SciSpaCy](https://allenai.github.io/scispacy/) library: +```python +from flair.splitter import SciSpacySentenceSplitter + +# initialize the sentence splitter +splitter = SciSpacySentenceSplitter() + +# split text into a list of Sentence objects +sentences = splitter.split(abstract) + +# you can apply the HunFlair tagger directly to this list +tagger.predict(sentences) +``` +We can access the annotations of the single sentences by just iterating over the list: +```python +for sentence in sentences: + print(sentence.to_tagged_string()) +``` +This should print: +~~~ +Sentence[35]: "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, coding for the FMRP protein which is largely involved in synaptic function." \ + → ["Fragile X syndrome"/Disease, "FXS"/Disease, "developmental disorder"/Disease, "X-linked"/Gene, "FMR1"/Gene, "FMRP"/Gene] +Sentence[23]: "FXS patients present several behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive deficits." \ + → ["FXS"/Disease, "patients"/Species, "behavioral abnormalities"/Disease, "hyperactivity"/Disease, "anxiety"/Disease, "sensory hyper-responsiveness"/Disease, "cognitive deficits"/Disease] +Sentence[27]: "Autistic symptoms, e.g., altered social interaction and communication, are also often observed: FXS is indeed the most common monogenic cause of autism." \ + → ["Autistic symptoms"/Disease, "altered social interaction and communication"/Disease, "FXS"/Disease, "autism"/Disease] +~~~ diff --git a/docs/tutorial/tutorial-hunflair2/training-ner-models.md b/docs/tutorial/tutorial-hunflair2/training-ner-models.md new file mode 100644 index 000000000..f768b0abf --- /dev/null +++ b/docs/tutorial/tutorial-hunflair2/training-ner-models.md @@ -0,0 +1,223 @@ +# HunFlair2 Tutorial 3: Training NER models + +This part of the tutorial shows how you can train your own biomedical named entity recognition models +using state-of-the-art pretrained Transformers embeddings. + +For this tutorial, we assume that you're familiar with the [base types](project:../tutorial-basics/basic-types.md) of Flair +and how [transformers_word embeddings](project:../tutorial-training/how-to-train-sequence-tagger.md). +You should also know how to [load a corpus](project:../tutorial-training/how-to-load-prepared-dataset.md). + +## Train a biomedical NER model from scratch: single entity type + +Here is example code for a biomedical NER model trained on the [`NCBI_DISEASE`](#flair.datasets.biomedical.NCBI_DISEASE) corpus using Transformer word embeddings. +This will result in a tagger specialized for a single entity type, i.e. "Disease". + +```python +from flair.datasets import NCBI_DISEASE + +# 1. get the corpus +corpus = NCBI_DISEASE() +print(corpus) + +# 2. make the tag dictionary from the corpus +tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False) + +# 3. initialize embeddings +from flair.embeddings import TransformerWordEmbeddings + +embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings( + "michiyasunaga/BioLinkBERT-base", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, + model_max_length=512, +) + +# 4. initialize sequence tagger +from flair.models import SequenceTagger + +tagger: SequenceTagger = SequenceTagger( + hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_format="BIOES", + tag_type="ner", + use_crf=False, + use_rnn=False, + reproject_embeddings=False, +) + +# 5. initialize trainer +from flair.trainers import ModelTrainer + +trainer: ModelTrainer = ModelTrainer(tagger, corpus) + +trainer.fine_tune( + base_path="taggers/ncbi-disease", + train_with_dev=False, + max_epochs=16, + learning_rate=2.0e-5, + mini_batch_size=16, + shuffle=False, +) +``` + +Once the model is trained you can use it to predict tags for new sentences. +Just call the predict method of the model. + +```python +# load the model you trained +model = SequenceTagger.load("taggers/ncbi-disease/best-model.pt") + +# create example sentence +from flair.data import Sentence +sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.") + +# predict tags and print +model.predict(sentence) + +print(sentence.to_tagged_string()) +``` + +If the model works well, it will correctly tag "breast cancer" as disease in this example: + +``` +Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer . +``` + +## Train a biomedical NER model: multiple entity types + +If you are dealing with multiple entity types, e.g. "Disease" and "Chemicals", you can opt +to train a single model capable of handling multiple entity types at once. +This can be achieved by using the [`PrefixedSequenceTagger()`](#flair.models.PrefixedSequenceTagger) class, which implements the method described in \[1\]. +This model uses prompting, i.e. it adds a prefix (hence the name) string in front of specifying the +entity types requested for tagging: `[Tag , , ...]`. +This is especially useful for training, as it allows to combine multiple corpora even if they cover different subsets of entity types. + +```python +# 1. get the corpora +from flair.datasets.biomedical import HUNER_ALL_CDR, HUNER_CHEMICAL_NLM_CHEM +corpora = (HUNER_ALL_CDR(), HUNER_CHEMICAL_NLM_CHEM()) + +# 2. add prefixed strings to each corpus by prepending its tagged entity +# types "[Tag , , ...]" +from flair.data import MultiCorpus +from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy +from flair.datasets.biomedical import ( + BIGBIO_NER_CORPUS, + CELL_LINE_TAG, + CHEMICAL_TAG, + DISEASE_TAG, + GENE_TAG, + SPECIES_TAG, +) + +mapping = { + CELL_LINE_TAG: "cell lines", + CHEMICAL_TAG: "chemicals", + DISEASE_TAG: "diseases", + GENE_TAG: "genes", + SPECIES_TAG: "species", +} + +prefixed_corpora = [] +all_entity_types = set() +for corpus in corpora: + entity_types = sorted( + set( + [ + mapping[tag] + for tag in corpus.get_entity_type_mapping().values() + ] + ) + ) + all_entity_types.update(set(entity_types)) + + print(f"Entity types in {corpus}: {entity_types}") + + augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy( + entity_types + ) + prefixed_corpora.append( + augmentation_strategy.augment_corpus(corpus) + ) + +corpora = MultiCorpus(prefixed_corpora) +all_entity_types = sorted(all_entity_types) + +# 3. make the tag dictionary from the corpus +tag_dictionary = corpus.make_label_dictionary(label_type="ner") + +# 4. the final model will on default predict all the entity types seen +# in the training corpora, e.g., disease and chemicals here +augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy( + all_entity_types +) + +# 5. initialize embeddings +from flair.embeddings import TransformerWordEmbeddings + +embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings( + "michiyasunaga/BioLinkBERT-base", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, + model_max_length=512, +) + +# 4. initialize sequence tagger +from flair.models.prefixed_tagger import PrefixedSequenceTagger + +tagger: SequenceTagger = PrefixedSequenceTagger( + hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_format="BIOES", + tag_type="ner", + use_crf=False, + use_rnn=False, + reproject_embeddings=False, + augmentation_strategy=augmentation_strategy, +) + +# 5. initialize trainer +from flair.trainers import ModelTrainer + +trainer: ModelTrainer = ModelTrainer(tagger, corpus) + +trainer.fine_tune( + base_path="taggers/cdr_nlm_chem", + train_with_dev=False, + max_epochs=16, + learning_rate=2.0e-5, + mini_batch_size=16, + shuffle=False, +) +``` + +## Training HunFlair2 from scratch + +*HunFlair2* uses the `PrefixedSequenceTagger()` class as defined above but adds the following corpora to the training set instead: + +```python +from flair.datasets.biomedical import ( + HUNER_ALL_BIORED, HUNER_GENE_NLM_GENE, + HUNER_GENE_GNORMPLUS, HUNER_ALL_SCAI, + HUNER_CHEMICAL_NLM_CHEM, HUNER_SPECIES_LINNEAUS, + HUNER_SPECIES_S800, HUNER_DISEASE_NCBI +) + +corpora = ( + HUNER_ALL_BIORED(), HUNER_GENE_NLM_GENE(), + HUNER_GENE_GNORMPLUS(), HUNER_ALL_SCAI(), + HUNER_CHEMICAL_NLM_CHEM(), HUNER_SPECIES_LINNEAUS(), + HUNER_SPECIES_S800(), HUNER_DISEASE_NCBI() +) + +``` + +## References + +\[1\] Luo, L., Wei, C. H., Lai, P. T., Leaman, R., Chen, Q., & Lu, Z. (2023). AIONER: all-in-one scheme-based biomedical named entity recognition using deep learning. Bioinformatics, 39(5), btad310. diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 9d0e28a43..25a67cde8 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -1,14 +1,9 @@ +import importlib.util import inspect import logging -import os -import platform import re -import stat import string -import subprocess -import tempfile from abc import ABC, abstractmethod -from collections import defaultdict from enum import Enum, auto from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast @@ -33,7 +28,7 @@ from flair.datasets.entity_linking import InMemoryEntityLinkingDictionary from flair.embeddings import DocumentEmbeddings, DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings from flair.embeddings.base import load_embeddings -from flair.file_utils import cached_path, hf_download +from flair.file_utils import hf_download from flair.training_utils import Result logger = logging.getLogger("flair") @@ -261,22 +256,19 @@ class Ab3PEntityPreprocessor(EntityPreprocessor): def __init__( self, - ab3p_path: Optional[Path] = None, - word_data_dir: Optional[Path] = None, preprocessor: Optional[EntityPreprocessor] = None, ) -> None: """Creates the mention pre-processor. Args: - ab3p_path: Path to the folder containing the Ab3P implementation - word_data_dir: Path to the word data directory preprocessor: Basic entity preprocessor """ - if ab3p_path is not None and word_data_dir is not None: - self.ab3p_path = ab3p_path - self.word_data_dir = word_data_dir - else: - self.ab3p_path, self.word_data_dir = self._get_biosyn_ab3p_paths() + try: + import pyab3p + except ImportError: + raise ImportError("Please install pyab3p to use the `Ab3PEntityPreprocessor`") + self.ab3p = pyab3p.Ab3p() + self.preprocessor = preprocessor self.abbreviation_dict: Dict[str, Dict[str, str]] = {} @@ -298,7 +290,7 @@ def process_mention(self, entity_mention: str, sentence: Optional[Sentence] = No if self.preprocessor is not None: entity_mention = self.preprocessor.process_entity_name(entity_mention) - # NOTE: Avoid emtpy string if mentions are just punctutations (e.g. `-` or `(`) + # NOTE: Avoid emtpy string if mentions are just punctuations (e.g. `-` or `(`) entity_mention = original if len(entity_mention) == 0 else entity_mention return entity_mention @@ -311,53 +303,6 @@ def process_entity_name(self, entity_name: str) -> str: return entity_name - def _get_biosyn_ab3p_paths(self) -> Tuple[Path, Path]: - data_dir = flair.cache_root / "ab3p_biosyn" - if not data_dir.exists(): - data_dir.mkdir(parents=True) - - word_data_dir = data_dir / "word_data" - if not word_data_dir.exists(): - word_data_dir.mkdir() - - ab3p_path = self._download_biosyn_ab3p(data_dir, word_data_dir) - - return ab3p_path, word_data_dir - - def _download_biosyn_ab3p(self, data_dir: Path, word_data_dir: Path) -> Path: - """Downloads the Ab3P tool and all necessary data files.""" - # Download word data for Ab3P if not already downloaded - ab3p_url = "https://raw.githubusercontent.com/dmis-lab/BioSyn/master/Ab3P/WordData/" - - ab3p_files = [ - "Ab3P_prec.dat", - "Lf1chSf", - "SingTermFreq.dat", - "cshset_wrdset3.ad", - "cshset_wrdset3.ct", - "cshset_wrdset3.ha", - "cshset_wrdset3.nm", - "cshset_wrdset3.str", - "hshset_Lf1chSf.ad", - "hshset_Lf1chSf.ha", - "hshset_Lf1chSf.nm", - "hshset_Lf1chSf.str", - "hshset_stop.ad", - "hshset_stop.ha", - "hshset_stop.nm", - "hshset_stop.str", - "stop", - ] - for file in ab3p_files: - cached_path(ab3p_url + file, word_data_dir) - - # Download Ab3P executable - ab3p_path = cached_path("https://github.com/dmis-lab/BioSyn/raw/master/Ab3P/identify_abbr", data_dir) - - # Make Ab3P executable - ab3p_path.chmod(ab3p_path.stat().st_mode | stat.S_IXUSR) - return ab3p_path - def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict[str, Dict[str, str]]: """Processes the given sentences with the Ab3P tool. @@ -376,62 +321,13 @@ def _build_abbreviation_dict(self, sentences: List[flair.data.Sentence]) -> Dict Returns: abbreviation_dict: abbreviations and their resolution detected in each input sentence """ - abbreviation_dict: Dict = defaultdict(dict) - - # Create a temp file which holds the sentences we want to process with Ab3P - with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as temp_file: - for sentence in sentences: - temp_file.write(sentence.to_original_text() + "\n") - temp_file.flush() - - # Temporarily create path file in the current working directory for Ab3P - with open(os.path.join(os.getcwd(), "path_Ab3P"), "w") as path_file: - path_file.write(str(self.word_data_dir) + "/\n") - - # Run Ab3P with the temp file containing the dataset - # https://pylint.pycqa.org/en/latest/user_guide/messages/warning/subprocess-run-check.html - try: - result = subprocess.run( - [self.ab3p_path, temp_file.name], - capture_output=True, - check=True, - ) - except subprocess.CalledProcessError: - logger.error( - """The abbreviation resolver Ab3P could not be run on your system. To ensure maximum accuracy, please - install Ab3P yourself. See https://github.com/ncbi-nlp/Ab3P""" - ) - else: - line = result.stdout.decode("utf-8") - if "Path file for type cshset does not exist!" in line: - logger.error( - "Error when using Ab3P for abbreviation resolution. A file named path_Ab3p needs to exist in your current directory containing the path to the WordData directory for Ab3P to work!" - ) - elif "Cannot open" in line or "failed to open" in line: - logger.error( - "Error when using Ab3P for abbreviation resolution. Could not open the WordData directory for Ab3P!" - ) - - lines = line.split("\n") - cur_sentence = None - for line in lines: - if len(line.split("|")) == 3: - if cur_sentence is None: - continue - - sf, lf, _ = line.split("|") - sf = sf.strip() - lf = lf.strip() - abbreviation_dict[cur_sentence][sf] = lf + abbreviation_dict: Dict[str, Dict[str, str]] = {} - elif len(line.strip()) > 0: - cur_sentence = line - else: - cur_sentence = None - - finally: - # remove the path file - os.remove(os.path.join(os.getcwd(), "path_Ab3P")) + for sentence in sentences: + sentence_text = sentence.to_original_text() + abbreviation_dict[sentence_text] = { + abbr_out.short_form: abbr_out.long_form for abbr_out in self.ab3p.get_abbrs(sentence_text) + } return abbreviation_dict @@ -951,11 +847,12 @@ def _fetch_model(model_name: str) -> str: if model_name in hf_model_map: model_name = hf_model_map[model_name] - if platform.system() == "Windows": + if not model_name.endswith("-no-ab3p") and importlib.util.find_spec("pyab3p") is None: logger.warning( - "You seem to run your application on a Windows system. Unfortunately, the abbreviation " - "resolution of HunFlair2 is only available on Linux/Mac systems. Therefore, a model " - "without abbreviation resolution is therefore loaded" + "'pyab3p' is not found, switching to a model without abbreviation resolution. " + "This might impact the model performance. To reach full performance, please install" + "pyab3p by running:" + " pip install pyab3p" ) model_name += "-no-ab3p" @@ -969,7 +866,6 @@ def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs) -> "Entity label_type = state["label_type"] dictionary = InMemoryEntityLinkingDictionary.from_state(state["dictionary"]) batch_size = state.get("batch_size", 128) - return cls(candidate_generator, preprocessor, entity_label_types, label_type, dictionary, batch_size=batch_size) def _get_state_dict(self): diff --git a/flair/training_utils.py b/flair/training_utils.py index c57c74210..2c3ce9d5f 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -1,6 +1,5 @@ import logging import random -import sys from collections import defaultdict from enum import Enum from functools import reduce @@ -350,10 +349,7 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar def log_line(log): - if sys.version_info >= (3, 8): - log.info("-" * 100, stacklevel=3) - else: - log.info("-" * 100) + log.info("-" * 100, stacklevel=3) def add_file_handler(log, output_file): diff --git a/pyproject.toml b/pyproject.toml index 62da09bfc..8d2181d28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ filterwarnings = [ "ignore:pkg_resources", # huggingface has deprecated calls. 'ignore:Deprecated call to `pkg_resources', # huggingface has deprecated calls. 'ignore:distutils Version classes are deprecated.', # faiss uses deprecated distutils. + 'ignore:`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.', # transformers calls deprecated hf_hub ] markers = [ "integration", @@ -46,7 +47,7 @@ ignore_errors = true [tool.ruff] line-length = 120 -target-version = "py37" +target-version = "py38" [tool.ruff.lint] #select = ["ALL"] # Uncommit to autofix all the things diff --git a/requirements-dev.txt b/requirements-dev.txt index d3e818d70..61d45acf8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,3 +11,5 @@ types-dataclasses>=0.6.6 types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 types-tabulate>=0.9.0.2 +pyab3p +transformers!=4.40.1,!=4.40.0 \ No newline at end of file diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md index bef0f42ad..6f2c1474b 100644 --- a/resources/docs/HUNFLAIR2.md +++ b/resources/docs/HUNFLAIR2.md @@ -14,7 +14,7 @@ NER tools on unseen corpora. ## Quick Start #### Requirements and Installation -*HunFlair2* is based on Flair 0.13+ and Python 3.8+. If you do not have Python 3.8, install it first. +*HunFlair2* is based on Flair 0.14+ and Python 3.8+. If you do not have Python 3.8, install it first. Then, in your favorite virtual environment, simply do: ``` pip install flair diff --git a/tests/embeddings/test_transformer_word_embeddings.py b/tests/embeddings/test_transformer_word_embeddings.py index 13e933016..a2ca3716a 100644 --- a/tests/embeddings/test_transformer_word_embeddings.py +++ b/tests/embeddings/test_transformer_word_embeddings.py @@ -215,8 +215,8 @@ def test_layoutlm_embeddings_with_context_warns_user(self): sentence[2].add_metadata("bbox", (0, 12, 10, 22)) with pytest.warns(UserWarning) as record: TransformerWordEmbeddings("microsoft/layoutlm-base-uncased", layers="-1,-2,-3,-4", use_context=True) - assert len(record) == 1 - assert "microsoft/layoutlm" in record[0].message.args[0] + assert len(record) > 0 + assert "microsoft/layoutlm" in record[-1].message.args[0] @pytest.mark.integration() def test_layoutlmv3_without_image_embeddings_fails(self):