diff --git a/README.md b/README.md
index 6287b3261..3ea845188 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@ Flair is:
* **A powerful NLP library.** Flair allows you to apply our state-of-the-art natural language processing (NLP)
models to your text, such as named entity recognition (NER), sentiment analysis, part-of-speech tagging (PoS),
- special support for [biomedical data](/resources/docs/HUNFLAIR.md),
+ special support for [biomedical texts](/resources/docs/HUNFLAIR2.md),
sense disambiguation and classification, with support for a rapidly growing number of languages.
* **A text embedding library.** Flair has simple interfaces that allow you to use and combine different word and
diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py
index b4e3edc1f..76173bac8 100644
--- a/flair/embeddings/token.py
+++ b/flair/embeddings/token.py
@@ -204,7 +204,7 @@ def __init__(
self.__embedding_length: int = precomputed_word_embeddings.vector_size
- vectors = np.row_stack(
+ vectors = np.vstack(
(
precomputed_word_embeddings.vectors,
np.zeros(self.__embedding_length, dtype="float"),
@@ -399,7 +399,7 @@ def __setstate__(self, state: Dict[str, Any]):
state.setdefault("field", None)
if "precomputed_word_embeddings" in state:
precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings")
- vectors = np.row_stack(
+ vectors = np.vstack(
(
precomputed_word_embeddings.vectors,
np.zeros(precomputed_word_embeddings.vector_size, dtype="float"),
diff --git a/flair/models/__init__.py b/flair/models/__init__.py
index ac69e19aa..452c513f9 100644
--- a/flair/models/__init__.py
+++ b/flair/models/__init__.py
@@ -6,6 +6,7 @@
from .multitask_model import MultitaskModel
from .pairwise_classification_model import TextPairClassifier
from .pairwise_regression_model import TextPairRegressor
+from .prefixed_tagger import PrefixedSequenceTagger # This import has to be after SequenceTagger!
from .regexp_tagger import RegexpTagger
from .relation_classifier_model import RelationClassifier
from .relation_extractor_model import RelationExtractor
@@ -26,6 +27,7 @@
"RelationExtractor",
"RegexpTagger",
"SequenceTagger",
+ "PrefixedSequenceTagger",
"TokenClassifier",
"WordTagger",
"FewshotClassifier",
diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py
index 3f39003b7..9d0e28a43 100644
--- a/flair/models/entity_mention_linking.py
+++ b/flair/models/entity_mention_linking.py
@@ -1,6 +1,7 @@
import inspect
import logging
import os
+import platform
import re
import stat
import string
@@ -648,6 +649,8 @@ def p(text: str) -> str:
emb = emb / torch.norm(emb)
dense_embeddings.append(emb.cpu().numpy())
sent.clear_embeddings()
+
+ # empty cuda cache if device is a cuda device
if flair.device.type == "cuda":
torch.cuda.empty_cache()
@@ -681,6 +684,11 @@ def embed(self, entity_mentions: List[str]) -> Dict[str, np.ndarray]:
emb = emb / torch.norm(emb)
query_embeddings["dense"].append(emb.cpu().numpy())
sent.clear_embeddings(self.embeddings["dense"].get_names())
+
+ # Sanity conversion: if flair.device was set as a string, convert to torch.device
+ if isinstance(flair.device, str):
+ flair.device = torch.device(flair.device)
+
if flair.device.type == "cuda":
torch.cuda.empty_cache()
@@ -836,9 +844,13 @@ def extract_entities_mentions(self, sentence: Sentence, entity_label_types: Dict
if any(label in ["diseases", "genes", "species", "chemical"] for label in sentence.annotation_layers):
if not self._warned_legacy_sequence_tagger:
logger.warning(
- "The tagger `Classifier.load('hunflair') is deprecated. Please update to: `Classifier.load('hunflair2')`."
+ "It appears that the sentences have been annotated with HunFlair (version 1). "
+ "Consider using HunFlair2 for improved extraction performance: Classifier.load('hunflair2')."
+ "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further "
+ "information."
)
self._warned_legacy_sequence_tagger = True
+
entity_types = {e for sublist in entity_label_types.values() for e in sublist}
entities_mentions = [
label for label in sentence.get_labels() if normalize_entity_type(label.value) in entity_types
@@ -939,6 +951,14 @@ def _fetch_model(model_name: str) -> str:
if model_name in hf_model_map:
model_name = hf_model_map[model_name]
+ if platform.system() == "Windows":
+ logger.warning(
+ "You seem to run your application on a Windows system. Unfortunately, the abbreviation "
+ "resolution of HunFlair2 is only available on Linux/Mac systems. Therefore, a model "
+ "without abbreviation resolution is therefore loaded"
+ )
+ model_name += "-no-ab3p"
+
return hf_download(model_name)
@classmethod
diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py
index d127f0e14..bcd9befb3 100644
--- a/flair/models/multitask_model.py
+++ b/flair/models/multitask_model.py
@@ -260,6 +260,14 @@ def _fetch_model(model_name) -> str:
cache_dir = Path("models")
if model_name in model_map:
+ if model_name in ["hunflair", "hunflair-paper", "bioner"]:
+ log.warning(
+ "HunFlair (version 1) is deprecated. Consider using HunFlair2 for improved extraction performance: "
+ "Classifier.load('hunflair2')."
+ "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further "
+ "information."
+ )
+
model_name = cached_path(model_map[model_name], cache_dir=cache_dir)
return model_name
diff --git a/flair/models/prefixed_tagger.py b/flair/models/prefixed_tagger.py
index b8c01c50a..a2b3012c2 100644
--- a/flair/models/prefixed_tagger.py
+++ b/flair/models/prefixed_tagger.py
@@ -9,7 +9,8 @@
import flair.data
from flair.data import Corpus, Sentence, Token
from flair.datasets import DataLoader, FlairDatapointDataset
-from flair.models import SequenceTagger
+from flair.file_utils import hf_download
+from flair.models.sequence_tagger_model import SequenceTagger
class PrefixedSentence(Sentence):
@@ -317,3 +318,21 @@ def augment_sentences(
sentences = [sentences]
return [self.augmentation_strategy.augment_sentence(sentence, annotation_layers) for sentence in sentences]
+
+ @staticmethod
+ def _fetch_model(model_name) -> str:
+ huggingface_model_map = {"hunflair2": "hunflair/hunflair2-ner"}
+
+ # check if model name is a valid local file
+ if Path(model_name).exists():
+ model_path = model_name
+
+ # check if model name is a pre-configured hf model
+ elif model_name in huggingface_model_map:
+ hf_model_name = huggingface_model_map[model_name]
+ return hf_download(hf_model_name)
+
+ else:
+ model_path = hf_download(model_name)
+
+ return model_path
diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py
index 9e1ff7719..1f2a93c68 100644
--- a/flair/models/sequence_tagger_model.py
+++ b/flair/models/sequence_tagger_model.py
@@ -781,6 +781,14 @@ def _fetch_model(model_name) -> str:
elif model_name in hu_model_map:
model_path = cached_path(hu_model_map[model_name], cache_dir=cache_dir)
+ if model_name.startswith("hunflair-"):
+ log.warning(
+ "HunFlair (version 1) is deprecated. Consider using HunFlair2 for improved extraction performance: "
+ "Classifier.load('hunflair2')."
+ "See https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR2.md for further "
+ "information."
+ )
+
# special handling for the taggers by the @redewiegergabe project (TODO: move to model hub)
elif model_name == "de-historic-indirect":
model_file = flair.cache_root / cache_dir / "indirect" / "final-model.pt"
diff --git a/resources/docs/HUNFLAIR.md b/resources/docs/HUNFLAIR.md
index a85b88c8b..9ec20f7a6 100644
--- a/resources/docs/HUNFLAIR.md
+++ b/resources/docs/HUNFLAIR.md
@@ -8,6 +8,9 @@ NER data sets](HUNFLAIR_CORPORA.md) and comes with a Flair language model ("pubm
FastText embeddings ("pubmed") that were trained on roughly 3 million full texts and about
25 million abstracts from the biomedical domain.
+**Using HunFlair (version 1) is deprecated, please refer to [HunFlair2](HUNFLAIR2.md)
+for an updated and improved version.**
+
Content:
[Quick Start](#quick-start) |
[BioNER-Tool Comparison](#comparison-to-other-biomedical-ner-tools) |
diff --git a/resources/docs/HUNFLAIR2.md b/resources/docs/HUNFLAIR2.md
new file mode 100644
index 000000000..bef0f42ad
--- /dev/null
+++ b/resources/docs/HUNFLAIR2.md
@@ -0,0 +1,137 @@
+# HunFlair2
+
+*HunFlair2* is a state-of-the-art named entity tagger and linker for biomedical texts. It comes with
+models for genes/proteins, chemicals, diseases, species and cell lines. *HunFlair2*
+builds on pretrained domain-specific language models and outperforms other biomedical
+NER tools on unseen corpora.
+
+Content:
+[Quick Start](#quick-start) |
+[Tool Comparison](#comparison-to-other-biomedical-entity-extraction-tools) |
+[Tutorials](#tutorials) |
+[Citing HunFlair](#citing-hunflair2)
+
+## Quick Start
+
+#### Requirements and Installation
+*HunFlair2* is based on Flair 0.13+ and Python 3.8+. If you do not have Python 3.8, install it first.
+Then, in your favorite virtual environment, simply do:
+```
+pip install flair
+```
+
+#### Example 1: Biomedical NER
+Let's run named entity recognition (NER) over an example sentence. All you need to do is
+make a Sentence, load a pre-trained model and use it to predict tags for the sentence:
+```python
+from flair.data import Sentence
+from flair.nn import Classifier
+
+# make a sentence
+sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome")
+
+# load biomedical NER tagger
+tagger = Classifier.load("hunflair2")
+
+# tag sentence
+tagger.predict(sentence)
+```
+Done! The Sentence now has entity annotations. Let's print the entities found by the tagger:
+```python
+for entity in sentence.get_labels():
+ print(entity)
+```
+This should print:
+```console
+Span[0:2]: "Behavioral abnormalities" → Disease (1.0)
+Span[4:5]: "Fmr1" → Gene (1.0)
+Span[6:7]: "Mouse" → Species (1.0)
+Span[9:12]: "Fragile X Syndrome" → Disease (1.0)
+```
+
+#### Example 2: Biomedical NEN
+For improved integration and aggregation from multiple different documents linking / normalizing the entities to
+standardized ontologies or knowledge bases is required. Let's perform entity normalization by using
+specialized models per entity type:
+```python
+from flair.data import Sentence
+from flair.models import EntityMentionLinker
+from flair.nn import Classifier
+
+# make a sentence
+sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome")
+
+# load biomedical NER tagger + predict entities
+tagger = Classifier.load("hunflair2")
+tagger.predict(sentence)
+
+# load gene linker and perform normalization
+gene_linker = EntityMentionLinker.load("gene-linker")
+gene_linker.predict(sentence)
+
+# load disease linker and perform normalization
+disease_linker = EntityMentionLinker.load("disease-linker")
+disease_linker.predict(sentence)
+
+# load species linker and perform normalization
+species_linker = EntityMentionLinker.load("species-linker")
+species_linker.predict(sentence)
+```
+**Note**, the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed,
+which might takes a certain amount of time. All further calls are then based on this pre-processing and run
+much faster.
+
+Done! The Sentence now has entity normalizations. Let's print the entity identifiers found by the linkers:
+```python
+for entity in sentence.get_labels("link"):
+ print(entity)
+```
+This should print:
+```console
+Span[0:2]: "Behavioral abnormalities" → MESH:D001523/name=Mental Disorders (197.9467010498047)
+Span[4:5]: "Fmr1" → 108684022/name=FRAXA (219.9510040283203)
+Span[6:7]: "Mouse" → 10090/name=Mus musculus (213.6201934814453)
+Span[9:12]: "Fragile X Syndrome" → MESH:D005600/name=Fragile X Syndrome (193.7115020751953)
+```
+
+## Comparison to other biomedical entity extraction tools
+Tools for biomedical entity extraction are typically trained and evaluated on single, rather small gold standard
+data sets. However, they are applied "in the wild" to a much larger collection of texts, often varying in
+topic, entity distribution, genre (e.g. patents vs. scientific articles) and text type (e.g. abstract
+vs. full text), which can lead to severe drops in performance.
+
+*HunFlair2* outperforms other biomedical entity extraction tools on corpora not used for training of neither
+*HunFlair2* or any of the competitor tools.
+
+| Corpus | Entity Type | BENT | BERN2 | PubTator Central | SciSpacy | HunFlair |
+|----------------------------------------------------------------------------------------------|-------------|-------|-------|------------------|----------|-------------|
+| [MedMentions](https://github.com/chanzuckerberg/MedMentions) | Chemical | 40.90 | 41.79 | 31.28 | 34.95 | *__51.17__* |
+| | Disease | 45.94 | 47.33 | 41.11 | 40.78 | *__57.27__* |
+| [tmVar (v3)](https://github.com/ncbi/tmVar3?tab=readme-ov-file) | Gene | 0.54 | 43.96 | *__86.02__* | - | 76.75 |
+| [BioID](https://biocreative.bioinformatics.udel.edu/media/store/files/2018/BC6_track1_1.pdf) | Species | 10.35 | 14.35 | *__58.90__* | 37.14 | 49.66 |
+|||||
+| Average | All | 24.43 | 36.86 | 54.33 | 37.61 | *__58.79__* |
+
+All results are F1 scores highlighting end-to-end performance, i.e., named entity recognition and normalization,
+using partial matching of predicted text offsets with the original char offsets of the gold standard data.
+We allow a shift by max one character.
+
+You can find detailed evaluations and discussions in [our paper](https://arxiv.org/abs/2402.12372).
+
+## Tutorials
+We provide a set of quick tutorials to get you started with *HunFlair2*:
+* [Tutorial 1: Tagging biomedical named entities](HUNFLAIR2_TUTORIAL_1_TAGGING.md)
+* [Tutorial 2: Linking biomedical named entities](HUNFLAIR2_TUTORIAL_2_LINKING.md)
+* [Tutorial 3: Training NER models](HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md)
+* [Tutorial 4: Customizing linking](HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md)
+
+## Citing HunFlair2
+Please cite the following paper when using *HunFlair2*:
+~~~
+@article{sanger2024hunflair2,
+ title={HunFlair2 in a cross-corpus evaluation of biomedical named entity recognition and normalization tools},
+ author={S{\"a}nger, Mario and Garda, Samuele and Wang, Xing David and Weber-Genzel, Leon and Droop, Pia and Fuchs, Benedikt and Akbik, Alan and Leser, Ulf},
+ journal={arXiv preprint arXiv:2402.12372},
+ year={2024}
+}
+~~~
diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md b/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md
new file mode 100644
index 000000000..20d9b643b
--- /dev/null
+++ b/resources/docs/HUNFLAIR2_TUTORIAL_1_TAGGING.md
@@ -0,0 +1,121 @@
+# HunFlair2 - Tutorial 1: Tagging
+
+This is part 1 of the tutorial, in which we show how to use our pre-trained *HunFlair2* models to tag your text.
+
+### Tagging with Pre-trained HunFlair2-Models
+Let's use the pre-trained *HunFlair2* model for biomedical named entity recognition (NER).
+This model was trained over multiple biomedical NER data sets and can recognize 5 different entity types,
+i.e. cell lines, chemicals, disease, gene / proteins and species.
+```python
+from flair.nn import Classifier
+
+tagger = Classifier.load("hunflair2")
+```
+All you need to do is use the predict() method of the tagger on a sentence.
+This will add predicted tags to the tokens in the sentence.
+Lets use a sentence with four named entities:
+```python
+from flair.data import Sentence
+
+sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome")
+
+# predict NER tags
+tagger.predict(sentence)
+
+# print the predicted tags
+for entity in sentence.get_labels():
+ print(entity)
+```
+This should print:
+```console
+Span[0:2]: "Behavioral abnormalities" → Disease (1.0)
+Span[4:5]: "Fmr1" → Gene (1.0)
+Span[6:7]: "Mouse" → Species (1.0)
+Span[9:12]: "Fragile X Syndrome" → Disease (1.0)
+```
+The output indicates that there are two diseases mentioned in the text ("_Behavioral Abnormalities_" and
+"_Fragile X Syndrome_") as well as one gene ("_fmr1_") and one species ("_Mouse_"). For each entity the
+text span in the sentence mention it is given and Label with a value and a score (confidence in the
+prediction). You can also get additional information, such as the position offsets of each entity
+in the sentence in a structured way by calling the `to_dict()` method:
+
+```python
+print(sentence.to_dict())
+```
+This should print:
+```python
+{
+ 'text': 'Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome',
+ 'labels': [],
+ 'entities': [
+ {'text': 'Behavioral abnormalities', 'start_pos': 0, 'end_pos': 24, 'labels': [{'value': 'Disease', 'confidence': 0.9999860525131226}]},
+ {'text': 'Fmr1', 'start_pos': 32, 'end_pos': 36, 'labels': [{'value': 'Gene', 'confidence': 0.9999895095825195}]},
+ {'text': 'Mouse', 'start_pos': 41, 'end_pos': 46, 'labels': [{'value': 'Species', 'confidence': 0.9999873638153076}]},
+ {'text': 'Fragile X Syndrome', 'start_pos': 56, 'end_pos': 74, 'labels': [{'value': 'Disease', 'confidence': 0.9999928871790568}]}
+ ],
+ # further sentence information
+}
+```
+
+### Using a Biomedical Tokenizer
+Tokenization, i.e. separating a text into tokens / words, is an important issue in natural language processing
+in general and biomedical text mining in particular. So far, we used a tokenizer for general domain text.
+This can be unfavourable if applied to biomedical texts.
+
+*HunFlair2* integrates [SciSpaCy](https://allenai.github.io/scispacy/), a library specially designed to work with scientific text.
+To use the library we first have to install it and download one of it's models:
+~~~
+pip install scispacy==0.5.1
+pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
+~~~
+
+To use the tokenizer we just have to pass it as parameter to when instancing a sentence:
+```python
+from flair.tokenization import SciSpacyTokenizer
+
+sentence = Sentence("Behavioral abnormalities in the Fmr1 KO2 Mouse Model of Fragile X Syndrome",
+ use_tokenizer=SciSpacyTokenizer())
+```
+
+### Working with longer Texts
+Often, we are concerned with complete scientific abstracts or full-texts when performing biomedical text mining, e.g.
+```python
+abstract = "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, " \
+ "coding for the FMRP protein which is largely involved in synaptic function. FXS patients present several " \
+ "behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive " \
+ "deficits. Autistic symptoms, e.g., altered social interaction and communication, are also often observed: " \
+ "FXS is indeed the most common monogenic cause of autism."
+```
+
+To work with complete abstracts or full-text, we first have to split them into separate sentences.
+Again we can apply the integration of the [SciSpaCy](https://allenai.github.io/scispacy/) library:
+```python
+from flair.splitter import SciSpacySentenceSplitter
+
+# initialize the sentence splitter
+splitter = SciSpacySentenceSplitter()
+
+# split text into a list of Sentence objects
+sentences = splitter.split(abstract)
+
+# you can apply the HunFlair tagger directly to this list
+tagger.predict(sentences)
+```
+We can access the annotations of the single sentences by just iterating over the list:
+```python
+for sentence in sentences:
+ print(sentence.to_tagged_string())
+```
+This should print:
+~~~
+Sentence[35]: "Fragile X syndrome (FXS) is a developmental disorder caused by a mutation in the X-linked FMR1 gene, coding for the FMRP protein which is largely involved in synaptic function." \
+ → ["Fragile X syndrome"/Disease, "FXS"/Disease, "developmental disorder"/Disease, "X-linked"/Gene, "FMR1"/Gene, "FMRP"/Gene]
+Sentence[23]: "FXS patients present several behavioral abnormalities, including hyperactivity, anxiety, sensory hyper-responsiveness, and cognitive deficits." \
+ → ["FXS"/Disease, "patients"/Species, "behavioral abnormalities"/Disease, "hyperactivity"/Disease, "anxiety"/Disease, "sensory hyper-responsiveness"/Disease, "cognitive deficits"/Disease]
+Sentence[27]: "Autistic symptoms, e.g., altered social interaction and communication, are also often observed: FXS is indeed the most common monogenic cause of autism." \
+ → ["Autistic symptoms"/Disease, "altered social interaction and communication"/Disease, "FXS"/Disease, "autism"/Disease]
+~~~
+
+### Next
+Now, let us look at how to [link / normalize the entities to standard ontologies](HUNFLAIR2_TUTORIAL_2_LINKING.md)
+in the second tutorial.
diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md b/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md
new file mode 100644
index 000000000..182fd0cd9
--- /dev/null
+++ b/resources/docs/HUNFLAIR2_TUTORIAL_2_LINKING.md
@@ -0,0 +1,88 @@
+# HunFlair2 - Tutorial 2: Entity Linking
+
+[Part 1](HUNFLAIR2_TUTORIAL_1_TAGGING.md) of the tutorial, showed how to use our pre-trained *HunFlair2* models to
+tag biomedical entities in your text. However, documents from different biomedical (sub-) fields may use different
+terms to refer to the exact same concept, e.g., “_tumor protein p53_”, “_tumor suppressor p53_”, “_TRP53_” are all
+valid names for the gene “TP53” ([NCBI Gene:7157](https://www.ncbi.nlm.nih.gov/gene/7157)).
+For improved integration and aggregation of entity mentions from multiple different documents linking / normalizing
+the entities to standardized ontologies or knowledge bases is required.
+
+### Linking with pre-trained HunFlair2 Models
+
+After adding named entity recognition tags to your sentence, you can link the entities to standard ontologies
+using distinct, type-specific linking models:
+
+```python
+from flair.models import EntityMentionLinker
+from flair.nn import Classifier
+from flair.data import Sentence
+
+sentence = Sentence(
+ "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, "
+ "a neurodegenerative disease, which is exacerbated by exposure to high "
+ "levels of mercury in mouse populations."
+)
+
+# Tag named entities in the text
+ner_tagger = Classifier.load("hunflair2")
+ner_tagger.predict(sentence)
+
+# Load disease linker and perform disease linking
+disease_linker = EntityMentionLinker.load("disease-linker")
+disease_linker.predict(sentence)
+
+# Load gene linker and perform gene linking
+gene_linker = EntityMentionLinker.load("gene-linker")
+gene_linker.predict(sentence)
+
+# Load chemical linker and perform chemical linking
+chemical_linker = EntityMentionLinker.load("chemical-linker")
+chemical_linker.predict(sentence)
+
+# Load species linker and perform species linking
+species_linker = EntityMentionLinker.load("species-linker")
+species_linker.predict(sentence)
+```
+
+**Note**: the ontologies and knowledge bases used are pre-processed the first time the normalisation is executed,
+which might takes a certain amount of time. All further calls are then based on this pre-processing and run
+much faster.
+
+After running the code we can inspect and output the linked entities via:
+
+```python
+for tag in sentence.get_labels("link"):
+ print(tag)
+```
+
+This should print:
+
+```
+Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062)
+Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875)
+Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203)
+Span[23:24]: "mercury" → MESH:D008628/name=Mercury (220.39199829101562)
+Span[25:26]: "mouse" → 10090/name=Mus musculus (213.6201934814453)
+```
+
+For each entity, the output contains both the NER mention annotations and their ontology identifiers to which
+the mentions were mapped. Moreover, the official name of the entity in the ontology and the similarity score
+of the entity mention and the ontology concept is given. For instance, the official name for the disease
+"_X-linked adrenoleukodystrophy_" is adrenoleukodystrophy. The similarity scores are specific to entity type,
+ontology and linking model used and can therefore only be compared and related to those using the exact same
+setup.
+
+### Overview of pre-trained Entity Linking Models
+
+HunFlair2 comes with the following pre-trained linking models:
+
+| Entity Type | Model Name | Ontology / Dictionary | Linking Algorithm / Base Model (Data Set) |
+| ----------- | ----------------- | ---------------------------------------------------------- | --------------------------------------------------------------------------------------- |
+| Chemical | `chemical-linker` | [CTD Chemicals](https://ctdbase.org/downloads/#allchems) | [SapBERT (BC5CDR)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-chemical) |
+| Disease | `disease-linker` | [CTD Diseases](https://ctdbase.org/downloads/#alldiseases) | [SapBERT (NCBI Disease)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc5cdr-disease) |
+| Gene | `gene-linker` | [NCBI Gene (Human)](https://www.ncbi.nlm.nih.gov/gene) | [SapBERT (BC2GN)](https://huggingface.co/dmis-lab/biosyn-sapbert-bc2gn) |
+| Species | `species-linker` | [NCBI Taxonmy](https://www.ncbi.nlm.nih.gov/taxonomy) | [SapBERT (UMLS)](https://huggingface.co/cambridgeltl/SapBERT-from-PubMedBERT-fulltext) |
+
+For detailed information concerning the different models and their integration please refer to [our paper](https://arxiv.org/abs/2402.12372).
+
+If you wish to customize the models and dictionaries please refer to the [dedicated tutorial](HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md).
diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md b/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md
new file mode 100644
index 000000000..762fe14b9
--- /dev/null
+++ b/resources/docs/HUNFLAIR2_TUTORIAL_3_TRAINING_NER.md
@@ -0,0 +1,223 @@
+# HunFlair2 Tutorial 3: Training NER models
+
+This part of the tutorial shows how you can train your own biomedical named entity recognition models
+using state-of-the-art pretrained Transformers embeddings.
+
+For this tutorial, we assume that you're familiar with the [base types](https://flairnlp.github.io/docs/tutorial-basics/basic-types) of Flair
+and how [transformers_word embeddings](https://flairnlp.github.io/docs/tutorial-training/how-to-train-sequence-tagger).
+You should also know how to [load a corpus](https://flairnlp.github.io/docs/tutorial-training/how-to-load-prepared-dataset).
+
+## Train a biomedical NER model from scratch: single entity type
+
+Here is example code for a biomedical NER model trained on the `NCBI_DISEASE` corpus using Transformer word embeddings.
+This will result in a tagger specialized for a single entity type, i.e. "Disease".
+
+```python
+from flair.datasets import NCBI_DISEASE
+
+# 1. get the corpus
+corpus = NCBI_DISEASE()
+print(corpus)
+
+# 2. make the tag dictionary from the corpus
+tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
+
+# 3. initialize embeddings
+from flair.embeddings import TransformerWordEmbeddings
+
+embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
+ "michiyasunaga/BioLinkBERT-base",
+ layers="-1",
+ subtoken_pooling="first",
+ fine_tune=True,
+ use_context=True,
+ model_max_length=512,
+)
+
+# 4. initialize sequence tagger
+from flair.models import SequenceTagger
+
+tagger: SequenceTagger = SequenceTagger(
+ hidden_size=256,
+ embeddings=embeddings,
+ tag_dictionary=tag_dictionary,
+ tag_format="BIOES",
+ tag_type="ner",
+ use_crf=False,
+ use_rnn=False,
+ reproject_embeddings=False,
+)
+
+# 5. initialize trainer
+from flair.trainers import ModelTrainer
+
+trainer: ModelTrainer = ModelTrainer(tagger, corpus)
+
+trainer.fine_tune(
+ base_path="taggers/ncbi-disease",
+ train_with_dev=False,
+ max_epochs=16,
+ learning_rate=2.0e-5,
+ mini_batch_size=16,
+ shuffle=False,
+)
+```
+
+Once the model is trained you can use it to predict tags for new sentences.
+Just call the predict method of the model.
+
+```python
+# load the model you trained
+model = SequenceTagger.load("taggers/ncbi-disease/best-model.pt")
+
+# create example sentence
+from flair.data import Sentence
+sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.")
+
+# predict tags and print
+model.predict(sentence)
+
+print(sentence.to_tagged_string())
+```
+
+If the model works well, it will correctly tag "breast cancer" as disease in this example:
+
+```
+Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer .
+```
+
+## Train a biomedical NER model: multiple entity types
+
+If you are dealing with multiple entity types, e.g. "Disease" and "Chemicals", you can opt
+to train a single model capable of handling multiple entity types at once.
+This can be achieved by using the `PrefixedSequenceTagger()` class, which implements the method described in \[1\].
+This model uses prompting, i.e. it adds a prefix (hence the name) string in front of specifying the
+entity types requested for tagging: `[Tag , , ...]`.
+Thist is especially usefull for training, as it allows to combine multiple corpora even if they cover different subsets of entity types.
+
+```python
+# 1. get the corpora
+from flair.datasets.biomedical import HUNER_ALL_CDR, HUNER_CHEMICAL_NLM_CHEM
+corpora = (HUNER_ALL_CDR(), HUNER_CHEMICAL_NLM_CHEM())
+
+# 2. add prefixed strings to each corpus by prepending its tagged entity
+# types "[Tag , , ...]"
+from flair.data import MultiCorpus
+from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy
+from flair.datasets.biomedical import (
+ BIGBIO_NER_CORPUS,
+ CELL_LINE_TAG,
+ CHEMICAL_TAG,
+ DISEASE_TAG,
+ GENE_TAG,
+ SPECIES_TAG,
+)
+
+mapping = {
+ CELL_LINE_TAG: "cell lines",
+ CHEMICAL_TAG: "chemicals",
+ DISEASE_TAG: "diseases",
+ GENE_TAG: "genes",
+ SPECIES_TAG: "species",
+}
+
+prefixed_corpora = []
+all_entity_types = set()
+for corpus in corpora:
+ entity_types = sorted(
+ set(
+ [
+ mapping[tag]
+ for tag in corpus.get_entity_type_mapping().values()
+ ]
+ )
+ )
+ all_entity_types.update(set(entity_types))
+
+ print(f"Entity types in {corpus}: {entity_types}")
+
+ augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
+ entity_types
+ )
+ prefixed_corpora.append(
+ augmentation_strategy.augment_corpus(corpus)
+ )
+
+corpora = MultiCorpus(prefixed_corpora)
+all_entity_types = sorted(all_entity_types)
+
+# 3. make the tag dictionary from the corpus
+tag_dictionary = corpus.make_label_dictionary(label_type="ner")
+
+# 4. the final model will on default predict all the entity types seen
+# in the training corpora, e.g., disease and chemicals here
+augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
+ all_entity_types
+)
+
+# 5. initialize embeddings
+from flair.embeddings import TransformerWordEmbeddings
+
+embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
+ "michiyasunaga/BioLinkBERT-base",
+ layers="-1",
+ subtoken_pooling="first",
+ fine_tune=True,
+ use_context=True,
+ model_max_length=512,
+)
+
+# 4. initialize sequence tagger
+from flair.models.prefixed_tagger import PrefixedSequenceTagger
+
+tagger: SequenceTagger = PrefixedSequenceTagger(
+ hidden_size=256,
+ embeddings=embeddings,
+ tag_dictionary=tag_dictionary,
+ tag_format="BIOES",
+ tag_type="ner",
+ use_crf=False,
+ use_rnn=False,
+ reproject_embeddings=False,
+ augmentation_strategy=augmentation_strategy,
+)
+
+# 5. initialize trainer
+from flair.trainers import ModelTrainer
+
+trainer: ModelTrainer = ModelTrainer(tagger, corpus)
+
+trainer.fine_tune(
+ base_path="taggers/cdr_nlm_chem",
+ train_with_dev=False,
+ max_epochs=16,
+ learning_rate=2.0e-5,
+ mini_batch_size=16,
+ shuffle=False,
+)
+```
+
+## Training HunFlair2 from scratch
+
+*HunFlair2* uses the `PrefixedSequenceTagger()` class as defined above but adds the following corpora to the training set instead:
+
+```python
+from flair.datasets.biomedical import (
+ HUNER_ALL_BIORED, HUNER_GENE_NLM_GENE,
+ HUNER_GENE_GNORMPLUS, HUNER_ALL_SCAI,
+ HUNER_CHEMICAL_NLM_CHEM, HUNER_SPECIES_LINNEAUS,
+ HUNER_SPECIES_S800, HUNER_DISEASE_NCBI
+)
+
+corpora = (
+ HUNER_ALL_BIORED(), HUNER_GENE_NLM_GENE(),
+ HUNER_GENE_GNORMPLUS(), HUNER_ALL_SCAI(),
+ HUNER_CHEMICAL_NLM_CHEM(), HUNER_SPECIES_LINNEAUS(),
+ HUNER_SPECIES_S800(), HUNER_DISEASE_NCBI()
+)
+
+```
+
+## References
+
+\[1\] Luo, L., Wei, C. H., Lai, P. T., Leaman, R., Chen, Q., & Lu, Z. (2023). AIONER: all-in-one scheme-based biomedical named entity recognition using deep learning. Bioinformatics, 39(5), btad310.
diff --git a/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md b/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md
new file mode 100644
index 000000000..d435aefcc
--- /dev/null
+++ b/resources/docs/HUNFLAIR2_TUTORIAL_4_CUSTOMIZE_LINKING.md
@@ -0,0 +1,145 @@
+# HunFlair2 Tutorial 4: Customizing linking models
+
+In this tutorial you'll find information on how to customize the entity linking models according to your needs.
+As of now, fine-tuning the models is not supported.
+
+## Customize dictionary
+
+All linking models come with a pre-defined pairing of entity type and dictionary,
+e.g. "Disease" mentions are linked by default to the [CTD Diseases](https://ctdbase.org/help/diseaseDetailHelp.jsp).
+You can change the dictionary to which mentions are linked by following the steps below.
+We'll be using the [Human Phenotype Ontology](https://hpo.jax.org/app/) in our example
+(Download the `hp.json` file you find [here](https://hpo.jax.org/app/data/ontology) if you want to follow along).
+
+First we load from the original data a python dictionary mapping names to concept identifiers
+
+```python
+import json
+from collections import defaultdict
+with open("hp.json") as fp:
+ data = json.load(fp)
+
+nodes = [n for n in data['graphs'][0]['nodes'] if n.get('type') == 'CLASS']
+hpo = defaultdict(list)
+for node in nodes:
+ concept_id = node['id'].replace('http://purl.obolibrary.org/obo/', '')
+ names = [node['lbl']] + [s['val'] for s in node.get('synonym', [])]
+ for name in names:
+ hpo[name].append(concept_id)
+```
+
+Then we can convert this mapping into a dictionary that can be used by our linking model:
+
+```python
+from flair.datasets.entity_linking import (
+ InMemoryEntityLinkingDictionary,
+ EntityCandidate,
+)
+
+database_name="HPO"
+
+candidates = [
+ EntityCandidate(
+ concept_id=ids[0],
+ concept_name=name,
+ additional_ids=ids[1:],
+ database_name=database_name,
+ )
+ for name, ids in hpo.items()
+]
+
+dictionary = InMemoryEntityLinkingDictionary(
+ candidates=candidates, dataset_name=database_name
+)
+```
+
+To use this dictionary you need to initialize a new linker model with it.
+See the section below for that.
+
+## Custom pre-trained model
+
+You can initialize a new linker model with both a custom model and custom dictionary (see section above) like this:
+
+```python
+pretrained_model="cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
+linker = EntityMentionLinker.build(
+ pretrained_model,
+ dictionary=dictionary,
+ hybrid_search=False,
+ entity_type="disease",
+ )
+```
+
+Omitting the `dictionary` parameter will load the default dictionary for the specified `entity_type`.
+
+## Customizing Prediction Labels
+
+In the default setup all linker models output their prediction into the same annotation category *link*.
+To record the NEN annotation in separate categories, you can use the `pred_label_type` parameter of the
+`predict()` method:
+
+```python
+gene_linker.predict(sentence, pred_label_type="my-genes")
+disease_linker.predict(sentence, pred_label_type="my-diseases")
+
+print("Diseases:")
+for disease_tag in sentence.get_labels("my-diseases"):
+ print(disease_tag)
+
+print("\nGenes:")
+for gene_tag in sentence.get_labels("my-genes"):
+ print(gene_tag)
+```
+
+This will output:
+
+```
+Diseases:
+Span[7:9]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875)
+Span[11:13]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203)
+
+Genes:
+Span[4:5]: "ABCD1" → 215/name=ABCD1 (210.89810180664062)
+```
+
+Moreover, each linker has a pre-defined configuration specifying for which NER annotations it should compute
+entity links:
+
+```python
+print(gene_linker.entity_label_types)
+print(disease_linker.entity_label_types)
+```
+
+By default all models will use the *ner* annotation category and apply the linking algorithm for annotations
+of the respective entity type:
+
+```python
+{'ner': {'gene'}}
+{'ner': {'disease'}}
+```
+
+You can customize this by using the `entity_label_types` parameter of the `predict()` method:
+
+```python
+sentence = Sentence(
+ "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, "
+ "a neurodegenerative disease, which is exacerbated by exposure to high "
+ "levels of mercury in mouse populations."
+)
+
+from flair.models import SequenceTagger
+
+# Use disease ner tagger from HunFlair v1
+hunflair1_tagger = SequenceTagger.load("hunflair-disease")
+hunflair1_tagger.predict(sentence, label_name="my-diseases")
+
+# Use the entity_label_types parameter in predict() to specify the annotation category
+disease_linker.predict(sentence, entity_label_types="my-diseases")
+```
+
+If you are using annotated texts with more fine-granular NER annotations you are able to specify the
+annotation category and tag type using a dictionary. For instance:
+
+```python
+gene_linker.predict(sentence, entity_label_types={"ner": {"gene": "protein"}})
+```