diff --git a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md index 247b3daa1..fc9bc492b 100644 --- a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md +++ b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md @@ -7,7 +7,7 @@ This tutorial section show you how to train state-of-the-art NER models and othe ## Training a named entity recognition (NER) model with transformers -For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context +For a state-of-the-art NER system you should fine-tune transformer embeddings, and use full document context (see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details). Use the following script: diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md new file mode 100644 index 000000000..56a5fc1d3 --- /dev/null +++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md @@ -0,0 +1,146 @@ +# Train a span classifier + +Span Classification models are used to model problems such as entity linking, where you already have extracted some +relevant spans +within the {term}`Sentence` and want to predict some more fine-grained labels. + +This tutorial section show you how to train state-of-the-art NER models and other taggers in Flair. + +## Training an entity linker (NEL) model with transformers + +For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context +(see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details). + +Use the following script: + +```python +from flair.datasets import ZELDA +from flair.embeddings import TransformerWordEmbeddings +from flair.models import SpanClassifier +from flair.models.entity_linker_model import CandidateGenerator +from flair.trainers import ModelTrainer +from flair.nn.decoder import PrototypicalDecoder + + +# 1. get the corpus +corpus = ZELDA() +print(corpus) + +# 2. what label do we want to predict? +label_type = 'nel' + +# 3. make the label dictionary from the corpus +label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False) +print(label_dict) + +# 4. initialize fine-tuneable transformer embeddings WITH document context +embeddings = TransformerWordEmbeddings( + model="bert-base-uncased", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, +) + +# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection) +tagger = SpanClassifier( + embeddings=embeddings, + label_dictionary=label_dict, + tag_type=label_type, + decoder=PrototypicalDecoder( + num_prototypes=len(label_dict), + embeddings_size=embeddings.embedding_length * 2, # we use "first_last" encoding for spans + distance_function="dot_product", + ), + candidates=CandidateGenerator("zelda"), +) + +# 6. initialize trainer +trainer = ModelTrainer(tagger, corpus) + +# 7. run fine-tuning +trainer.fine_tune( + "resources/taggers/zelda-nel", + learning_rate=5.0e-6, + mini_batch_size=4, + mini_batch_chunk_size=1, # remove this parameter to speed up computation if you have a big GPU +) +``` + +As you can see, we use [`TransformerWordEmbeddings`](#flair.embeddings.token.TransformerWordEmbeddings) based on [bert-base-uncased](https://huggingface.co/bert-base-uncased) embeddings. We enable fine-tuning and set `use_context` to True. +We use [Prototypical Networks](https://arxiv.org/abs/1703.05175), to generalize bettwer in the few-shot classification setting. +Also, we set a `CandidateGenerator` in the [`SpanClassifier`](#flair.models.SpanClassifier). +This way we limit the classification to a small set of candidates that are choosen depending on the text of the respective span. + +## Loading a ColumnCorpus + +In cases you want to train over a custom named entity linking dataset, you can load them with the [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus) object. +Most sequence labeling datasets in NLP use some sort of column format in which each line is a word and each column is +one level of linguistic annotation. See for instance this sentence: + +```console +George B-George_Washington +Washington I-George_Washington +went O +to O +Washington B-Washington_D_C + +Sam B-Sam_Houston +Houston I-Sam_Houston +stayed O +home O +``` + +The first column is the word itself, the second BIO-annotated tags used to specify the spans that will be classified. To read such a +dataset, define the column structure as a dictionary and instantiate a [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus). + +```python +from flair.data import Corpus +from flair.datasets import ColumnCorpus + +# define columns +columns = {0: "text", 1: "nel"} + +# this is the folder in which train, test and dev files reside +data_folder = '/path/to/data/folder' + +# init a corpus using column format, data folder and the names of the train, dev and test files +corpus: Corpus = ColumnCorpus(data_folder, columns) +``` + +## constructing a dataset in memory + +If you have a pipeline where you need to construct your dataset from a different data source, +you can always construct a [Corpus](#flair.data.Corpus) with [FlairDatapointDataset](#flair.datasets.FlairDatapointDataset) by hand. +Let's assume you create a function `create_datapoint(datapoint) -> Sentence` that looks somewhat like this: +```python +from flair.data import Sentence + +def create_sentence(datapoint) -> Sentence: + tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary) + spans = ... # create a list of tuples (start_token, end_token, label) from your data structure + sentence = Sentence(tokens) + for (start, end, label) in spans: + sentence[start:end+1].add_label("nel", label) +``` +Then you can use this function to create a full dataset: +```python +from flair.data import Corpus +from flair.datasets import FlairDatapointDataset + +def construct_corpus(data): + return Corpus( + train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]), + dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]), + test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]), + ) +``` +And use this to construct a corpus instead of loading a dataset. + + + + + + +## Combining NEL with Mention Detection + diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index b551b1e69..890ac9c36 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -471,7 +471,7 @@ def __init__( Args: path_to_column_file: path to the file with the column-formatted data column_name_map: a map specifying the column format - column_delimiter: default is to split on any separatator, but you can overwrite for instance with "\t" to split only on tabs + column_delimiter: default is to split on any separator, but you can overwrite for instance with "\t" to split only on tabs comment_symbol: if set, lines that begin with this symbol are treated as comments in_memory: If set to True, the dataset is kept in memory as Sentence objects, otherwise does disk reads document_separator_token: If provided, sentences that function as document boundaries are so marked diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 3e46f3a2d..02a80155e 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -6,6 +6,7 @@ from unicodedata import category import torch +from deprecated.sphinx import deprecated import flair.embeddings import flair.nn @@ -18,7 +19,7 @@ class CandidateGenerator: """Given a string, the CandidateGenerator returns possible target classes as candidates.""" - def __init__(self, candidates: Union[str, Dict], backoff: bool = True) -> None: + def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = True) -> None: # internal candidate lists of generator self.mention_to_candidates_map: Dict = {} @@ -40,6 +41,8 @@ def __init__(self, candidates: Union[str, Dict], backoff: bool = True) -> None: elif isinstance(candidates, Dict): self.mention_to_candidates_map = candidates + else: + raise ValueError(f"'{candidates}' could not be loaded.") # if lower casing is enabled, create candidate lists of lower cased versions self.backoff = backoff @@ -92,7 +95,7 @@ def __init__( candidates: Optional[CandidateGenerator] = None, **classifierargs, ) -> None: - """Initializes an EntityLinker. + """Initializes an SpanClassifier. Args: embeddings: embeddings used to embed the tokens of the sentences. @@ -232,8 +235,6 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SpanClassifier": return cast("SpanClassifier", super().load(model_path=model_path)) -def EntityLinker(**classifierargs): - from warnings import warn - - warn("The EntityLinker class is deprecated and will be removed in Flair 1.0. Use SpanClassifier instead!") - return SpanClassifier(**classifierargs) +@deprecated(reason="The EntityLinker was renamed to :class:`flair.models.SpanClassifier`.", version="0.12.2") +class EntityLinker(SpanClassifier): + pass