From 90480b224cad74469e0476e04f3b2f6898e4af06 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 22 Mar 2024 11:35:06 +0100 Subject: [PATCH] add prediction label type for span classifier --- .../tutorial-training/how-to-train-span-classifier.md | 7 ++++--- flair/models/entity_linker_model.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md index e1d916ff7d..e5d32cb426 100644 --- a/docs/tutorial/tutorial-training/how-to-train-span-classifier.md +++ b/docs/tutorial/tutorial-training/how-to-train-span-classifier.md @@ -154,7 +154,7 @@ from flair.nn.multitask import make_multitask_model_and_corpus # 1. get the corpus ner_corpus = NER_MULTI_WIKINER() -nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one +nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one # --- Embeddings that are shared by both models --- # shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True) @@ -171,12 +171,13 @@ ner_model = SequenceTagger( ) -nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True) +nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True) nel_model = SpanClassifier( embeddings=shared_embeddings, label_dictionary=nel_label_dict, - label_type="ner", + label_type="nel", + span_label_type="ner", decoder=PrototypicalDecoder( num_prototypes=len(nel_label_dict), embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 1d716e7904..d003a99f56 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -94,6 +94,7 @@ def __init__( label_dictionary: Dictionary, pooling_operation: str = "first_last", label_type: str = "nel", + span_label_type: Optional[str] = None, candidates: Optional[CandidateGenerator] = None, **classifierargs, ) -> None: @@ -107,6 +108,7 @@ def __init__( text representation we take the average of the embeddings of the token in the mention. `first_last` concatenates the embedding of the first and the embedding of the last token. label_type: name of the label you use. + span_label_type: name of the label you use for inputs of predictions. candidates: If provided, use a :class:`CandidateGenerator` for prediction candidates. **classifierargs: The arguments propagated to :meth:`flair.nn.DefaultClassifier.__init__` """ @@ -121,6 +123,7 @@ def __init__( self.pooling_operation = pooling_operation self._label_type = label_type + self._span_label_type = span_label_type cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = { "average": self.emb_mean, @@ -153,6 +156,11 @@ def emb_mean(self, span, embedding_names): return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0) def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]: + if self._span_label_type is not None: + spans = sentence.get_spans(self._span_label_type) + # only use span label type if there are predictions, otherwise search for output label type (training labels) + if spans: + return spans return sentence.get_spans(self.label_type) def _filter_data_point(self, data_point: Sentence) -> bool: @@ -170,6 +178,7 @@ def _get_state_dict(self): "pooling_operation": self.pooling_operation, "loss_weights": self.weight_dict, "candidates": self.candidates, + "span_label_type": self._span_label_type, } return model_state