Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triplet margin for distance functions in TripletEvaluator #2862

Merged
merged 7 commits into from
Nov 26, 2024
79 changes: 56 additions & 23 deletions sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Literal

import numpy as np
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.util import (
pairwise_cos_sim,
pairwise_dot_score,
pairwise_euclidean_sim,
pairwise_manhattan_sim,
)

if TYPE_CHECKING:
from sentence_transformers.SentenceTransformer import SentenceTransformer
Expand All @@ -22,7 +25,7 @@
class TripletEvaluator(SentenceEvaluator):
"""
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
Checks if distance(sentence, positive_example) < distance(sentence, negative_example).
Checks if ``similarity(sentence, positive_example) < similarity(sentence, negative_example) + margin``.

Example:
::
Expand All @@ -47,7 +50,7 @@ class TripletEvaluator(SentenceEvaluator):
results = triplet_evaluator(model)
'''
TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
Accuracy Cosine Distance: 95.60%
Accuracy Cosine Similarity: 95.60%
'''
print(triplet_evaluator.primary_metric)
# => "all_nli_dev_cosine_accuracy"
Expand All @@ -60,13 +63,15 @@ def __init__(
anchors: list[str],
positives: list[str],
negatives: list[str],
main_distance_function: str | SimilarityFunction | None = None,
main_similarity_function: str | SimilarityFunction | None = None,
margin: float | dict[str, float] | None = None,
name: str = "",
batch_size: int = 16,
show_progress_bar: bool = False,
write_csv: bool = True,
truncate_dim: int | None = None,
similarity_fn_names: list[Literal["cosine", "dot", "euclidean", "manhattan"]] | None = None,
main_distance_function: str | SimilarityFunction | None = "deprecated",
):
"""
Initializes a TripletEvaluator object.
Expand All @@ -75,17 +80,22 @@ def __init__(
anchors (List[str]): Sentences to check similarity to. (e.g. a query)
positives (List[str]): List of positive sentences
negatives (List[str]): List of negative sentences
main_distance_function (Union[str, SimilarityFunction], optional):
The distance function to use. If not specified, use cosine similarity,
dot product, Euclidean, and Manhattan. Defaults to None.
main_similarity_function (Union[str, SimilarityFunction], optional):
The similarity function to use. If not specified, use cosine similarity,
dot product, Euclidean, and Manhattan similarity. Defaults to None.
margin (Union[float, Dict[str, float]], optional): Margins for various similarity metrics.
If a float is provided, it will be used as the margin for all similarity metrics.
If a dictionary is provided, the keys should be 'cosine', 'dot', 'manhattan', and 'euclidean'.
The value specifies the minimum margin by which the negative sample should be further from
the anchor than the positive sample. Defaults to None.
name (str): Name for the output. Defaults to "".
batch_size (int): Batch size used to compute embeddings. Defaults to 16.
show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
write_csv (bool): Write results to a CSV file. Defaults to True.
truncate_dim (int, optional): The dimension to truncate sentence embeddings to.
`None` uses the model's current truncation dimension. Defaults to None.
similarity_fn_names (List[str], optional): List of similarity function names to evaluate.
If not specified, evaluate using the ``similarity_fn_name`` .
If not specified, evaluate using the ``model.similarity_fn_name``.
Defaults to None.
"""
super().__init__()
Expand All @@ -98,9 +108,32 @@ def __init__(
assert len(self.anchors) == len(self.positives)
assert len(self.anchors) == len(self.negatives)

self.main_distance_function = SimilarityFunction(main_distance_function) if main_distance_function else None
if main_distance_function != "deprecated" and main_similarity_function is None:
main_similarity_function = main_distance_function
logger.warning(
"The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. "
"'main_distance_function' will be removed in a future release."
)

self.main_similarity_function = (
SimilarityFunction(main_similarity_function) if main_similarity_function else None
)
self.similarity_fn_names = similarity_fn_names or []

if margin is None:
self.margin = {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}
elif isinstance(margin, (float, int)):
self.margin = {"cosine": margin, "dot": margin, "manhattan": margin, "euclidean": margin}
elif isinstance(margin, dict):
self.margin = {
**{"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0},
**margin,
}
else:
raise ValueError(
"`margin` should be a float or a dictionary with keys 'cosine', 'dot', 'manhattan', and 'euclidean'"
)

self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
Expand Down Expand Up @@ -171,20 +204,20 @@ def __call__(

similarity_functions = {
"cosine": lambda anchors, positives, negatives: (
paired_cosine_distances(anchors, positives),
paired_cosine_distances(anchors, negatives),
pairwise_cos_sim(anchors, positives),
pairwise_cos_sim(anchors, negatives),
),
"dot": lambda anchors, positives, negatives: (
np.sum(anchors * positives, axis=-1),
np.sum(anchors * negatives, axis=-1),
pairwise_dot_score(anchors, positives),
pairwise_dot_score(anchors, negatives),
),
"manhattan": lambda anchors, positives, negatives: (
paired_manhattan_distances(anchors, positives),
paired_manhattan_distances(anchors, negatives),
pairwise_manhattan_sim(anchors, positives),
pairwise_manhattan_sim(anchors, negatives),
),
"euclidean": lambda anchors, positives, negatives: (
paired_euclidean_distances(anchors, positives),
paired_euclidean_distances(anchors, negatives),
pairwise_euclidean_sim(anchors, positives),
pairwise_euclidean_sim(anchors, negatives),
),
}

Expand All @@ -194,9 +227,9 @@ def __call__(
positive_scores, negative_scores = similarity_functions[fn_name](
embeddings_anchors, embeddings_positives, embeddings_negatives
)
accuracy = np.mean(positive_scores < negative_scores)
accuracy = (positive_scores > negative_scores + self.margin[fn_name]).float().mean().item()
metrics[f"{fn_name}_accuracy"] = accuracy
logger.info(f"Accuracy {fn_name.capitalize()} Distance:\t{accuracy:.2%}")
logger.info(f"Accuracy {fn_name.capitalize()} Similarity:\t{accuracy:.2%}")

if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
Expand All @@ -214,13 +247,13 @@ def __call__(
if len(self.similarity_fn_names) > 1:
metrics["max_accuracy"] = max(metrics.values())

if self.main_distance_function:
if self.main_similarity_function:
self.primary_metric = {
SimilarityFunction.COSINE: "cosine_accuracy",
SimilarityFunction.DOT_PRODUCT: "dot_accuracy",
SimilarityFunction.EUCLIDEAN: "euclidean_accuracy",
SimilarityFunction.MANHATTAN: "manhattan_accuracy",
}.get(self.main_distance_function)
}.get(self.main_similarity_function)
else:
if len(self.similarity_fn_names) > 1:
self.primary_metric = "max_accuracy"
Expand Down
36 changes: 36 additions & 0 deletions tests/evaluation/test_triplet_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Tests the correct computation of evaluation scores from TripletEvaluator
"""

from __future__ import annotations

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator


def test_TripletEvaluator(stsb_bert_tiny_model_reused: SentenceTransformer) -> None:
"""Tests that the TripletEvaluator can be loaded & used"""
model = stsb_bert_tiny_model_reused
anchors = [
"A person on a horse jumps over a broken down airplane.",
"Children smiling and waving at camera",
"A boy is jumping on skateboard in the middle of a red bridge.",
]
positives = [
"A person is outdoors, on a horse.",
"There are children looking at the camera.",
"The boy does a skateboarding trick.",
]
negatives = [
"A person is at a diner, ordering an omelette.",
"The kids are frowning",
"The boy skates down the sidewalk.",
]
evaluator = TripletEvaluator(anchors, positives, negatives, name="all_nli_dev")
metrics = evaluator(model)
assert evaluator.primary_metric == "all_nli_dev_cosine_accuracy"
assert metrics[evaluator.primary_metric] == 1.0

evaluator_with_margin = TripletEvaluator(anchors, positives, negatives, margin=0.7, name="all_nli_dev")
metrics = evaluator_with_margin(model)
assert metrics[evaluator.primary_metric] == 0.0