Skip to content

Commit

Permalink
Only include dataset_name if strictly needed, stricter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 6, 2024
1 parent 22fc64a commit 50e1613
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
5 changes: 4 additions & 1 deletion sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def maybe_add_prompts_or_dataset_name_column(
dataset_dict,
prompts=prompts,
include_prompt_lengths=include_prompt_lengths,
include_dataset_name=include_dataset_name,
)
return dataset_dict

Expand All @@ -1042,6 +1043,7 @@ def add_prompts_or_dataset_name_column(
prompts: dict[str, str] | str | None = None,
dataset_name: str | None = None,
include_prompt_lengths: bool = False,
include_dataset_name: bool = False,
) -> DatasetDict | Dataset | None:
# If we have DatasetDict, recurse
if isinstance(dataset_dict, (IterableDatasetDict, DatasetDict)):
Expand All @@ -1051,8 +1053,9 @@ def add_prompts_or_dataset_name_column(
dataset_dict[dataset_name] = self.add_prompts_or_dataset_name_column(
dataset_dict=dataset,
prompts=nested_prompts,
dataset_name=dataset_name,
dataset_name=dataset_name if include_dataset_name else None,
include_prompt_lengths=include_prompt_lengths,
include_dataset_name=include_dataset_name,
)
return dataset_dict

Expand Down
44 changes: 21 additions & 23 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,24 @@ def compute_loss_tracker(model, inputs, **kwargs):
trainer.compute_loss = compute_loss_tracker
trainer.train()

if prompts and not pool_include_prompt:
# In this one edge case, the prompts won't be used because the datasets aren't dictionaries, so the prompts
# are seen as column names & ignored as they don't exist.
if not (
# In this one edge case, the prompts won't be used because the datasets aren't dictionaries, so the prompts
# are seen as column names & ignored as they don't exist.
if (
prompts
and not pool_include_prompt
and not (
prompts == {"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "} and (train_dict, eval_dict) == (False, False)
):
assert "prompt_length" in tracked_forward_keys
)
):
assert "prompt_length" in tracked_forward_keys
else:
assert "prompt_length" not in tracked_forward_keys

# We only need the dataset_name if the loss requires it
if loss_dict:
assert "dataset_name" in datacollator_keys
else:
assert "dataset_name" not in datacollator_keys

if prompts is None:
if (train_dict, eval_dict) == (False, False):
Expand All @@ -498,9 +509,6 @@ def compute_loss_tracker(model, inputs, **kwargs):
elif (train_dict, eval_dict) == (True, True):
expected = all_train | all_eval

if loss_dict:
assert "dataset_name" in datacollator_keys

elif prompts == "Prompt: ":
if (train_dict, eval_dict) == (False, False):
expected = {prompts + sample for sample in all_train_1} | {prompts + sample for sample in all_eval_1}
Expand All @@ -514,10 +522,6 @@ def compute_loss_tracker(model, inputs, **kwargs):
if not pool_include_prompt:
expected.add(prompts)

# We only need the dataset_name if the loss requires it
if loss_dict:
assert "dataset_name" in datacollator_keys

elif prompts == {"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "}:
# If we don't have dataset dictionaries, the prompts will be seen as column names
if (train_dict, eval_dict) == (False, False):
Expand All @@ -542,12 +546,10 @@ def compute_loss_tracker(model, inputs, **kwargs):
| {prompts["stsb-2"] + sample for sample in all_eval_2}
)

if (train_dict, eval_dict) != (False, False):
if not pool_include_prompt:
expected.update(set(prompts.values()))

# We need the dataset_name because the prompts need it - except if the datasets aren't dictionaries
assert "dataset_name" in datacollator_keys
# We need to add the prompt to the expected set because we need to collect prompt lengths if
# not pool_include_prompt, except if the datasets aren't dictionaries
if (train_dict, eval_dict) != (False, False) and not pool_include_prompt:
expected.update(set(prompts.values()))

elif prompts == {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "}:
if (train_dict, eval_dict) == (False, False):
Expand Down Expand Up @@ -590,10 +592,6 @@ def compute_loss_tracker(model, inputs, **kwargs):
if not pool_include_prompt:
expected.update(set(prompts.values()))

# We only need the dataset_name if the loss requires it, because the prompts don't need it
if loss_dict:
assert "dataset_name" in datacollator_keys

elif prompts == {
"stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
"stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
Expand Down

0 comments on commit 50e1613

Please sign in to comment.