diff --git a/docs/tutorial/tutorial-basics/entity-mention-linking.md b/docs/tutorial/tutorial-basics/entity-mention-linking.md new file mode 100644 index 0000000000..e153955001 --- /dev/null +++ b/docs/tutorial/tutorial-basics/entity-mention-linking.md @@ -0,0 +1,127 @@ +# Using and creating entity mention linker + +As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN aproach](https://huggingface.co/hunflair)]. + +## Example 1: Printing Entity linking outputs to console + +To illustrate, let's use the example the hunflair models on a biomedical sentence: + +```python +from flair.models import EntityMentionLinker +from flair.nn import Classifier +from flair.tokenization import SciSpacyTokenizer +from flair.data import Sentence + +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in dolphin populations.", + use_tokenizer=SciSpacyTokenizer() +) + +ner_tagger = Classifier.load("hunflair") +ner_tagger.predict(sentence) + +nen_tagger = EntityMentionLinker.load("disease-linker-no-ab3p") +nen_tagger.predict(sentence) + +for tag in sentence.get_labels(): + print(tag) +``` + +```{note} + Here we use the `disease-linker-no-ab3p` model, as it is the simplest model to run. You might get better results by using `disease-linker` instead, + but under the hood ab3p uses an executeable that is only compiled for linux and therefore won't run on every system. + + Analogously to `disease` there are also linker for `chemical`, `species` and `gene` + all work with the `{entity_type}-linker` or `{entity_type}-linker-no-ab3p` naming-schema +``` + + +This should print: +```console +Span[4:5]: "ABCD1" → Gene (0.9509) +Span[7:11]: "X-linked adrenoleukodystrophy" → Disease (0.9872) +Span[7:11]: "X-linked adrenoleukodystrophy" → MESH:D000326/name=Adrenoleukodystrophy (195.30780029296875) +Span[13:15]: "neurodegenerative disease" → Disease (0.8988) +Span[13:15]: "neurodegenerative disease" → MESH:D019636/name=Neurodegenerative Diseases (201.1804962158203) +Span[29:30]: "mercury" → Chemical (0.9484) +Span[31:32]: "dolphin" → Species (0.8071) +``` + +As we can see, the huflair-ner model resolved entities of several types, however for the disease linker, only those of type disease were relevant: +- "X-linked adrenoleukodystrophy" refers to the entity "[Adrenoleukodystrophy](https://id.nlm.nih.gov/mesh/D000326.html)" +- "neurodegenerative disease" refers to the "[Neurodegenerative Diseases](https://id.nlm.nih.gov/mesh/D019636.html)" + + +## Example 2: Structured handling of predictions + +After the predictions, the flair sentence has multiple labels added to the sentence object. +* Each NER prediction adds a span referenced by the `label_type` from the span tagger. +* Each NEL prediction adds one or more labels (up to `k`) to the respective span. Those have the `label_type` from the entity mention linker. +* The NEL labels are ordered by their score. Depending on the exact implementation, it is possible that the order is ascending or descending, however the first one is always the best. + +Therefore, an example to extract the information to a dictionary that could be used for further processing is the following: + +```python +from flair.models import EntityMentionLinker +from flair.nn import Classifier +from flair.tokenization import SciSpacyTokenizer +from flair.data import Sentence + +sentence = Sentence( + "The mutation in the ABCD1 gene causes X-linked adrenoleukodystrophy, " + "a neurodegenerative disease, which is exacerbated by exposure to high " + "levels of mercury in dolphin populations.", + use_tokenizer=SciSpacyTokenizer() +) + +ner_tagger = Classifier.load("hunflair") +ner_tagger.predict(sentence) + +nen_tagger = EntityMentionLinker.load("disease-linker-no-ab3p") + +# top_k = 5 so that a span can have up to 5 labels assigned. +nen_tagger.predict(sentence, top_k=5) + +result_mentions = [] + +for span in sentence.get_spans(ner_tagger.label_type): + + # basic information about the span that is tagged. + span_data = { + "start": span.start_position + sentence.start_position, + "end": span.end_position + sentence.start_position, + "text": span.text, + } + + # add the ner label. We always have only one, so we can use `span.get_label(...)` + span_data["ner_label"] = span.get_label(ner_tagger.label_type).value + + mentions_found = [] + + # since `top_k` is larger than 1, we need to handle multiple nen labels. Therefore we use `span.get_labels(...)` + for label in span.get_labels(nen_tagger.label_type): + mentions_found.append({ + "id": label.value, + "score": label.score, + }) + + # extract the most probable prediction if any prediction is found. + if mentions_found: + span_data["nen_id"] = mentions_found[0]["id"] + else: + span_data["nen_id"] = None + + # add all found candidates with rating if you want to explore more than just the most probable prediction. + span_data["mention_candidates"] = mentions_found + + result_mentions.append(span_data) + +print(result_mentions) +``` + +```{note} + If you need more than the extracted ids, you can use `nen_tagger.dictionary[span_data["nen_id"]]` + to look up the [`flair.data.EntityCandidate`](#flair.data.EntityCandidate) which contains further information. +``` \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/index.rst b/docs/tutorial/tutorial-basics/index.rst index 6e59970237..b5c3a7f474 100644 --- a/docs/tutorial/tutorial-basics/index.rst +++ b/docs/tutorial/tutorial-basics/index.rst @@ -12,6 +12,7 @@ and showcases various models we ship with Flair. tagging-entities tagging-sentiment entity-linking + entity-mention-linking part-of-speech-tagging other-models how-to-tag-corpus diff --git a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md index 247b3daa11..fc9bc492b1 100644 --- a/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md +++ b/docs/tutorial/tutorial-training/how-to-train-sequence-tagger.md @@ -7,7 +7,7 @@ This tutorial section show you how to train state-of-the-art NER models and othe ## Training a named entity recognition (NER) model with transformers -For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context +For a state-of-the-art NER system you should fine-tune transformer embeddings, and use full document context (see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details). Use the following script: diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md new file mode 100644 index 0000000000..e1d916ff7d --- /dev/null +++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md @@ -0,0 +1,233 @@ +# Train a span classifier + +Span Classification models are used to model problems such as entity linking, where you already have extracted some +relevant spans +within the {term}`Sentence` and want to predict some more fine-grained labels. + +This tutorial section show you how to train models using the [Span Classifier](#flair.models.SpanClassifier) in Flair. + +## Training an entity linker (NEL) model with transformers + +For a state-of-the-art NER sytem you should fine-tune transformer embeddings, and use full document context +(see our [FLERT](https://arxiv.org/abs/2011.06993) paper for details). + +Use the following script: + +```python +from flair.datasets import ZELDA +from flair.embeddings import TransformerWordEmbeddings +from flair.models import SpanClassifier +from flair.models.entity_linker_model import CandidateGenerator +from flair.trainers import ModelTrainer +from flair.nn.decoder import PrototypicalDecoder + + +# 1. get the corpus +corpus = ZELDA() +print(corpus) + +# 2. what label do we want to predict? +label_type = 'nel' + +# 3. make the label dictionary from the corpus +label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=True) +print(label_dict) + +# 4. initialize fine-tuneable transformer embeddings WITH document context +embeddings = TransformerWordEmbeddings( + model="bert-base-uncased", + layers="-1", + subtoken_pooling="first", + fine_tune=True, + use_context=True, +) + +# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection) +tagger = SpanClassifier( + embeddings=embeddings, + label_dictionary=label_dict, + label_type=label_type, + decoder=PrototypicalDecoder( + num_prototypes=len(label_dict), + embeddings_size=embeddings.embedding_length * 2, # we use "first_last" encoding for spans + distance_function="dot_product", + ), + candidates=CandidateGenerator("zelda"), +) + +# 6. initialize trainer +trainer = ModelTrainer(tagger, corpus) + +# 7. run fine-tuning +trainer.fine_tune( + "resources/taggers/zelda-nel", + learning_rate=5.0e-6, + mini_batch_size=4, + mini_batch_chunk_size=1, # remove this parameter to speed up computation if you have a big GPU +) +``` + +As you can see, we use [`TransformerWordEmbeddings`](#flair.embeddings.token.TransformerWordEmbeddings) based on [bert-base-uncased](https://huggingface.co/bert-base-uncased) embeddings. We enable fine-tuning and set `use_context` to True. +We use [Prototypical Networks](https://arxiv.org/abs/1703.05175), to generalize bettwer in the few-shot classification setting. +Also, we set a `CandidateGenerator` in the [`SpanClassifier`](#flair.models.SpanClassifier). +This way we limit the classification to a small set of candidates that are chosen depending on the text of the respective span. + +## Loading a ColumnCorpus + +In cases you want to train over a custom named entity linking dataset, you can load them with the [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus) object. +Most sequence labeling datasets in NLP use some sort of column format in which each line is a word and each column is +one level of linguistic annotation. See for instance this sentence: + +```console +George B-George_Washington +Washington I-George_Washington +went O +to O +Washington B-Washington_D_C + +Sam B-Sam_Houston +Houston I-Sam_Houston +stayed O +home O +``` + +The first column is the word itself, the second BIO-annotated tags used to specify the spans that will be classified. To read such a +dataset, define the column structure as a dictionary and instantiate a [`ColumnCorpus`](#flair.datasets.sequence_labeling.ColumnCorpus). + +```python +from flair.data import Corpus +from flair.datasets import ColumnCorpus + +# define columns +columns = {0: "text", 1: "nel"} + +# this is the folder in which train, test and dev files reside +data_folder = '/path/to/data/folder' + +# init a corpus using column format, data folder and the names of the train, dev and test files +corpus: Corpus = ColumnCorpus(data_folder, columns) +``` + +## constructing a dataset in memory + +If you have a pipeline where you need to construct your dataset from a different data source, +you can always construct a [Corpus](#flair.data.Corpus) with [FlairDatapointDataset](#flair.datasets.base.FlairDatapointDataset) by hand. +Let's assume you create a function `create_datapoint(datapoint) -> Sentence` that looks somewhat like this: +```python +from flair.data import Sentence + +def create_sentence(datapoint) -> Sentence: + tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary) + spans = ... # create a list of tuples (start_token, end_token, label) from your data structure + sentence = Sentence(tokens) + for (start, end, label) in spans: + sentence[start:end+1].add_label("nel", label) +``` +Then you can use this function to create a full dataset: +```python +from flair.data import Corpus +from flair.datasets import FlairDatapointDataset + +def construct_corpus(data): + return Corpus( + train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]), + dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]), + test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]), + ) +``` +And use this to construct a corpus instead of loading a dataset. + + +## Combining NEL with Mention Detection + +often, you don't just want to use a Named Entity Linking model alone, but combine it with a Mention Detection or Named Entity Recognition model. +For this, you can use a [Multitask Model](#flair.models.MultitaskModel) to combine a [SequenceTagger](#flair.models.SequenceTagger) and a [Span Classifier](#flair.models.SpanClassifier). + +```python +from flair.datasets import NER_MULTI_WIKINER, ZELDA +from flair.embeddings import TransformerWordEmbeddings +from flair.models import SequenceTagger, SpanClassifier +from flair.models.entity_linker_model import CandidateGenerator +from flair.trainers import ModelTrainer +from flair.nn import PrototypicalDecoder +from flair.nn.multitask import make_multitask_model_and_corpus + +# 1. get the corpus +ner_corpus = NER_MULTI_WIKINER() +nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one + +# --- Embeddings that are shared by both models --- # +shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True) + +ner_label_dict = ner_corpus.make_label_dictionary("ner", add_unk=False) + +ner_model = SequenceTagger( + embeddings=shared_embeddings, + tag_dictionary=ner_label_dict, + tag_type="ner", + use_rnn=False, + use_crf=False, + reproject_embeddings=False, +) + + +nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True) + +nel_model = SpanClassifier( + embeddings=shared_embeddings, + label_dictionary=nel_label_dict, + label_type="ner", + decoder=PrototypicalDecoder( + num_prototypes=len(nel_label_dict), + embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans + distance_function="dot_product", + ), + candidates=CandidateGenerator("zelda"), +) + + +# -- Define mapping (which tagger should train on which model) -- # +multitask_model, multicorpus = make_multitask_model_and_corpus( + [ + (ner_model, ner_corpus), + (nel_model, nel_corpus), + ] +) + +# -- Create model trainer and train -- # +trainer = ModelTrainer(multitask_model, multicorpus) +trainer.fine_tune(f"resources/taggers/zelda_with_mention") +``` + +Here, the [make_multitask_model_and_corpus](#flair.nn.multitask.make_multitask_model_and_corpus) method creates a multitask model and a multicorpus where each sub-model is aligned for a sub-corpus. + +### Multitask with aligned training data + +If you have sentences with both annotations for ner and for nel, you might want to use a single corpus for both models. + +This means, that you need to manually the `multitask_id` to the sentences: + +```python +from flair.data import Sentence + +def create_sentence(datapoint) -> Sentence: + tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary) + spans = ... # create a list of tuples (start_token, end_token, label) from your data structure + sentence = Sentence(tokens) + for (start, end, ner_label, nel_label) in spans: + sentence[start:end+1].add_label("ner", ner_label) + sentence[start:end+1].add_label("nel", nel_label) + sentence.add_label("multitask_id", "Task_0") # Task_0 for the NER model + sentence.add_label("multitask_id", "Task_1") # Task_1 for the NEL model +``` + +Then you can run the multitask training script with the exception that you create the [MultitaskModel](#flair.models.MultitaskModel) directly. + +```python +... +multitask_model = MultitaskModel([ner_model, nel_model], use_all_tasks=True) +``` + +Here, setting `use_all_tasks=True` means that we will jointly train on both tasks at the same time. This will save a lot of training time, +as the shared embedding will be calculated once but used twice (once for each model). + diff --git a/docs/tutorial/tutorial-training/index.rst b/docs/tutorial/tutorial-training/index.rst index 70209a3f70..c47f12cf83 100644 --- a/docs/tutorial/tutorial-training/index.rst +++ b/docs/tutorial/tutorial-training/index.rst @@ -13,3 +13,4 @@ This tutorial illustrates how you can train your own state-of-the-art NLP models how-to-load-custom-dataset how-to-train-sequence-tagger how-to-train-text-classifier + how-to-train-span-classifier diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 820f3ba570..0d0c78491c 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -471,7 +471,7 @@ def __init__( Args: path_to_column_file: path to the file with the column-formatted data column_name_map: a map specifying the column format - column_delimiter: default is to split on any separatator, but you can overwrite for instance with "\t" to split only on tabs + column_delimiter: default is to split on any separator, but you can overwrite for instance with "\t" to split only on tabs comment_symbol: if set, lines that begin with this symbol are treated as comments in_memory: If set to True, the dataset is kept in memory as Sentence objects, otherwise does disk reads document_separator_token: If provided, sentences that function as document boundaries are so marked diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 838664e402..1d716e7904 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -2,10 +2,11 @@ import re from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union, cast from unicodedata import category import torch +from deprecated.sphinx import deprecated import flair.embeddings import flair.nn @@ -18,7 +19,7 @@ class CandidateGenerator: """Given a string, the CandidateGenerator returns possible target classes as candidates.""" - def __init__(self, candidates: Union[str, Dict], backoff: bool = True) -> None: + def __init__(self, candidates: Union[str, Dict[str, List[str]]], backoff: bool = True) -> None: # internal candidate lists of generator self.mention_to_candidates_map: Dict = {} @@ -40,7 +41,9 @@ def __init__(self, candidates: Union[str, Dict], backoff: bool = True) -> None: elif isinstance(candidates, Dict): self.mention_to_candidates_map = candidates - + else: + raise ValueError(f"'{candidates}' could not be loaded.") + self.mention_to_candidates_map = cast(Dict[str, List[str]], self.mention_to_candidates_map) # if lower casing is enabled, create candidate lists of lower cased versions self.backoff = backoff if self.backoff: @@ -48,14 +51,16 @@ def __init__(self, candidates: Union[str, Dict], backoff: bool = True) -> None: lowercased_mention_to_candidates_map: Dict = {} # go through each mention and its candidates - for mention, candidates in self.mention_to_candidates_map.items(): + for mention, candidates_list in self.mention_to_candidates_map.items(): backoff_mention = self._make_backoff_string(mention) # check if backoff mention already seen. If so, add candidates. Else, create new entry. if backoff_mention in lowercased_mention_to_candidates_map: current_candidates = lowercased_mention_to_candidates_map[backoff_mention] - lowercased_mention_to_candidates_map[backoff_mention] = set(current_candidates).union(candidates) + lowercased_mention_to_candidates_map[backoff_mention] = set(current_candidates).union( + candidates_list + ) else: - lowercased_mention_to_candidates_map[backoff_mention] = candidates + lowercased_mention_to_candidates_map[backoff_mention] = candidates_list # set lowercased version as map self.mention_to_candidates_map = lowercased_mention_to_candidates_map @@ -92,7 +97,7 @@ def __init__( candidates: Optional[CandidateGenerator] = None, **classifierargs, ) -> None: - """Initializes an EntityLinker. + """Initializes an SpanClassifier. Args: embeddings: embeddings used to embed the tokens of the sentences. @@ -232,8 +237,6 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "SpanClassifier": return cast("SpanClassifier", super().load(model_path=model_path)) -def EntityLinker(**classifierargs): - from warnings import warn - - warn("The EntityLinker class is deprecated and will be removed in Flair 1.0. Use SpanClassifier instead!") - return SpanClassifier(**classifierargs) +@deprecated(reason="The EntityLinker was renamed to :class:`flair.models.SpanClassifier`.", version="0.12.2") +class EntityLinker(SpanClassifier): + pass diff --git a/flair/models/entity_mention_linking.py b/flair/models/entity_mention_linking.py index 1bb58e4355..3f39003b77 100644 --- a/flair/models/entity_mention_linking.py +++ b/flair/models/entity_mention_linking.py @@ -927,9 +927,13 @@ def _fetch_model(model_name: str) -> str: bio_base_repo = "hunflair" hf_model_map = { "gene-linker": f"{bio_base_repo}/biosyn-sapbert-bc2gn", + "gene-linker-no-ab3p": f"{bio_base_repo}/biosyn-sapbert-bc2gn-no-ab3p", "disease-linker": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-disease", + "disease-linker-no-ab3p": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-disease-no-ab3p", "chemical-linker": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-chemical", + "chemical-linker-no-ab3p": f"{bio_base_repo}/biosyn-sapbert-bc5cdr-chemical-no-ab3p", "species-linker": f"{bio_base_repo}/sapbert-ncbi-taxonomy", + "species-linker-no-ab3p": f"{bio_base_repo}/sapbert-ncbi-taxonomy-no-ab3p", } if model_name in hf_model_map: @@ -963,7 +967,7 @@ def _get_state_dict(self): @classmethod def build( cls, - model_name_or_path: Union[str, Path], + model_name_or_path: str, label_type: str = "link", dictionary_name_or_path: Optional[Union[str, Path]] = None, hybrid_search: bool = True, @@ -975,11 +979,21 @@ def build( dictionary: Optional[EntityLinkingDictionary] = None, dataset_name: Optional[str] = None, ) -> "EntityMentionLinker": - """Loads a model for biomedical named entity normalization.""" - if not isinstance(model_name_or_path, str): - raise ValueError(f"String matching model name has to be an string (and not {type(model_name_or_path)}") - model_name_or_path = cast(str, model_name_or_path) + """Builds a model for biomedical named entity normalization. + Args: + model_name_or_path: the name to an transformer embedding model on the huggingface hub or "exact-string-match" + label_type: the label-type the predictions should be assigned to + dictionary_name_or_path: the name or path to a dictionary. If the model name is a common biomedical model, the dictionary name is asigned by default. Otherwise you can pass any of "gene", "species", "disease", "chemical" to get the respective biomedical dictionary. + hybrid_search: if True add a character-ngram-tfidf embedding on top of the transformer embedding model. + batch_size: the batch_size used when indexing the dictionary. + similarity_metric: the metric used to compare similarity between two embeddings. + preprocessor: The preprocessor used to preprocess. If None is passed, it used an AD3P processor. + sparse_weight: if hybrid_search is added, the sparse weight will weight the importance of the character-ngram-tfidf embedding. For the common models, this will be overwritten with a specific value. + entity_type: the entity type of the mentions + dictionary: the dictionary provided in memory. If None, the dictionary is loaded from dictionary_name_or_path. + dataset_name: the name to assign the dictionary for reference. + """ if dictionary is None: if dictionary_name_or_path is None or isinstance(dictionary_name_or_path, str): dictionary_name_or_path = cls.__get_dictionary_path( @@ -987,15 +1001,11 @@ def build( ) dictionary = load_dictionary(dictionary_name_or_path, dataset_name=dataset_name) - if isinstance(model_name_or_path, str): - model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( - model_name_or_path=model_name_or_path, - entity_type=entity_type, - hybrid_search=hybrid_search, - ) - else: - assert entity_type is not None, "When using a custom model you must specify `entity_type`" - assert entity_type in ENTITY_TYPES, f"Invalid entity type `{entity_type}! Must be one of: {ENTITY_TYPES}" + model_name_or_path, entity_type = cls.__get_model_path_and_entity_type( + model_name_or_path=model_name_or_path, + entity_type=entity_type, + hybrid_search=hybrid_search, + ) preprocessor = ( preprocessor @@ -1032,10 +1042,10 @@ def build( @staticmethod def __get_model_path_and_entity_type( - model_name_or_path: Union[str, Path], + model_name_or_path: str, entity_type: Optional[str] = None, hybrid_search: bool = False, - ) -> Tuple[Union[str, Path], str]: + ) -> Tuple[str, str]: """Try to figure out what model the user wants.""" if model_name_or_path not in MODELS and model_name_or_path not in ENTITY_TYPES: raise ValueError( @@ -1049,7 +1059,7 @@ def __get_model_path_and_entity_type( if hybrid_search: # load model by entity_type - if isinstance(model_name_or_path, str) and model_name_or_path in ENTITY_TYPES: + if model_name_or_path in ENTITY_TYPES: model_name_or_path = cast(str, model_name_or_path) entity_type = model_name_or_path @@ -1063,25 +1073,20 @@ def __get_model_path_and_entity_type( model_name_or_path, ) model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + elif model_name_or_path not in PRETRAINED_HYBRID_MODELS: + logger.warning( + "EntityMentionLinker: `hybrid_search=True` but model `%s` was not trained for hybrid search." + " Results may be poor.", + model_name_or_path, + ) + assert ( + entity_type is not None + ), f"For non-hybrid model `{model_name_or_path}` with `hybrid_search=True` you must specify `entity_type`" else: - if model_name_or_path not in PRETRAINED_HYBRID_MODELS: - logger.warning( - "EntityMentionLinker: `hybrid_search=True` but model `%s` was not trained for hybrid search." - " Results may be poor.", - model_name_or_path, - ) - assert ( - entity_type is not None - ), f"For non-hybrid model `{model_name_or_path}` with `hybrid_search=True` you must specify `entity_type`" - else: - model_name_or_path = cast(str, model_name_or_path) - entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] - - else: - if isinstance(model_name_or_path, str): model_name_or_path = cast(str, model_name_or_path) - if model_name_or_path in ENTITY_TYPES: - model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] + entity_type = PRETRAINED_HYBRID_MODELS[model_name_or_path] + elif model_name_or_path in ENTITY_TYPES: + model_name_or_path = ENTITY_TYPE_TO_DENSE_MODEL[model_name_or_path] assert ( entity_type is not None diff --git a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md index e7cf794998..3253e72414 100644 --- a/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md +++ b/resources/docs/HUNFLAIR_TUTORIAL_3_ENTITY_LINKING.md @@ -3,7 +3,7 @@ After adding named entity recognition tags to your sentence, you can run named entity linking on these annotations. ```python -from flair.models.biomedical_entity_linking import EntityMentionLinker +from flair.models import EntityMentionLinker from flair.nn import Classifier from flair.tokenization import SciSpacyTokenizer from flair.data import Sentence @@ -56,7 +56,7 @@ a knowledge base or ontology. We have pre-configured combinations of models and You can also provide your own model and dictionary: ```python -from flair.models.biomedical_entity_linking import EntityMentionLinker +from flair.models import EntityMentionLinker nen_tagger = EntityMentionLinker.build("name_or_path_to_your_model", dictionary_names_or_path="name_or_path_to_your_dictionary")