Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve nel tutorial #3369

Merged
merged 15 commits into from
Mar 6, 2024
127 changes: 127 additions & 0 deletions docs/tutorial/tutorial-basics/entity-mention-linking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Using and creating entity mention linker

As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN aproach](https://huggingface.co/hunflair)].

## Example 1: Printing Entity linking outputs to console

To illustrate, let's use the example the hunflair models on a biomedical sentence:

```python
from flair.models import EntityMentionLinker
from flair.nn import Classifier
from flair.tokenization import SciSpacyTokenizer
from flair.data import Sentence

sentence = Sentence(
"The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, "
"a neurodegenerative disease, which is exacerbated by exposure to high "
"levels of mercury in dolphin populations.",
use_tokenizer=SciSpacyTokenizer()
)

ner_tagger = Classifier.load("hunflair")
ner_tagger.predict(sentence)

nen_tagger = EntityMentionLinker.load("disease-linker-no-ab3p")
nen_tagger.predict(sentence)

for tag in sentence.get_labels():
print(tag)
```

```{note}
Here we use the `disease-linker-no-ab3p` model, as it is the simplest model to run. You might get better results by using `disease-linker` instead,
but under the hood ab3p uses an executeable that is only compiled for linux and therefore won't run on every system.

Analogously to `disease` there are also linker for `chemical`, `species` and `gene`
all work with the `{entity_type}-linker` or `{entity_type}-linker-no-ab3p` naming-schema
```


This should print:
```console
Span[4:5]: "ABCD1" → Gene (0.9509)
Span[7:11]: "X-linked adrenoleukodystrophy" → Disease (0.9872)
Span[7:11]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875)
Span[13:15]: "neurodegenerative disease" → Disease (0.8988)
Span[13:15]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203)
Span[29:30]: "mercury" → Chemical (0.9484)
Span[31:32]: "dolphin" → Species (0.8071)
```

As we can see, the huflair-ner model resolved entities of several types, however for the disease linker, only those of type disease were relevant:
- "X-linked adrenoleukodystrophy" refers to the entity "[Adrenoleukodystrophy](https://id.nlm.nih.gov/mesh/D000326.html)"
- "neurodegenerative disease" refers to the "[Neurodegenerative Diseases](https://id.nlm.nih.gov/mesh/D019636.html)"


## Example 2: Structured handling of predictions

After the predictions, the flair sentence has multiple labels added to the sentence object.
* Each NER prediction adds a span referenced by the `label_type` from the span tagger.
* Each NEL prediction adds one or more labels (up to `k`) to the respective span. Those have the `label_type` from the entity mention linker.
* The NEL labels are ordered by their score. Depending on the exact implementation, it is possible that the order is ascending or descending, however the first one is always the best.

Therefore, an example to extract the information to a dictionary that could be used for further processing is the following:

```python
from flair.models import EntityMentionLinker
from flair.nn import Classifier
from flair.tokenization import SciSpacyTokenizer
from flair.data import Sentence

sentence = Sentence(
"The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, "
"a neurodegenerative disease, which is exacerbated by exposure to high "
"levels of mercury in dolphin populations.",
use_tokenizer=SciSpacyTokenizer()
)

ner_tagger = Classifier.load("hunflair")
ner_tagger.predict(sentence)

nen_tagger = EntityMentionLinker.load("disease-linker-no-ab3p")

# top_k = 5 so that a span can have up to 5 labels assigned.
nen_tagger.predict(sentence, top_k=5)

result_mentions = []

for span in sentence.get_spans(ner_tagger.label_type):

# basic information about the span that is tagged.
span_data = {
"start": span.start_position + sentence.start_position,
"end": span.end_position + sentence.start_position,
"text": span.text,
}

# add the ner label. We always have only one, so we can use `span.get_label(...)`
span_data["ner_label"] = span.get_label(ner_tagger.label_type).value

mentions_found = []

# since `top_k` is larger than 1, we need to handle multiple nen labels. Therefore we use `span.get_labels(...)`
for label in span.get_labels(nen_tagger.label_type):
mentions_found.append({
"id": label.value,
"score": label.score,
})

# extract the most probable prediction if any prediction is found.
if mentions_found:
span_data["nen_id"] = mentions_found[0]["id"]
else:
span_data["nen_id"] = None

# add all found candidates with rating if you want to explore more than just the most probable prediction.
span_data["mention_candidates"] = mentions_found

result_mentions.append(span_data)

print(result_mentions)
```

```{note}
If you need more than the extracted ids, you can use `nen_tagger.dictionary[span_data["nen_id"]]`
to look up the [`flair.data.EntityCandidate`](#flair.data.EntityCandidate) which contains further information.
```
1 change: 1 addition & 0 deletions docs/tutorial/tutorial-basics/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and showcases various models we ship with Flair.
tagging-entities
tagging-sentiment
entity-linking
entity-mention-linking
part-of-speech-tagging
other-models
how-to-tag-corpus
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This tutorial section show you how to train state-of-the-art NER models and othe

## Training a named entity recognition (NER) model with transformers

For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context
For a state-of-the-art NER system you should fine-tune transformer embeddings, and use full document context
(see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details).

Use the following script:
Expand Down
233 changes: 233 additions & 0 deletions docs/tutorial/tutorial-training/how-to-train-span-classifier.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Train a span classifier

Span Classification models are used to model problems such as entity linking, where you already have extracted some
relevant spans
within the {term}`Sentence` and want to predict some more fine-grained labels.

This tutorial section show you how to train models using the [Span Classifier](#flair.models.SpanClassifier) in Flair.

## Training an entity linker (NEL) model with transformers

For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context
(see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details).

Use the following script:

```python
from flair.datasets import ZELDA
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SpanClassifier
from flair.models.entity_linker_model import CandidateGenerator
from flair.trainers import ModelTrainer
from flair.nn.decoder import PrototypicalDecoder


# 1. get the corpus
corpus = ZELDA()
print(corpus)

# 2. what label do we want to predict?
label_type = 'nel'

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=True)
print(label_dict)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(
model="bert-base-uncased",
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
)

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SpanClassifier(
embeddings=embeddings,
label_dictionary=label_dict,
label_type=label_type,
decoder=PrototypicalDecoder(
num_prototypes=len(label_dict),
embeddings_size=embeddings.embedding_length * 2, # we use "first_last" encoding for spans
distance_function="dot_product",
),
candidates=CandidateGenerator("zelda"),
)

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
trainer.fine_tune(
"resources/taggers/zelda-nel",
learning_rate=5.0e-6,
mini_batch_size=4,
mini_batch_chunk_size=1, # remove this parameter to speed up computation if you have a big GPU
)
```

As you can see, we use [`TransformerWordEmbeddings`](#flair.embeddings.token.TransformerWordEmbeddings) based on [bert-base-uncased](https://huggingface.co/bert-base-uncased) embeddings. We enable fine-tuning and set `use_context` to True.
We use [Prototypical Networks](https://arxiv.org/abs/1703.05175), to generalize bettwer in the few-shot classification setting.
Also, we set a `CandidateGenerator` in the [`SpanClassifier`](#flair.models.SpanClassifier).
This way we limit the classification to a small set of candidates that are chosen depending on the text of the respective span.

## Loading a ColumnCorpus

In cases you want to train over a custom named entity linking dataset, you can load them with the [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus) object.
Most sequence labeling datasets in NLP use some sort of column format in which each line is a word and each column is
one level of linguistic annotation. See for instance this sentence:

```console
George B-George_Washington
Washington I-George_Washington
went O
to O
Washington B-Washington_D_C

Sam B-Sam_Houston
Houston I-Sam_Houston
stayed O
home O
```

The first column is the word itself, the second BIO-annotated tags used to specify the spans that will be classified. To read such a
dataset, define the column structure as a dictionary and instantiate a [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus).

```python
from flair.data import Corpus
from flair.datasets import ColumnCorpus

# define columns
columns = {0: "text", 1: "nel"}

# this is the folder in which train, test and dev files reside
data_folder = '/path/to/data/folder'

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns)
```

## constructing a dataset in memory

If you have a pipeline where you need to construct your dataset from a different data source,
you can always construct a [Corpus](#flair.data.Corpus) with [FlairDatapointDataset](#flair.datasets.base.FlairDatapointDataset) by hand.
Let's assume you create a function `create_datapoint(datapoint) -> Sentence` that looks somewhat like this:
```python
from flair.data import Sentence

def create_sentence(datapoint) -> Sentence:
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
sentence = Sentence(tokens)
for (start, end, label) in spans:
sentence[start:end+1].add_label("nel", label)
```
Then you can use this function to create a full dataset:
```python
from flair.data import Corpus
from flair.datasets import FlairDatapointDataset

def construct_corpus(data):
return Corpus(
train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]),
dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]),
test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]),
)
```
And use this to construct a corpus instead of loading a dataset.


## Combining NEL with Mention Detection

often, you don't just want to use a Named Entity Linking model alone, but combine it with a Mention Detection or Named Entity Recognition model.
For this, you can use a [Multitask Model](#flair.models.MultitaskModel) to combine a [SequenceTagger](#flair.models.SequenceTagger) and a [Span Classifier](#flair.models.SpanClassifier).

```python
from flair.datasets import NER_MULTI_WIKINER, ZELDA
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger, SpanClassifier
from flair.models.entity_linker_model import CandidateGenerator
from flair.trainers import ModelTrainer
from flair.nn import PrototypicalDecoder
from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
ner_corpus = NER_MULTI_WIKINER()
nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one

# --- Embeddings that are shared by both models --- #
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)

ner_label_dict = ner_corpus.make_label_dictionary("ner", add_unk=False)

ner_model = SequenceTagger(
embeddings=shared_embeddings,
tag_dictionary=ner_label_dict,
tag_type="ner",
use_rnn=False,
use_crf=False,
reproject_embeddings=False,
)


nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True)

nel_model = SpanClassifier(
embeddings=shared_embeddings,
label_dictionary=nel_label_dict,
label_type="ner",
decoder=PrototypicalDecoder(
num_prototypes=len(nel_label_dict),
embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans
distance_function="dot_product",
),
candidates=CandidateGenerator("zelda"),
)


# -- Define mapping (which tagger should train on which model) -- #
multitask_model, multicorpus = make_multitask_model_and_corpus(
[
(ner_model, ner_corpus),
(nel_model, nel_corpus),
]
)

# -- Create model trainer and train -- #
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune(f"resources/taggers/zelda_with_mention")
```

Here, the [make_multitask_model_and_corpus](#flair.nn.multitask.make_multitask_model_and_corpus) method creates a multitask model and a multicorpus where each sub-model is aligned for a sub-corpus.

### Multitask with aligned training data

If you have sentences with both annotations for ner and for nel, you might want to use a single corpus for both models.

This means, that you need to manually the `multitask_id` to the sentences:

```python
from flair.data import Sentence

def create_sentence(datapoint) -> Sentence:
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
sentence = Sentence(tokens)
for (start, end, ner_label, nel_label) in spans:
sentence[start:end+1].add_label("ner", ner_label)
sentence[start:end+1].add_label("nel", nel_label)
sentence.add_label("multitask_id", "Task_0") # Task_0 for the NER model
sentence.add_label("multitask_id", "Task_1") # Task_1 for the NEL model
```

Then you can run the multitask training script with the exception that you create the [MultitaskModel](#flair.models.MultitaskModel) directly.

```python
...
multitask_model = MultitaskModel([ner_model, nel_model], use_all_tasks=True)
```

Here, setting `use_all_tasks=True` means that we will jointly train on both tasks at the same time. This will save a lot of training time,
as the shared embedding will be calculated once but used twice (once for each model).

1 change: 1 addition & 0 deletions docs/tutorial/tutorial-training/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ This tutorial illustrates how you can train your own state-of-the-art NLP models
how-to-load-custom-dataset
how-to-train-sequence-tagger
how-to-train-text-classifier
how-to-train-span-classifier
2 changes: 1 addition & 1 deletion flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __init__(
Args:
path_to_column_file: path to the file with the column-formatted data
column_name_map: a map specifying the column format
column_delimiter: default is to split on any separatator, but you can overwrite for instance with "\t" to split only on tabs
column_delimiter: default is to split on any separator, but you can overwrite for instance with "\t" to split only on tabs
comment_symbol: if set, lines that begin with this symbol are treated as comments
in_memory: If set to True, the dataset is kept in memory as Sentence objects, otherwise does disk reads
document_separator_token: If provided, sentences that function as document boundaries are so marked
Expand Down
Loading
Loading