Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when cls_pooling="mean" or cls_pooling="max" for TransformerDocumentEmbeddings #3558

30 changes: 14 additions & 16 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@

import flair
from flair.data import Sentence, Token, log
from flair.embeddings.base import (
DocumentEmbeddings,
Embeddings,
TokenEmbeddings,
register_embeddings,
)
from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings

SENTENCE_BOUNDARY_TAG: str = "[FLERT]"

Expand Down Expand Up @@ -198,24 +193,33 @@ def fill_mean_token_embeddings(


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1]


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)

return result


@torch.jit.script_if_tracing
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)

return result


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]]
Expand Down Expand Up @@ -1127,11 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if peft_config is not None:
# add adapters for finetuning
try:
from peft import (
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
from peft import TaskType, get_peft_model, prepare_model_for_kbit_training
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
Expand Down Expand Up @@ -1446,9 +1446,7 @@ def forward(
else:
assert sub_token_lengths is not None
if self.cls_pooling == "cls":
document_embeddings = sentence_hidden_states[
torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1
]
document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "mean":
document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "max":
Expand Down
17 changes: 16 additions & 1 deletion tests/embeddings/test_transformer_document_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flair.data import Dictionary
import pytest

from flair.data import Dictionary, Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.nn import Classifier
Expand Down Expand Up @@ -37,3 +39,16 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
# check that context_length and use_context_separator is the same for both
assert model.embeddings.context_length == loaded_single_task.embeddings.context_length
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator


@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
def test_cls_pooling(cls_pooling):
embeddings = TransformerDocumentEmbeddings(
model="distilbert-base-uncased",
layers="-1",
cls_pooling=cls_pooling,
allow_long_sentences=True,
)
sentence = Sentence("Today is a good day.")
embeddings.embed(sentence)
assert sentence.embedding is not None