Skip to content

Commit

Permalink
Move prompts to SentenceTransformersArguments
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 6, 2024
1 parent e3b334c commit 22fc64a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 27 deletions.
42 changes: 18 additions & 24 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,6 @@ class SentenceTransformerTrainer(Trainer):
or a dictionary mapping dataset names to functions that return a loss class instance given a model.
In practice, the latter two are primarily used for hyper-parameter optimization. Will default to
:class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
prompts (Union[Dict[str, Dict[str, str]], Dict[str, str], str], *optional*):
The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted:
1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test
datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to
prompts. This should only be used if your training/evaluation/test datasets are a
:class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
Expand Down Expand Up @@ -141,7 +128,6 @@ def __init__(
| Callable[[SentenceTransformer], torch.nn.Module]
| dict[str, Callable[[SentenceTransformer], torch.nn.Module]]
| None = None,
prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None,
evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None,
data_collator: DataCollator | None = None,
tokenizer: PreTrainedTokenizerBase | Callable | None = None,
Expand Down Expand Up @@ -248,7 +234,6 @@ def __init__(
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True

self.prompts = prompts
self._prompt_length_mapping = {}

self.model: SentenceTransformer
Expand Down Expand Up @@ -285,11 +270,13 @@ def __init__(
self.evaluator = evaluator

if self.train_dataset is not None:
self.validate_column_names(self.train_dataset, dataset_name="train")
self.train_dataset = self.maybe_add_prompts_or_dataset_name_column(train_dataset)
self.train_dataset = self.maybe_add_prompts_or_dataset_name_column(
train_dataset, args.prompts, dataset_name="train"
)
if self.eval_dataset is not None:
self.validate_column_names(self.eval_dataset, dataset_name="eval")
self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column(eval_dataset)
self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column(
eval_dataset, args.prompts, dataset_name="eval"
)

# Add a callback responsible for automatically tracking data required for the automatic model card generation
model_card_callback = ModelCardCallback(self, default_args_dict)
Expand Down Expand Up @@ -438,8 +425,9 @@ def evaluate(
metric_key_prefix: str = "eval",
) -> dict[str, float]:
if eval_dataset:
self.validate_column_names(eval_dataset, dataset_name="eval")
eval_dataset = self.maybe_add_prompts_or_dataset_name_column(eval_dataset)
eval_dataset = self.maybe_add_prompts_or_dataset_name_column(
eval_dataset, self.args.prompts, dataset_name="eval"
)
else:
eval_dataset = self.eval_dataset
return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
Expand Down Expand Up @@ -995,7 +983,10 @@ def add_prompts_or_dataset_name_transform(
return batch

def maybe_add_prompts_or_dataset_name_column(
self, dataset_dict: DatasetDict | Dataset | None
self,
dataset_dict: DatasetDict | Dataset | None,
prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None,
dataset_name: str | None = None,
) -> DatasetDict | Dataset | None:
"""
Maybe add prompts or dataset names to the dataset. We add the dataset_name column to the dataset if:
Expand Down Expand Up @@ -1032,12 +1023,15 @@ def maybe_add_prompts_or_dataset_name_column(
if hasattr(dataset_dict, "_sentence_transformers_preprocessed"):
return dataset_dict

# Ensure that there's no "dataset_name"/"return_loss" columns in the unprocessed datasets
self.validate_column_names(dataset_dict, dataset_name=dataset_name)

# Only add if 1) we have prompts or 2) we need the dataset name for the loss dictionary
if self.prompts or include_dataset_name:
if prompts or include_dataset_name:
include_prompt_lengths = self._include_prompt_length()
dataset_dict = self.add_prompts_or_dataset_name_column(
dataset_dict,
prompts=self.prompts,
prompts=prompts,
include_prompt_lengths=include_prompt_lengths,
)
return dataset_dict
Expand Down
14 changes: 14 additions & 0 deletions sentence_transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
Args:
output_dir (`str`):
The output directory where the model checkpoints will be written.
prompts (Union[Dict[str, Dict[str, str]], Dict[str, str], str], *optional*):
The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted:
1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test
datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to
prompts. This should only be used if your training/evaluation/test datasets are a
:class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*):
The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options.
Defaults to ``BatchSamplers.BATCH_SAMPLER``.
Expand All @@ -157,6 +170,7 @@ class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
"""

prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None
batch_sampler: BatchSamplers | str = field(
default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def upper_transform(batch):
with tempfile.TemporaryDirectory() as temp_dir:
args = SentenceTransformerTrainingArguments(
output_dir=str(temp_dir),
prompts=prompts,
max_steps=2,
eval_steps=2,
eval_strategy="steps",
Expand All @@ -464,17 +465,16 @@ def upper_transform(batch):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
prompts=prompts,
)
if not isinstance(context, nullcontext):
return

datacollator_keys = set()
old_compute_loss = trainer.compute_loss

def compute_loss_tracker(model, inputs, return_outputs=False):
def compute_loss_tracker(model, inputs, **kwargs):
datacollator_keys.update(set(inputs.keys()))
loss = old_compute_loss(model, inputs, return_outputs)
loss = old_compute_loss(model, inputs, **kwargs)
return loss

trainer.compute_loss = compute_loss_tracker
Expand Down

0 comments on commit 22fc64a

Please sign in to comment.