Skip to content

Commit

Permalink
Merge branch 'main' into fix-step-shifting-when-accum-grad
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Oct 25, 2024
2 parents a69efb0 + 1d06379 commit 0333ba0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
6 changes: 3 additions & 3 deletions docs/source/en/agents_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ manager_agent.run("Who is the CEO of Hugging Face?")

Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.

If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.

The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
Expand Down Expand Up @@ -240,4 +240,4 @@ with gr.Blocks() as demo:

if __name__ == "__main__":
demo.launch()
```
```
10 changes: 9 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,15 @@ class GenerationConfig(PushToHubMixin):
speed up decoding.
cache_implementation (`str`, *optional*, default to `None`):
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
{ALL_CACHE_IMPLEMENTATIONS}. We support other cache types, but they must be manually instantiated and
- `"static"`: [`StaticCache`]
- `"offloaded_static"`: [`OffloadedStaticCache`]
- `"sliding_window"`: [`SlidingWindowCache`]
- `"hybrid"`: [`HybridCache`]
- `"mamba"`: [`MambaCache`]
- `"quantized"`: [`QuantizedCache`]
We support other cache types, but they must be manually instantiated and
passed to `generate` through the `past_key_values` argument. See our
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Expand Down
5 changes: 3 additions & 2 deletions tests/pipelines/test_pipelines_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def run_pipeline_test(self, summarizer, _):
and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
if str(summarizer.device) == "cpu":
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)

@require_torch
Expand Down
18 changes: 10 additions & 8 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,19 @@ def run_pipeline_test(self, text_generator, _):
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
):
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
if str(text_generator.device) == "cpu":
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)

outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
if str(text_generator.device) == "cpu":
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)

@require_torch
@require_accelerate
Expand Down

0 comments on commit 0333ba0

Please sign in to comment.