Skip to content

Commit

Permalink
Merge pull request #3558 from fkdosilovic/bug-fix-3552-document-embed…
Browse files Browse the repository at this point in the history
…dings-with-cls-pooling

Fix error when cls_pooling="mean" or cls_pooling="max" for TransformerDocumentEmbeddings
  • Loading branch information
helpmefindaname authored Nov 29, 2024
2 parents a9862df + aa2ebdc commit 9d2c1ed
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
30 changes: 14 additions & 16 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@

import flair
from flair.data import Sentence, Token, log
from flair.embeddings.base import (
DocumentEmbeddings,
Embeddings,
TokenEmbeddings,
register_embeddings,
)
from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings

SENTENCE_BOUNDARY_TAG: str = "[FLERT]"

Expand Down Expand Up @@ -198,24 +193,33 @@ def fill_mean_token_embeddings(


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1]


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)

return result


@torch.jit.script_if_tracing
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)

return result


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]]
Expand Down Expand Up @@ -1127,11 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if peft_config is not None:
# add adapters for finetuning
try:
from peft import (
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
from peft import TaskType, get_peft_model, prepare_model_for_kbit_training
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
Expand Down Expand Up @@ -1446,9 +1446,7 @@ def forward(
else:
assert sub_token_lengths is not None
if self.cls_pooling == "cls":
document_embeddings = sentence_hidden_states[
torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1
]
document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "mean":
document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "max":
Expand Down
17 changes: 16 additions & 1 deletion tests/embeddings/test_transformer_document_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flair.data import Dictionary
import pytest

from flair.data import Dictionary, Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.nn import Classifier
Expand Down Expand Up @@ -37,3 +39,16 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
# check that context_length and use_context_separator is the same for both
assert model.embeddings.context_length == loaded_single_task.embeddings.context_length
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator


@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
def test_cls_pooling(cls_pooling):
embeddings = TransformerDocumentEmbeddings(
model="distilbert-base-uncased",
layers="-1",
cls_pooling=cls_pooling,
allow_long_sentences=True,
)
sentence = Sentence("Today is a good day.")
embeddings.embed(sentence)
assert sentence.embedding is not None

0 comments on commit 9d2c1ed

Please sign in to comment.