Skip to content

Commit

Permalink
[training] Pass steps/epoch/output_path to Evaluator during t…
Browse files Browse the repository at this point in the history
…raining (#3066)

Both in fit_mixin and the trainer
Plus fixing 2 evaluators
  • Loading branch information
tomaarsen authored Nov 19, 2024
1 parent e156f38 commit 0f44583
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def __call__(
steps,
]
+ [
metrics[f"{fn_name}_{m}"]
metrics[f"{metric}_{fn_name}"]
for fn_name in self.similarity_fn_names
for m in ["pearson", "spearman"]
for metric in ["pearson", "spearman"]
]
)

Expand Down
12 changes: 6 additions & 6 deletions sentence_transformers/evaluation/NanoBEIREvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,20 +326,20 @@ def __call__(
output_data = [epoch, steps]
for name in self.score_function_names:
for k in self.accuracy_at_k:
output_data.append(per_dataset_results[name]["accuracy@k"][k])
output_data.append(agg_results[f"{name}_accuracy@{k}"])

for k in self.precision_recall_at_k:
output_data.append(per_dataset_results[name]["precision@k"][k])
output_data.append(per_dataset_results[name]["recall@k"][k])
output_data.append(agg_results[f"{name}_precision@{k}"])
output_data.append(agg_results[f"{name}_recall@{k}"])

for k in self.mrr_at_k:
output_data.append(per_dataset_results[name]["mrr@k"][k])
output_data.append(agg_results[f"{name}_mrr@{k}"])

for k in self.ndcg_at_k:
output_data.append(per_dataset_results[name]["ndcg@k"][k])
output_data.append(agg_results[f"{name}_ndcg@{k}"])

for k in self.map_at_k:
output_data.append(per_dataset_results[name]["map@k"][k])
output_data.append(agg_results[f"{name}_map@{k}"])

fOut.write(",".join(map(str, output_data)))
fOut.write("\n")
Expand Down
13 changes: 10 additions & 3 deletions sentence_transformers/fit_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ class EvaluatorCallback(TrainerCallback):
The `.trainer` must be provided after the trainer has been created.
"""

def __init__(self, evaluator: SentenceEvaluator) -> None:
def __init__(self, evaluator: SentenceEvaluator, output_path: str | None = None) -> None:
super().__init__()
self.evaluator = evaluator
self.output_path = output_path
if self.output_path is not None:
self.output_path = os.path.join(self.output_path, "eval")
os.makedirs(self.output_path, exist_ok=True)

self.metric_key_prefix = "eval"
self.trainer = None

Expand All @@ -114,7 +119,9 @@ def on_epoch_end(
model: SentenceTransformer,
**kwargs,
) -> None:
evaluator_metrics = self.evaluator(model, epoch=state.epoch)
evaluator_metrics = self.evaluator(
model, output_path=self.output_path, epoch=state.epoch, steps=state.global_step
)
if not isinstance(evaluator_metrics, dict):
evaluator_metrics = {"evaluator": evaluator_metrics}

Expand Down Expand Up @@ -353,7 +360,7 @@ def _default_checkpoint_dir() -> str:
# Create callbacks
callbacks = []
if evaluator is not None:
callbacks.append(EvaluatorCallback(evaluator))
callbacks.append(EvaluatorCallback(evaluator, output_path))
if callback is not None:
callbacks.append(OriginalCallback(callback, evaluator))

Expand Down
8 changes: 7 additions & 1 deletion sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,13 @@ def evaluation_loop(
return output

with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
evaluator_metrics = self.evaluator(self.model)
output_path = self.args.output_dir
if output_path is not None:
output_path = os.path.join(output_path, "eval")
os.makedirs(output_path, exist_ok=True)
evaluator_metrics = self.evaluator(
self.model, output_path=output_path, epoch=self.state.epoch, steps=self.state.global_step
)
if not isinstance(evaluator_metrics, dict):
evaluator_metrics = {"evaluator": evaluator_metrics}

Expand Down
9 changes: 5 additions & 4 deletions tests/evaluation/test_information_retrieval_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from pathlib import Path
from unittest.mock import Mock, PropertyMock

import pytest
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_data():
return queries, corpus, relevant_docs


def test_simple(test_data):
def test_simple(test_data, tmp_path: Path):
queries, corpus, relevant_docs = test_data
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

Expand All @@ -78,7 +79,7 @@ def test_simple(test_data):
ndcg_at_k=[3],
map_at_k=[5],
)
results = ir_evaluator(model)
results = ir_evaluator(model, output_path=str(tmp_path))
expected_keys = [
"test_cosine_accuracy@1",
"test_cosine_accuracy@3",
Expand All @@ -93,7 +94,7 @@ def test_simple(test_data):
assert set(results.keys()) == set(expected_keys)


def test_metrices(test_data, mock_model):
def test_metrices(test_data, mock_model, tmp_path: Path):
queries, corpus, relevant_docs = test_data

ir_evaluator = InformationRetrievalEvaluator(
Expand All @@ -107,7 +108,7 @@ def test_metrices(test_data, mock_model):
ndcg_at_k=[3],
map_at_k=[5],
)
results = ir_evaluator(mock_model)
results = ir_evaluator(mock_model, output_path=str(tmp_path))
# We expect test_cosine_precision@3 to be 0.4, since 6 out of 15 (5 queries * 3) are True Positives
# We expect test_cosine_recall@1 to be 0.9; the average of 4 times a recall of 1 and once a recall of 0.5
expected_results = {
Expand Down
5 changes: 3 additions & 2 deletions tests/evaluation/test_label_accuracy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import csv
import gzip
import os
from pathlib import Path

from torch.utils.data import DataLoader

Expand All @@ -19,7 +20,7 @@
)


def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None:
def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer, tmp_path: Path) -> None:
"""Tests that the LabelAccuracyEvaluator can be loaded correctly"""
model = paraphrase_distilroberta_base_v1_model
nli_dataset_path = "datasets/AllNLI.tsv.gz"
Expand All @@ -45,6 +46,6 @@ def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: Sentence

dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16)
evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss)
metrics = evaluator(model)
metrics = evaluator(model, output_path=str(tmp_path))
assert "accuracy" in metrics
assert metrics["accuracy"] > 0.2
8 changes: 6 additions & 2 deletions tests/evaluation/test_paraphrase_mining_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

from __future__ import annotations

from pathlib import Path

from sentence_transformers import (
SentenceTransformer,
evaluation,
)


def test_ParaphraseMiningEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None:
def test_ParaphraseMiningEvaluator(
paraphrase_distilroberta_base_v1_model: SentenceTransformer, tmp_path: Path
) -> None:
"""Tests that the ParaphraseMiningEvaluator can be loaded"""
model = paraphrase_distilroberta_base_v1_model
sentences = {
Expand All @@ -20,5 +24,5 @@ def test_ParaphraseMiningEvaluator(paraphrase_distilroberta_base_v1_model: Sente
3: "On the table the cat is",
}
data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0, 1), (2, 3)])
metrics = data_eval(model)
metrics = data_eval(model, output_path=str(tmp_path))
assert metrics[data_eval.primary_metric] > 0.99

0 comments on commit 0f44583

Please sign in to comment.