diff --git a/flair/models/__init__.py b/flair/models/__init__.py index e75daf074..d9fca4a70 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,3 +1,4 @@ +from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -37,4 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", + "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py new file mode 100644 index 000000000..ec3385a78 --- /dev/null +++ b/flair/models/deepncm_classification_model.py @@ -0,0 +1,219 @@ +import logging +from typing import Literal, Optional + +import torch + +import flair +from flair.data import Dictionary + +log = logging.getLogger("flair") + + +class DeepNCMDecoder(torch.nn.Module): + """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. + + This model combines deep learning with the Nearest Class Mean (NCM) approach. + It uses document embeddings to represent text, optionally applies an encoder, + and classifies based on the nearest class prototype in the embedded space. + + The model supports various methods for updating class prototypes during training, + making it adaptable to different learning scenarios. + + This implementation is based on the research paper: + Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. + In International Conference on Learning Representations (ICLR) 2018 Workshop. + URL: https://openreview.net/forum?id=rkPLZ4JPM + """ + + def __init__( + self, + label_dictionary: Dictionary, + embeddings_size: int, + encoding_dim: Optional[int] = None, + alpha: float = 0.9, + mean_update_method: Literal["online", "condensation", "decay"] = "online", + use_encoder: bool = True, + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. + + Args: + encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. + mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). + use_encoder: Whether to apply an encoder to the input embeddings (default is True). + multi_label: Whether to predict multiple labels per sentence (default is False). + """ + + super().__init__() + + self.label_dictionary = label_dictionary + self._num_prototypes = len(label_dictionary) + + self.alpha = alpha + self.mean_update_method = mean_update_method + self.use_encoder = use_encoder + self.multi_label = multi_label + + self.embedding_dim = embeddings_size + + if use_encoder: + self.encoding_dim = encoding_dim or self.embedding_dim + else: + self.encoding_dim = self.embedding_dim + + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + + self._validate_parameters() + + if self.use_encoder: + self.encoder = torch.nn.Sequential( + torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), + ) + else: + self.encoder = torch.nn.Sequential(torch.nn.Identity()) + + # all parameters will be pushed internally to the specified device + self.to(flair.device) + + def _validate_parameters(self) -> None: + """Validate that the input parameters have valid and compatible values.""" + assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" + assert self.mean_update_method in [ + "online", + "condensation", + "decay", + ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" + assert self.encoding_dim > 0, "encoding_dim must be greater than 0" + + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) + + def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + + Returns: + torch.Tensor: Distances between encoded embeddings and class prototypes. + """ + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) + + def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: + """Calculate updates for class prototypes based on the current batch. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + labels: True labels for the input sentences. + """ + one_hot = ( + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() + ) + + updates = torch.matmul(one_hot.t(), encoded_embeddings) + counts = one_hot.sum(dim=0) + mask = counts > 0 + self.prototype_updates[mask] += updates[mask] + self.prototype_update_counts[mask] += counts[mask] + + def update_prototypes(self) -> None: + """Apply accumulated updates to class prototypes.""" + with torch.no_grad(): + update_mask = self.prototype_update_counts > 0 + if update_mask.any(): + if self.mean_update_method in ["online", "condensation"]: + new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] + self.class_prototypes[update_mask] = ( + self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] + + self.prototype_updates[update_mask] + ) / new_counts.unsqueeze(1) + self.class_counts[update_mask] = new_counts + elif self.mean_update_method == "decay": + new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ + update_mask + ].unsqueeze(1) + self.class_prototypes[update_mask] = ( + self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes + ) + self.class_counts[update_mask] += self.prototype_update_counts[update_mask] + + # Reset prototype updates + self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) + + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. + + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. + """ + encoded_embeddings = self.encoder(embedded) + + distances = self._calculate_distances(encoded_embeddings) + + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) + + scores = -distances + + return scores + + def get_prototype(self, class_name: str) -> torch.Tensor: + """Get the prototype vector for a given class name. + + Args: + class_name: The name of the class whose prototype vector is requested. + + Returns: + torch.Tensor: The prototype vector for the given class. + + Raises: + ValueError: If the class name is not found in the label dictionary. + """ + try: + class_idx = self.label_dictionary.get_idx_for_item(class_name) + except IndexError as exc: + raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc + + return self.class_prototypes[class_idx].clone() + + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. + + Args: + input_vector (torch.Tensor): The input vector to compare against prototypes. + top_k (int): The number of closest prototypes to return (default is 5). + + Returns: + list[tuple[str, float]]: Each tuple contains (class_name, distance). + """ + if input_vector.dim() != 1: + raise ValueError("Input vector must be a 1D tensor") + if input_vector.size(0) != self.class_prototypes.size(1): + raise ValueError( + f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" + ) + + input_vector = input_vector.unsqueeze(0) + distances = self._calculate_distances(input_vector) + top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) + + nearest_prototypes = [] + for idx, value in zip(top_k_indices, top_k_values): + class_name = self.label_dictionary.get_item_for_index(idx.item()) + nearest_prototypes.append((class_name, value.item())) + + return nearest_prototypes diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a..d4062c89c 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch.nn from torch import Tensor @@ -765,8 +765,11 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # pass data points through network to get encoded data point tensor data_point_tensor = self._encode_data_points(sentences, data_points) - # decode - scores = self.decoder(data_point_tensor) + # decode, passing label tensor if needed, such as for prototype updates + if "label_tensor" in inspect.signature(self.decoder.forward).parameters: + scores = self.decoder(data_point_tensor, label_tensor) + else: + scores = self.decoder(data_point_tensor) # an optional masking step (no masking in most cases) scores = self._mask_scores(scores, data_points) @@ -801,7 +804,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ): + ) -> Optional[Union[List[DT], Tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: diff --git a/flair/trainers/plugins/__init__.py b/flair/trainers/plugins/__init__.py index 373fdf969..c3b1c1bab 100644 --- a/flair/trainers/plugins/__init__.py +++ b/flair/trainers/plugins/__init__.py @@ -1,6 +1,7 @@ from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt from .functional.anneal_on_plateau import AnnealingPlugin from .functional.checkpoints import CheckpointPlugin +from .functional.deepncm_trainer_plugin import DeepNCMPlugin from .functional.linear_scheduler import LinearSchedulerPlugin from .functional.reduce_transformer_vocab import ReduceTransformerVocabPlugin from .functional.weight_extractor import WeightExtractorPlugin @@ -15,6 +16,7 @@ "AnnealingPlugin", "CheckpointPlugin", "ClearmlLoggerPlugin", + "DeepNCMPlugin", "LinearSchedulerPlugin", "WeightExtractorPlugin", "LogFilePlugin", diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py new file mode 100644 index 000000000..e5394debd --- /dev/null +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -0,0 +1,42 @@ +import torch + +from flair.models import MultitaskModel +from flair.models.deepncm_classification_model import DeepNCMDecoder +from flair.trainers.plugins.base import TrainerPlugin + + +class DeepNCMPlugin(TrainerPlugin): + """Plugin for training DeepNCMClassifier. + + Handles both multitask and single-task scenarios. + """ + + def _process_models(self, operation: str): + """Process updates for all DeepNCMDecoder decoders in the trainer. + + Args: + operation (str): The operation to perform ('condensation' or 'update') + """ + model = self.trainer.model + + models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] + + for sub_model in models: + if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): + if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": + sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) + elif operation == "update": + sub_model.decoder.update_prototypes() + + @TrainerPlugin.hook + def after_training_epoch(self, **kwargs): + """Update prototypes after each training epoch.""" + self._process_models("condensation") + + @TrainerPlugin.hook + def after_training_batch(self, **kwargs): + """Update prototypes after each training batch.""" + self._process_models("update") + + def __str__(self) -> str: + return "DeepNCMPlugin" diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py new file mode 100644 index 000000000..b587a3314 --- /dev/null +++ b/tests/models/test_deepncm_classifier.py @@ -0,0 +1,188 @@ +import pytest +import torch + +from flair.data import Sentence +from flair.datasets import ClassificationCorpus +from flair.embeddings import TransformerDocumentEmbeddings +from flair.models import DeepNCMDecoder, TextClassifier +from flair.trainers import ModelTrainer +from flair.trainers.plugins import DeepNCMPlugin +from tests.model_test_utils import BaseModelTest + + +class TestDeepNCMDecoder(BaseModelTest): + model_cls = TextClassifier + train_label_type = "class" + multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] + training_args = { + "max_epochs": 2, + "mini_batch_size": 4, + "learning_rate": 1e-5, + } + + @pytest.fixture() + def embeddings(self): + return TransformerDocumentEmbeddings("distilbert-base-uncased", fine_tune=True) + + @pytest.fixture() + def corpus(self, tasks_base_path): + return ClassificationCorpus(tasks_base_path / "imdb", label_type=self.train_label_type) + + @pytest.fixture() + def multiclass_train_test_sentence(self): + return Sentence("This movie was great!") + + def build_model(self, embeddings, label_dict, **kwargs): + + model_args = { + "embeddings": embeddings, + "label_dictionary": label_dict, + "label_type": self.train_label_type, + "use_encoder": False, + "encoding_dim": 64, + "alpha": 0.95, + "mean_update_method": "online", + } + model_args.update(kwargs) + + deepncm_decoder = DeepNCMDecoder( + label_dictionary=model_args["label_dictionary"], + embeddings_size=model_args["embeddings"].embedding_length, + alpha=model_args["alpha"], + encoding_dim=model_args["encoding_dim"], + mean_update_method=model_args["mean_update_method"], + ) + + model = self.model_cls( + embeddings=model_args["embeddings"], + label_dictionary=model_args["label_dictionary"], + label_type=model_args["label_type"], + multi_label=model_args.get("multi_label", False), + decoder=deepncm_decoder, + ) + + return model + + @pytest.mark.integration() + def test_train_load_use_classifier( + self, results_base_path, corpus, embeddings, example_sentence, train_test_sentence + ): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + + model = self.build_model(embeddings, label_dict, mean_update_method="condensation") + + trainer = ModelTrainer(model, corpus) + trainer.fine_tune( + results_base_path, optimizer=torch.optim.AdamW, plugins=[DeepNCMPlugin()], **self.training_args + ) + + model.predict(train_test_sentence) + + for label in train_test_sentence.get_labels(self.train_label_type): + assert label.value is not None + assert 0.0 <= label.score <= 1.0 + assert isinstance(label.score, float) + + del trainer, model, corpus + + loaded_model = self.model_cls.load(results_base_path / "final-model.pt") + + loaded_model.predict(example_sentence) + loaded_model.predict([example_sentence, self.empty_sentence]) + loaded_model.predict([self.empty_sentence]) + + def test_get_prototype(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + prototype = model.decoder.get_prototype(next(iter(label_dict.get_items()))) + assert isinstance(prototype, torch.Tensor) + assert prototype.shape == (model.decoder.encoding_dim,) + + with pytest.raises(ValueError): + model.decoder.get_prototype("NON_EXISTENT_CLASS") + + def test_get_closest_prototypes(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + input_vector = torch.randn(model.decoder.encoding_dim) + closest_prototypes = model.decoder.get_closest_prototypes(input_vector, top_k=2) + + assert len(closest_prototypes) == 2 + assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) + + with pytest.raises(ValueError): + model.decoder.get_closest_prototypes(torch.randn(model.decoder.encoding_dim + 1)) + + def test_forward_loss(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + loss, count = model.forward_loss(sentences) + assert isinstance(loss, torch.Tensor) + assert loss.item() > 0 + assert count == len(sentences) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_mean_update_methods(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + initial_prototypes = model.decoder.class_prototypes.clone() + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + model.forward_loss(sentences) + model.decoder.update_prototypes() + + assert not torch.all(torch.eq(initial_prototypes, model.decoder.class_prototypes)) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + trainer = ModelTrainer(model, corpus) + plugin = DeepNCMPlugin() + plugin.attach_to(trainer) + + initial_class_counts = model.decoder.class_counts.clone() + initial_prototypes = model.decoder.class_prototypes.clone() + + # Simulate training epoch + plugin.after_training_epoch() + + if mean_update_method == "condensation": + assert torch.all( + model.decoder.class_counts == 1 + ), "Class counts should be 1 for condensation method after epoch" + elif mean_update_method == "online": + assert torch.all( + torch.eq(model.decoder.class_counts, initial_class_counts) + ), "Class counts should not change for online method after epoch" + + # Simulate training batch + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + model.forward_loss(sentences) + plugin.after_training_batch() + + assert not torch.all( + torch.eq(initial_prototypes, model.decoder.class_prototypes) + ), "Prototypes should be updated after a batch" + + if mean_update_method == "condensation": + assert torch.all( + model.decoder.class_counts >= 1 + ), "Class counts should be >= 1 for condensation method after a batch" + elif mean_update_method == "online": + assert torch.all( + model.decoder.class_counts > initial_class_counts + ), "Class counts should increase for online method after a batch"