Skip to content

Commit

Permalink
remove unnecessary max_new_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 24, 2024
1 parent 6d44f3c commit d739c0a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
8 changes: 0 additions & 8 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,6 @@ def _sanitize_parameters(
" please use only one"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
else:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
if "max_new_tokens" not in forward_kwargs["generate_kwargs"]:
logger.warning_once(
"The `max_new_tokens` parameter is not set. By default, the model will generate up to 20 tokens."
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = 20

if return_full_text is not None and return_type is None:
if return_tensors is not None:
Expand Down
14 changes: 7 additions & 7 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_small_model_pt_token(self):
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
text = "<image> What this is? Assistant: This is"

outputs = pipe(image, text=text, max_new_tokens=20)
outputs = pipe(image, text=text)
self.assertEqual(
outputs,
[
Expand All @@ -82,7 +82,7 @@ def test_small_model_pt_token(self):
],
)

outputs = pipe([image, image], text=[text, text], max_new_tokens=20)
outputs = pipe([image, image], text=[text, text])
self.assertEqual(
outputs,
[
Expand All @@ -103,8 +103,8 @@ def test_consistent_batching_behaviour(self):
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
prompt = "a photo of"

outputs = pipe([image, image], text=[prompt, prompt], max_new_tokens=20)
outputs_batched = pipe([image, image], text=[prompt, prompt], max_new_tokens=20, batch_size=2)
outputs = pipe([image, image], text=[prompt, prompt])
outputs_batched = pipe([image, image], text=[prompt, prompt], batch_size=2)
self.assertEqual(outputs, outputs_batched)

@slow
Expand All @@ -123,7 +123,7 @@ def test_model_pt_chat_template(self):
],
}
]
outputs = pipe([image_ny, image_chicago], text=messages, max_new_tokens=20)
outputs = pipe([image_ny, image_chicago], text=messages)
self.assertEqual(
outputs,
[
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_model_pt_chat_template_continue_final_message(self):
],
},
]
outputs = pipe(text=messages, max_new_tokens=20)
outputs = pipe(text=messages)
self.assertEqual(
outputs,
[
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_model_pt_chat_template_new_text(self):
],
}
]
outputs = pipe(text=messages, max_new_tokens=20, return_full_text=False)
outputs = pipe(text=messages, return_full_text=False)
self.assertEqual(
outputs,
[
Expand Down

0 comments on commit d739c0a

Please sign in to comment.