Skip to content

Commit

Permalink
add prediction label type for span classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Fuchs authored and helpmefindaname committed Jun 28, 2024
1 parent ca1b90b commit 90480b2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
ner_corpus = NER_MULTI_WIKINER()
nel_corpus = ZELDA(column_format={0: "text", 2: "ner"}) # need to set the label type to be the same as the ner one
nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one

# --- Embeddings that are shared by both models --- #
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)
Expand All @@ -171,12 +171,13 @@ ner_model = SequenceTagger(
)


nel_label_dict = nel_corpus.make_label_dictionary("ner", add_unk=True)
nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True)

nel_model = SpanClassifier(
embeddings=shared_embeddings,
label_dictionary=nel_label_dict,
label_type="ner",
label_type="nel",
span_label_type="ner",
decoder=PrototypicalDecoder(
num_prototypes=len(nel_label_dict),
embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans
Expand Down
9 changes: 9 additions & 0 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
label_dictionary: Dictionary,
pooling_operation: str = "first_last",
label_type: str = "nel",
span_label_type: Optional[str] = None,
candidates: Optional[CandidateGenerator] = None,
**classifierargs,
) -> None:
Expand All @@ -107,6 +108,7 @@ def __init__(
text representation we take the average of the embeddings of the token in the mention.
`first_last` concatenates the embedding of the first and the embedding of the last token.
label_type: name of the label you use.
span_label_type: name of the label you use for inputs of predictions.
candidates: If provided, use a :class:`CandidateGenerator` for prediction candidates.
**classifierargs: The arguments propagated to :meth:`flair.nn.DefaultClassifier.__init__`
"""
Expand All @@ -121,6 +123,7 @@ def __init__(

self.pooling_operation = pooling_operation
self._label_type = label_type
self._span_label_type = span_label_type

cases: Dict[str, Callable[[Span, List[str]], torch.Tensor]] = {
"average": self.emb_mean,
Expand Down Expand Up @@ -153,6 +156,11 @@ def emb_mean(self, span, embedding_names):
return torch.mean(torch.stack([token.get_embedding(embedding_names) for token in span], 0), 0)

def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Span]:
if self._span_label_type is not None:
spans = sentence.get_spans(self._span_label_type)
# only use span label type if there are predictions, otherwise search for output label type (training labels)
if spans:
return spans
return sentence.get_spans(self.label_type)

def _filter_data_point(self, data_point: Sentence) -> bool:
Expand All @@ -170,6 +178,7 @@ def _get_state_dict(self):
"pooling_operation": self.pooling_operation,
"loss_weights": self.weight_dict,
"candidates": self.candidates,
"span_label_type": self._span_label_type,
}
return model_state

Expand Down

0 comments on commit 90480b2

Please sign in to comment.