From 22fc64a0d554110e9f2b96b0b68ffe644e0ef383 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 6 Nov 2024 12:47:29 +0100 Subject: [PATCH] Move prompts to SentenceTransformersArguments --- sentence_transformers/trainer.py | 42 +++++++++++--------------- sentence_transformers/training_args.py | 14 +++++++++ tests/test_trainer.py | 6 ++-- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index c815620ad..a4e9a2602 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -84,19 +84,6 @@ class SentenceTransformerTrainer(Trainer): or a dictionary mapping dataset names to functions that return a loss class instance given a model. In practice, the latter two are primarily used for hyper-parameter optimization. Will default to :class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided. - prompts (Union[Dict[str, Dict[str, str]], Dict[str, str], str], *optional*): - The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted: - - 1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test - datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. - 2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test - datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. - 3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test - datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. - 4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to - prompts. This should only be used if your training/evaluation/test datasets are a - :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. - evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\ List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*): The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with @@ -141,7 +128,6 @@ def __init__( | Callable[[SentenceTransformer], torch.nn.Module] | dict[str, Callable[[SentenceTransformer], torch.nn.Module]] | None = None, - prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None, evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None, data_collator: DataCollator | None = None, tokenizer: PreTrainedTokenizerBase | Callable | None = None, @@ -248,7 +234,6 @@ def __init__( # to avoid having to specify it in the data collator or model's forward self.can_return_loss = True - self.prompts = prompts self._prompt_length_mapping = {} self.model: SentenceTransformer @@ -285,11 +270,13 @@ def __init__( self.evaluator = evaluator if self.train_dataset is not None: - self.validate_column_names(self.train_dataset, dataset_name="train") - self.train_dataset = self.maybe_add_prompts_or_dataset_name_column(train_dataset) + self.train_dataset = self.maybe_add_prompts_or_dataset_name_column( + train_dataset, args.prompts, dataset_name="train" + ) if self.eval_dataset is not None: - self.validate_column_names(self.eval_dataset, dataset_name="eval") - self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column(eval_dataset) + self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( + eval_dataset, args.prompts, dataset_name="eval" + ) # Add a callback responsible for automatically tracking data required for the automatic model card generation model_card_callback = ModelCardCallback(self, default_args_dict) @@ -438,8 +425,9 @@ def evaluate( metric_key_prefix: str = "eval", ) -> dict[str, float]: if eval_dataset: - self.validate_column_names(eval_dataset, dataset_name="eval") - eval_dataset = self.maybe_add_prompts_or_dataset_name_column(eval_dataset) + eval_dataset = self.maybe_add_prompts_or_dataset_name_column( + eval_dataset, self.args.prompts, dataset_name="eval" + ) else: eval_dataset = self.eval_dataset return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) @@ -995,7 +983,10 @@ def add_prompts_or_dataset_name_transform( return batch def maybe_add_prompts_or_dataset_name_column( - self, dataset_dict: DatasetDict | Dataset | None + self, + dataset_dict: DatasetDict | Dataset | None, + prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None, + dataset_name: str | None = None, ) -> DatasetDict | Dataset | None: """ Maybe add prompts or dataset names to the dataset. We add the dataset_name column to the dataset if: @@ -1032,12 +1023,15 @@ def maybe_add_prompts_or_dataset_name_column( if hasattr(dataset_dict, "_sentence_transformers_preprocessed"): return dataset_dict + # Ensure that there's no "dataset_name"/"return_loss" columns in the unprocessed datasets + self.validate_column_names(dataset_dict, dataset_name=dataset_name) + # Only add if 1) we have prompts or 2) we need the dataset name for the loss dictionary - if self.prompts or include_dataset_name: + if prompts or include_dataset_name: include_prompt_lengths = self._include_prompt_length() dataset_dict = self.add_prompts_or_dataset_name_column( dataset_dict, - prompts=self.prompts, + prompts=prompts, include_prompt_lengths=include_prompt_lengths, ) return dataset_dict diff --git a/sentence_transformers/training_args.py b/sentence_transformers/training_args.py index ca1e15da8..c783892b2 100644 --- a/sentence_transformers/training_args.py +++ b/sentence_transformers/training_args.py @@ -149,6 +149,19 @@ class SentenceTransformerTrainingArguments(TransformersTrainingArguments): Args: output_dir (`str`): The output directory where the model checkpoints will be written. + prompts (Union[Dict[str, Dict[str, str]], Dict[str, str], str], *optional*): + The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted: + + 1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test + datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. + 2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test + datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. + 3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test + datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. + 4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to + prompts. This should only be used if your training/evaluation/test datasets are a + :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. + batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*): The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options. Defaults to ``BatchSamplers.BATCH_SAMPLER``. @@ -157,6 +170,7 @@ class SentenceTransformerTrainingArguments(TransformersTrainingArguments): for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``. """ + prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None batch_sampler: BatchSamplers | str = field( default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."} ) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 74e29945f..31ca93228 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -450,6 +450,7 @@ def upper_transform(batch): with tempfile.TemporaryDirectory() as temp_dir: args = SentenceTransformerTrainingArguments( output_dir=str(temp_dir), + prompts=prompts, max_steps=2, eval_steps=2, eval_strategy="steps", @@ -464,7 +465,6 @@ def upper_transform(batch): train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, - prompts=prompts, ) if not isinstance(context, nullcontext): return @@ -472,9 +472,9 @@ def upper_transform(batch): datacollator_keys = set() old_compute_loss = trainer.compute_loss - def compute_loss_tracker(model, inputs, return_outputs=False): + def compute_loss_tracker(model, inputs, **kwargs): datacollator_keys.update(set(inputs.keys())) - loss = old_compute_loss(model, inputs, return_outputs) + loss = old_compute_loss(model, inputs, **kwargs) return loss trainer.compute_loss = compute_loss_tracker