Skip to content

Commit

Permalink
Generate: GenerationConfig throws an exception when generate args…
Browse files Browse the repository at this point in the history
… are passed (huggingface#27757)
  • Loading branch information
gante authored Nov 30, 2023
1 parent fe41647 commit 510270a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,24 @@ def validate(self, is_init=False):
f"({self.num_beams})."
)

# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
"prefix_allowed_tokens_fn",
"synced_gpus",
"assistant_model",
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
)
for arg in generate_arguments:
if hasattr(self, arg):
raise ValueError(
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
"`generate()` (or a pipeline) directly."
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
28 changes: 28 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,34 @@ def test_kwarg_init(self):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value

def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)

# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
self.assertEqual(len(captured_warnings), 1)

# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)

# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")

# Case 5: Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)

def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""

Expand Down

0 comments on commit 510270a

Please sign in to comment.