diff --git a/sentence_transformers/evaluation/TripletEvaluator.py b/sentence_transformers/evaluation/TripletEvaluator.py index 082e51588..ff17e89e9 100644 --- a/sentence_transformers/evaluation/TripletEvaluator.py +++ b/sentence_transformers/evaluation/TripletEvaluator.py @@ -6,12 +6,15 @@ from contextlib import nullcontext from typing import TYPE_CHECKING, Literal -import numpy as np -from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances - from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator from sentence_transformers.readers import InputExample from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.util import ( + pairwise_cos_sim, + pairwise_dot_score, + pairwise_euclidean_sim, + pairwise_manhattan_sim, +) if TYPE_CHECKING: from sentence_transformers.SentenceTransformer import SentenceTransformer @@ -22,7 +25,7 @@ class TripletEvaluator(SentenceEvaluator): """ Evaluate a model based on a triplet: (sentence, positive_example, negative_example). - Checks if distance(sentence, positive_example) < distance(sentence, negative_example). + Checks if ``similarity(sentence, positive_example) < similarity(sentence, negative_example) + margin``. Example: :: @@ -47,7 +50,7 @@ class TripletEvaluator(SentenceEvaluator): results = triplet_evaluator(model) ''' TripletEvaluator: Evaluating the model on the all-nli-dev dataset: - Accuracy Cosine Distance: 95.60% + Accuracy Cosine Similarity: 95.60% ''' print(triplet_evaluator.primary_metric) # => "all_nli_dev_cosine_accuracy" @@ -60,13 +63,15 @@ def __init__( anchors: list[str], positives: list[str], negatives: list[str], - main_distance_function: str | SimilarityFunction | None = None, + main_similarity_function: str | SimilarityFunction | None = None, + margin: float | dict[str, float] | None = None, name: str = "", batch_size: int = 16, show_progress_bar: bool = False, write_csv: bool = True, truncate_dim: int | None = None, similarity_fn_names: list[Literal["cosine", "dot", "euclidean", "manhattan"]] | None = None, + main_distance_function: str | SimilarityFunction | None = "deprecated", ): """ Initializes a TripletEvaluator object. @@ -75,9 +80,14 @@ def __init__( anchors (List[str]): Sentences to check similarity to. (e.g. a query) positives (List[str]): List of positive sentences negatives (List[str]): List of negative sentences - main_distance_function (Union[str, SimilarityFunction], optional): - The distance function to use. If not specified, use cosine similarity, - dot product, Euclidean, and Manhattan. Defaults to None. + main_similarity_function (Union[str, SimilarityFunction], optional): + The similarity function to use. If not specified, use cosine similarity, + dot product, Euclidean, and Manhattan similarity. Defaults to None. + margin (Union[float, Dict[str, float]], optional): Margins for various similarity metrics. + If a float is provided, it will be used as the margin for all similarity metrics. + If a dictionary is provided, the keys should be 'cosine', 'dot', 'manhattan', and 'euclidean'. + The value specifies the minimum margin by which the negative sample should be further from + the anchor than the positive sample. Defaults to None. name (str): Name for the output. Defaults to "". batch_size (int): Batch size used to compute embeddings. Defaults to 16. show_progress_bar (bool): If true, prints a progress bar. Defaults to False. @@ -85,7 +95,7 @@ def __init__( truncate_dim (int, optional): The dimension to truncate sentence embeddings to. `None` uses the model's current truncation dimension. Defaults to None. similarity_fn_names (List[str], optional): List of similarity function names to evaluate. - If not specified, evaluate using the ``similarity_fn_name`` . + If not specified, evaluate using the ``model.similarity_fn_name``. Defaults to None. """ super().__init__() @@ -98,9 +108,32 @@ def __init__( assert len(self.anchors) == len(self.positives) assert len(self.anchors) == len(self.negatives) - self.main_distance_function = SimilarityFunction(main_distance_function) if main_distance_function else None + if main_distance_function != "deprecated" and main_similarity_function is None: + main_similarity_function = main_distance_function + logger.warning( + "The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. " + "'main_distance_function' will be removed in a future release." + ) + + self.main_similarity_function = ( + SimilarityFunction(main_similarity_function) if main_similarity_function else None + ) self.similarity_fn_names = similarity_fn_names or [] + if margin is None: + self.margin = {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0} + elif isinstance(margin, (float, int)): + self.margin = {"cosine": margin, "dot": margin, "manhattan": margin, "euclidean": margin} + elif isinstance(margin, dict): + self.margin = { + **{"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}, + **margin, + } + else: + raise ValueError( + "`margin` should be a float or a dictionary with keys 'cosine', 'dot', 'manhattan', and 'euclidean'" + ) + self.batch_size = batch_size if show_progress_bar is None: show_progress_bar = ( @@ -171,20 +204,20 @@ def __call__( similarity_functions = { "cosine": lambda anchors, positives, negatives: ( - paired_cosine_distances(anchors, positives), - paired_cosine_distances(anchors, negatives), + pairwise_cos_sim(anchors, positives), + pairwise_cos_sim(anchors, negatives), ), "dot": lambda anchors, positives, negatives: ( - np.sum(anchors * positives, axis=-1), - np.sum(anchors * negatives, axis=-1), + pairwise_dot_score(anchors, positives), + pairwise_dot_score(anchors, negatives), ), "manhattan": lambda anchors, positives, negatives: ( - paired_manhattan_distances(anchors, positives), - paired_manhattan_distances(anchors, negatives), + pairwise_manhattan_sim(anchors, positives), + pairwise_manhattan_sim(anchors, negatives), ), "euclidean": lambda anchors, positives, negatives: ( - paired_euclidean_distances(anchors, positives), - paired_euclidean_distances(anchors, negatives), + pairwise_euclidean_sim(anchors, positives), + pairwise_euclidean_sim(anchors, negatives), ), } @@ -194,9 +227,9 @@ def __call__( positive_scores, negative_scores = similarity_functions[fn_name]( embeddings_anchors, embeddings_positives, embeddings_negatives ) - accuracy = np.mean(positive_scores < negative_scores) + accuracy = (positive_scores > negative_scores + self.margin[fn_name]).float().mean().item() metrics[f"{fn_name}_accuracy"] = accuracy - logger.info(f"Accuracy {fn_name.capitalize()} Distance:\t{accuracy:.2%}") + logger.info(f"Accuracy {fn_name.capitalize()} Similarity:\t{accuracy:.2%}") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) @@ -214,13 +247,13 @@ def __call__( if len(self.similarity_fn_names) > 1: metrics["max_accuracy"] = max(metrics.values()) - if self.main_distance_function: + if self.main_similarity_function: self.primary_metric = { SimilarityFunction.COSINE: "cosine_accuracy", SimilarityFunction.DOT_PRODUCT: "dot_accuracy", SimilarityFunction.EUCLIDEAN: "euclidean_accuracy", SimilarityFunction.MANHATTAN: "manhattan_accuracy", - }.get(self.main_distance_function) + }.get(self.main_similarity_function) else: if len(self.similarity_fn_names) > 1: self.primary_metric = "max_accuracy" diff --git a/tests/evaluation/test_triplet_evaluator.py b/tests/evaluation/test_triplet_evaluator.py new file mode 100644 index 000000000..50bdc84bd --- /dev/null +++ b/tests/evaluation/test_triplet_evaluator.py @@ -0,0 +1,36 @@ +""" +Tests the correct computation of evaluation scores from TripletEvaluator +""" + +from __future__ import annotations + +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import TripletEvaluator + + +def test_TripletEvaluator(stsb_bert_tiny_model_reused: SentenceTransformer) -> None: + """Tests that the TripletEvaluator can be loaded & used""" + model = stsb_bert_tiny_model_reused + anchors = [ + "A person on a horse jumps over a broken down airplane.", + "Children smiling and waving at camera", + "A boy is jumping on skateboard in the middle of a red bridge.", + ] + positives = [ + "A person is outdoors, on a horse.", + "There are children looking at the camera.", + "The boy does a skateboarding trick.", + ] + negatives = [ + "A person is at a diner, ordering an omelette.", + "The kids are frowning", + "The boy skates down the sidewalk.", + ] + evaluator = TripletEvaluator(anchors, positives, negatives, name="all_nli_dev") + metrics = evaluator(model) + assert evaluator.primary_metric == "all_nli_dev_cosine_accuracy" + assert metrics[evaluator.primary_metric] == 1.0 + + evaluator_with_margin = TripletEvaluator(anchors, positives, negatives, margin=0.7, name="all_nli_dev") + metrics = evaluator_with_margin(model) + assert metrics[evaluator.primary_metric] == 0.0