diff --git a/flair/models/__init__.py b/flair/models/__init__.py index bf3651078..d9fca4a70 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,4 @@ -from .deepncm_classification_model import DeepNCMClassifier +from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMClassifier", + "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index b942e2891..1619251f9 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -1,20 +1,14 @@ import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional import torch -from tqdm import tqdm import flair -from flair.data import Dictionary, Sentence -from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings -from flair.embeddings.base import load_embeddings -from flair.nn import Classifier log = logging.getLogger("flair") -class DeepNCMClassifier(Classifier[Sentence]): +class DeepNCMDecoder(torch.nn.Module): """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. This model combines deep learning with the Nearest Class Mean (NCM) approach. @@ -32,47 +26,49 @@ class DeepNCMClassifier(Classifier[Sentence]): def __init__( self, - embeddings: DocumentEmbeddings, - label_dictionary: Dictionary, - label_type: str, + num_prototypes: int, + embeddings_size: int, encoding_dim: Optional[int] = None, alpha: float = 0.9, mean_update_method: Literal["online", "condensation", "decay"] = "online", use_encoder: bool = True, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - ): - """Initialize a DeepNCMClassifier. + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. Args: - embeddings: Document embeddings to use for encoding text. - label_dictionary: Dictionary containing the label vocabulary. - label_type: The type of label to predict. encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). use_encoder: Whether to apply an encoder to the input embeddings (default is True). multi_label: Whether to predict multiple labels per sentence (default is False). - multi_label_threshold: The threshold for multi-label prediction (default is 0.5). """ + super().__init__() - self.embeddings = embeddings - self.label_dictionary = label_dictionary - self._label_type = label_type + self.num_classes = num_prototypes + self.alpha = alpha self.mean_update_method = mean_update_method self.use_encoder = use_encoder self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold - self.num_classes = len(label_dictionary) - self.embedding_dim = embeddings.embedding_length + + self.embedding_dim = embeddings_size if use_encoder: self.encoding_dim = encoding_dim or self.embedding_dim else: self.encoding_dim = self.embedding_dim + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(num_prototypes).to(flair.device) + self.to(flair.device) + self._validate_parameters() if self.use_encoder: @@ -84,22 +80,11 @@ def __init__( else: self.encoder = torch.nn.Sequential(torch.nn.Identity()) - self.loss_function = ( - torch.nn.BCEWithLogitsLoss(reduction="sum") - if self.multi_label - else torch.nn.CrossEntropyLoss(reduction="sum") - ) - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False - ) - self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + # all parameters will be pushed internally to the specified device self.to(flair.device) def _validate_parameters(self) -> None: - """Validate the input parameters.""" + """Validate that the input parameters have valid and compatible values.""" assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" assert self.mean_update_method in [ "online", @@ -108,26 +93,13 @@ def _validate_parameters(self) -> None: ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: - """Encode the input sentences using embeddings and optional encoder. - - Args: - sentences: Input sentence or list of sentences. - - Returns: - torch.Tensor: Encoded representations of the input sentences. - """ - if not isinstance(sentences, list): - sentences = [sentences] - - self.embeddings.embed(sentences) - sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) - encoded_embeddings = self.encoder(sentence_embeddings) - - return encoded_embeddings + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate distances between encoded embeddings and class prototypes. + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. Args: encoded_embeddings: Encoded representations of the input sentences. @@ -135,60 +107,7 @@ def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor Returns: torch.Tensor: Distances between encoded embeddings and class prototypes. """ - return torch.cdist(encoded_embeddings, self.class_prototypes) - - def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: - """Compute the loss for a batch of sentences. - - Args: - data_points: A list of sentences. - - Returns: - Tuple[torch.Tensor, int]: The total loss and the number of sentences. - """ - encoded_embeddings = self.forward(data_points) - labels = self._prepare_label_tensor(data_points) - distances = self._calculate_distances(encoded_embeddings) - loss = self.loss_function(-distances, labels) - self._calculate_prototype_updates(encoded_embeddings, labels) - - return loss, len(data_points) - - def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: - """Prepare the label tensor for the given sentences. - - Args: - sentences: A list of sentences. - - Returns: - torch.Tensor: The label tensor for the given sentences. - """ - if self.multi_label: - return torch.tensor( - [ - [ - ( - 1 - if label - in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] - else 0 - ) - for label in self.label_dictionary.get_items() - ] - for sentence in sentences - ], - dtype=torch.float, - device=flair.device, - ) - else: - return torch.tensor( - [ - self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) - for sentence in sentences - ], - dtype=torch.long, - device=flair.device, - ) + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: """Calculate updates for class prototypes based on the current batch. @@ -198,7 +117,7 @@ def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: labels: True labels for the input sentences. """ one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() ) updates = torch.matmul(one_hot.t(), encoded_embeddings) @@ -232,224 +151,20 @@ def update_prototypes(self) -> None: self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) - def predict( - self, - sentences: Union[List[Sentence], Sentence], - mini_batch_size: int = 32, - return_probabilities_for_all_classes: bool = False, - verbose: bool = False, - label_name: Optional[str] = None, - return_loss: bool = False, - embedding_storage_mode: str = "none", - ) -> Union[List[Sentence], Tuple[float, int]]: - """Predict classes for a list of sentences. - - Args: - sentences: A list of sentences or a single sentence. - mini_batch_size: Size of mini batches during prediction. - return_probabilities_for_all_classes: Whether to return probabilities for all classes. - verbose: If True, show progress bar during prediction. - label_name: The name of the label to use for prediction. - return_loss: If True, compute and return loss. - embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). - - Returns: - Union[List[Sentence], Tuple[float, int]]: - if return_loss is True, returns a tuple of total loss and total number of sentences; - otherwise, returns the list of sentences with predicted labels. - """ - with torch.no_grad(): - if not isinstance(sentences, list): - sentences = [sentences] - if not sentences: - return sentences - - label_name = label_name or self.label_type - Sentence.set_context_for_sentences(sentences) + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. - filtered_sentences = [sent for sent in sentences if len(sent) > 0] - reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) - - if len(reordered_sentences) == 0: - return sentences - - dataloader = DataLoader( - dataset=FlairDatapointDataset(reordered_sentences), - batch_size=mini_batch_size, - ) - - if verbose: - progress_bar = tqdm(dataloader) - progress_bar.set_description("Predicting") - dataloader = progress_bar - - total_loss = 0.0 - total_sentences = 0 - - for batch in dataloader: - if not batch: - continue - - encoded_embeddings = self.forward(batch) - distances = self._calculate_distances(encoded_embeddings) - - if self.multi_label: - probabilities = torch.sigmoid(-distances) - else: - probabilities = torch.nn.functional.softmax(-distances, dim=1) - - if return_loss: - labels = self._prepare_label_tensor(batch) - loss = self.loss_function(-distances, labels) - total_loss += loss.item() - total_sentences += len(batch) - - for sentence_index, sentence in enumerate(batch): - sentence.remove_labels(label_name) - - if self.multi_label: - for label_index, probability in enumerate(probabilities[sentence_index]): - if probability > self.multi_label_threshold or return_probabilities_for_all_classes: - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(label_name, label_value, probability.item()) - else: - predicted_idx = torch.argmax(probabilities[sentence_index]) - label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) - sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) - - if return_probabilities_for_all_classes: - for label_index, probability in enumerate(probabilities[sentence_index]): - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(f"{label_name}_all", label_value, probability.item()) - - for sentence in batch: - sentence.clear_embeddings(embedding_storage_mode) - - if return_loss: - return total_loss, total_sentences - return sentences - - def _get_state_dict(self) -> Dict[str, Any]: - """Get the state dictionary of the model. - - Returns: - Dict[str, Any]: The state dictionary containing model parameters and configuration. - """ - model_state = { - "embeddings": self.embeddings.save_embeddings(), - "label_dictionary": self.label_dictionary, - "label_type": self.label_type, - "encoding_dim": self.encoding_dim, - "alpha": self.alpha, - "mean_update_method": self.mean_update_method, - "use_encoder": self.use_encoder, - "multi_label": self.multi_label, - "multi_label_threshold": self.multi_label_threshold, - "class_prototypes": self.class_prototypes.cpu(), - "class_counts": self.class_counts.cpu(), - "encoder": self.encoder.state_dict(), - } - return model_state - - @classmethod - def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": - """Initialize the model from a state dictionary. - - Args: - state: The state dictionary containing model parameters and configuration. - **kwargs: Additional keyword arguments for model initialization. - - Returns: - DeepNCMClassifier: An instance of the model initialized with the given state. + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. """ - embeddings = state["embeddings"] - if isinstance(embeddings, dict): - embeddings = load_embeddings(embeddings) - - model = cls( - embeddings=embeddings, - label_dictionary=state["label_dictionary"], - label_type=state["label_type"], - encoding_dim=state["encoding_dim"], - alpha=state["alpha"], - mean_update_method=state["mean_update_method"], - use_encoder=state["use_encoder"], - multi_label=state.get("multi_label", False), - multi_label_threshold=state.get("multi_label_threshold", 0.5), - **kwargs, - ) - - if "encoder" in state: - model.encoder.load_state_dict(state["encoder"]) - if "class_prototypes" in state: - model.class_prototypes.data = state["class_prototypes"].to(flair.device) - if "class_counts" in state: - model.class_counts.data = state["class_counts"].to(flair.device) + encoded_embeddings = self.encoder(embedded) - return model - - def get_prototype(self, class_name: str) -> torch.Tensor: - """Get the prototype vector for a given class name. - - Args: - class_name: The name of the class whose prototype vector is requested. - - Returns: - torch.Tensor: The prototype vector for the given class. - - Raises: - ValueError: If the class name is not found in the label dictionary. - """ - try: - class_idx = self.label_dictionary.get_idx_for_item(class_name) - except IndexError as exc: - raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc - - return self.class_prototypes[class_idx].clone() - - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: - """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. - - Args: - input_vector (torch.Tensor): The input vector to compare against prototypes. - top_k (int): The number of closest prototypes to return (default is 5). - - Returns: - List[Tuple[str, float]]: Each tuple contains (class_name, distance). - """ - if input_vector.dim() != 1: - raise ValueError("Input vector must be a 1D tensor") - if input_vector.size(0) != self.class_prototypes.size(1): - raise ValueError( - f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" - ) - - input_vector = input_vector.unsqueeze(0) - distances = self._calculate_distances(input_vector) - top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) - - nearest_prototypes = [] - for idx, value in zip(top_k_indices, top_k_values): - class_name = self.label_dictionary.get_item_for_index(idx.item()) - nearest_prototypes.append((class_name, value.item())) - - return nearest_prototypes + distances = self._calculate_distances(encoded_embeddings) - @property - def label_type(self) -> str: - """Get the label type for this classifier.""" - return self._label_type + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) - def __str__(self) -> str: - """Get a string representation of the model. + scores = -distances - Returns: - str: A string describing the model architecture. - """ - return ( - f"DeepNCMClassifier(\n" - f" (embeddings): {self.embeddings}\n" - f" (encoder): {self.encoder}\n" - f" (prototypes): {self.class_prototypes.shape}\n" - f")" - ) + return scores diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a..9d3f9e2f6 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -765,8 +765,11 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # pass data points through network to get encoded data point tensor data_point_tensor = self._encode_data_points(sentences, data_points) - # decode - scores = self.decoder(data_point_tensor) + # decode, passing label tensor if needed, such as for prototype updates + if "label_tensor" in inspect.signature(self.decoder.forward).parameters: + scores = self.decoder(data_point_tensor, label_tensor) + else: + scores = self.decoder(data_point_tensor) # an optional masking step (no masking in most cases) scores = self._mask_scores(scores, data_points) @@ -801,7 +804,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ): + ) -> Optional[Union[List[DT], Tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 2c4c0ccb4..d2481c2ae 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,6 +1,7 @@ import torch -from flair.models import DeepNCMClassifier, MultitaskModel +from flair.models import MultitaskModel +from flair.models.deepncm_classification_model import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin @@ -21,7 +22,7 @@ def _process_models(self, operation: str): models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: - if isinstance(sub_model, DeepNCMClassifier): + if isinstance(sub_model.decoder, DeepNCMDecoder): if operation == "condensation" and sub_model.mean_update_method == "condensation": sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) elif operation == "update":