From 0f445834d07221c6f050d62e9e060fe99ead5c36 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:49:54 +0100 Subject: [PATCH] [`training`] Pass `steps`/`epoch`/`output_path` to Evaluator during training (#3066) Both in fit_mixin and the trainer Plus fixing 2 evaluators --- .../evaluation/EmbeddingSimilarityEvaluator.py | 4 ++-- .../evaluation/NanoBEIREvaluator.py | 12 ++++++------ sentence_transformers/fit_mixin.py | 13 ++++++++++--- sentence_transformers/trainer.py | 8 +++++++- .../test_information_retrieval_evaluator.py | 9 +++++---- tests/evaluation/test_label_accuracy_evaluator.py | 5 +++-- .../evaluation/test_paraphrase_mining_evaluator.py | 8 ++++++-- 7 files changed, 39 insertions(+), 20 deletions(-) diff --git a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py index 9b271ac7d..459ead3cc 100644 --- a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py +++ b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py @@ -225,9 +225,9 @@ def __call__( steps, ] + [ - metrics[f"{fn_name}_{m}"] + metrics[f"{metric}_{fn_name}"] for fn_name in self.similarity_fn_names - for m in ["pearson", "spearman"] + for metric in ["pearson", "spearman"] ] ) diff --git a/sentence_transformers/evaluation/NanoBEIREvaluator.py b/sentence_transformers/evaluation/NanoBEIREvaluator.py index 9b5669cb6..ec29ad094 100644 --- a/sentence_transformers/evaluation/NanoBEIREvaluator.py +++ b/sentence_transformers/evaluation/NanoBEIREvaluator.py @@ -326,20 +326,20 @@ def __call__( output_data = [epoch, steps] for name in self.score_function_names: for k in self.accuracy_at_k: - output_data.append(per_dataset_results[name]["accuracy@k"][k]) + output_data.append(agg_results[f"{name}_accuracy@{k}"]) for k in self.precision_recall_at_k: - output_data.append(per_dataset_results[name]["precision@k"][k]) - output_data.append(per_dataset_results[name]["recall@k"][k]) + output_data.append(agg_results[f"{name}_precision@{k}"]) + output_data.append(agg_results[f"{name}_recall@{k}"]) for k in self.mrr_at_k: - output_data.append(per_dataset_results[name]["mrr@k"][k]) + output_data.append(agg_results[f"{name}_mrr@{k}"]) for k in self.ndcg_at_k: - output_data.append(per_dataset_results[name]["ndcg@k"][k]) + output_data.append(agg_results[f"{name}_ndcg@{k}"]) for k in self.map_at_k: - output_data.append(per_dataset_results[name]["map@k"][k]) + output_data.append(agg_results[f"{name}_map@{k}"]) fOut.write(",".join(map(str, output_data))) fOut.write("\n") diff --git a/sentence_transformers/fit_mixin.py b/sentence_transformers/fit_mixin.py index e4cc4b046..230251511 100644 --- a/sentence_transformers/fit_mixin.py +++ b/sentence_transformers/fit_mixin.py @@ -100,9 +100,14 @@ class EvaluatorCallback(TrainerCallback): The `.trainer` must be provided after the trainer has been created. """ - def __init__(self, evaluator: SentenceEvaluator) -> None: + def __init__(self, evaluator: SentenceEvaluator, output_path: str | None = None) -> None: super().__init__() self.evaluator = evaluator + self.output_path = output_path + if self.output_path is not None: + self.output_path = os.path.join(self.output_path, "eval") + os.makedirs(self.output_path, exist_ok=True) + self.metric_key_prefix = "eval" self.trainer = None @@ -114,7 +119,9 @@ def on_epoch_end( model: SentenceTransformer, **kwargs, ) -> None: - evaluator_metrics = self.evaluator(model, epoch=state.epoch) + evaluator_metrics = self.evaluator( + model, output_path=self.output_path, epoch=state.epoch, steps=state.global_step + ) if not isinstance(evaluator_metrics, dict): evaluator_metrics = {"evaluator": evaluator_metrics} @@ -353,7 +360,7 @@ def _default_checkpoint_dir() -> str: # Create callbacks callbacks = [] if evaluator is not None: - callbacks.append(EvaluatorCallback(evaluator)) + callbacks.append(EvaluatorCallback(evaluator, output_path)) if callback is not None: callbacks.append(OriginalCallback(callback, evaluator)) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index ef36d922f..182908656 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -478,7 +478,13 @@ def evaluation_loop( return output with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO): - evaluator_metrics = self.evaluator(self.model) + output_path = self.args.output_dir + if output_path is not None: + output_path = os.path.join(output_path, "eval") + os.makedirs(output_path, exist_ok=True) + evaluator_metrics = self.evaluator( + self.model, output_path=output_path, epoch=self.state.epoch, steps=self.state.global_step + ) if not isinstance(evaluator_metrics, dict): evaluator_metrics = {"evaluator": evaluator_metrics} diff --git a/tests/evaluation/test_information_retrieval_evaluator.py b/tests/evaluation/test_information_retrieval_evaluator.py index 77745e033..b875313d0 100644 --- a/tests/evaluation/test_information_retrieval_evaluator.py +++ b/tests/evaluation/test_information_retrieval_evaluator.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from unittest.mock import Mock, PropertyMock import pytest @@ -63,7 +64,7 @@ def test_data(): return queries, corpus, relevant_docs -def test_simple(test_data): +def test_simple(test_data, tmp_path: Path): queries, corpus, relevant_docs = test_data model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") @@ -78,7 +79,7 @@ def test_simple(test_data): ndcg_at_k=[3], map_at_k=[5], ) - results = ir_evaluator(model) + results = ir_evaluator(model, output_path=str(tmp_path)) expected_keys = [ "test_cosine_accuracy@1", "test_cosine_accuracy@3", @@ -93,7 +94,7 @@ def test_simple(test_data): assert set(results.keys()) == set(expected_keys) -def test_metrices(test_data, mock_model): +def test_metrices(test_data, mock_model, tmp_path: Path): queries, corpus, relevant_docs = test_data ir_evaluator = InformationRetrievalEvaluator( @@ -107,7 +108,7 @@ def test_metrices(test_data, mock_model): ndcg_at_k=[3], map_at_k=[5], ) - results = ir_evaluator(mock_model) + results = ir_evaluator(mock_model, output_path=str(tmp_path)) # We expect test_cosine_precision@3 to be 0.4, since 6 out of 15 (5 queries * 3) are True Positives # We expect test_cosine_recall@1 to be 0.9; the average of 4 times a recall of 1 and once a recall of 0.5 expected_results = { diff --git a/tests/evaluation/test_label_accuracy_evaluator.py b/tests/evaluation/test_label_accuracy_evaluator.py index d753c7dcf..b431829ac 100644 --- a/tests/evaluation/test_label_accuracy_evaluator.py +++ b/tests/evaluation/test_label_accuracy_evaluator.py @@ -7,6 +7,7 @@ import csv import gzip import os +from pathlib import Path from torch.utils.data import DataLoader @@ -19,7 +20,7 @@ ) -def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None: +def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer, tmp_path: Path) -> None: """Tests that the LabelAccuracyEvaluator can be loaded correctly""" model = paraphrase_distilroberta_base_v1_model nli_dataset_path = "datasets/AllNLI.tsv.gz" @@ -45,6 +46,6 @@ def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: Sentence dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16) evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss) - metrics = evaluator(model) + metrics = evaluator(model, output_path=str(tmp_path)) assert "accuracy" in metrics assert metrics["accuracy"] > 0.2 diff --git a/tests/evaluation/test_paraphrase_mining_evaluator.py b/tests/evaluation/test_paraphrase_mining_evaluator.py index bc14317c3..3941c6074 100644 --- a/tests/evaluation/test_paraphrase_mining_evaluator.py +++ b/tests/evaluation/test_paraphrase_mining_evaluator.py @@ -4,13 +4,17 @@ from __future__ import annotations +from pathlib import Path + from sentence_transformers import ( SentenceTransformer, evaluation, ) -def test_ParaphraseMiningEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None: +def test_ParaphraseMiningEvaluator( + paraphrase_distilroberta_base_v1_model: SentenceTransformer, tmp_path: Path +) -> None: """Tests that the ParaphraseMiningEvaluator can be loaded""" model = paraphrase_distilroberta_base_v1_model sentences = { @@ -20,5 +24,5 @@ def test_ParaphraseMiningEvaluator(paraphrase_distilroberta_base_v1_model: Sente 3: "On the table the cat is", } data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0, 1), (2, 3)]) - metrics = data_eval(model) + metrics = data_eval(model, output_path=str(tmp_path)) assert metrics[data_eval.primary_metric] > 0.99