diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile
index 7ad4e96d62cde7..b597f5a73fb5be 100644
--- a/docker/transformers-all-latest-gpu/Dockerfile
+++ b/docker/transformers-all-latest-gpu/Dockerfile
@@ -9,7 +9,7 @@ SHELL ["sh", "-lc"]
# The following `ARG` are mainly used to specify the versions explicitly & directly in this docker file, and not meant
# to be used as arguments for docker build (so far).
-ARG PYTORCH='2.4.0'
+ARG PYTORCH='2.5.1'
# (not always a valid torch version)
ARG INTEL_TORCH_EXT='2.3.0'
# Example: `cu102`, `cu113`, etc.
diff --git a/docker/transformers-pytorch-gpu/Dockerfile b/docker/transformers-pytorch-gpu/Dockerfile
index 62578ad0f3610f..f22d77b9372d7e 100644
--- a/docker/transformers-pytorch-gpu/Dockerfile
+++ b/docker/transformers-pytorch-gpu/Dockerfile
@@ -11,7 +11,7 @@ ARG REF=main
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF
# If set to nothing, will install the latest version
-ARG PYTORCH='2.4.0'
+ARG PYTORCH='2.5.1'
ARG TORCH_VISION=''
ARG TORCH_AUDIO=''
# Example: `cu102`, `cu113`, etc.
diff --git a/docs/source/en/model_doc/mllama.md b/docs/source/en/model_doc/mllama.md
index 9cb038ed2e3453..4a6080ea2ce03a 100644
--- a/docs/source/en/model_doc/mllama.md
+++ b/docs/source/en/model_doc/mllama.md
@@ -30,6 +30,25 @@ The Llama 3.2-Vision collection of multimodal large language models (LLMs) is a
- The text passed to the processor should have the `"<|image|>"` tokens where the images should be inserted.
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor.
+
+
+
+Mllama has an extra token used as a placeholder for image positions in the text. It means that input ids and an input embedding layer will have an extra token. But since the weights for input and output embeddings are not tied, the `lm_head` layer has one less token and will fail if you want to calculate loss on image tokens or apply some logit processors. In case you are training, make sure to mask out special `"<|image|>"` tokens in the `labels` as the model should not be trained on predicting them.
+
+Otherwise if you see CUDA-side index erros when generating, use the below code to expand the `lm_head` by one more token.
+
+
+```python
+old_embeddings = model.get_output_embeddings()
+
+num_tokens = model.vocab_size + 1
+resized_embeddings = model._get_resized_lm_head(old_embeddings, new_num_tokens=num_tokens, mean_resizing=True)
+resized_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
+model.set_output_embeddings(resized_embeddings)
+```
+
+
+
## Usage Example
#### Instruct model
diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md
index f9ea3337699444..7bee3472892727 100644
--- a/docs/source/en/trainer.md
+++ b/docs/source/en/trainer.md
@@ -252,7 +252,70 @@ trainer = Trainer(..., args=training_args)
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
-## GaLore
+## Liger Kernel
+
+[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.
+
+
+Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
+
+
+First make sure to install Liger official repository:
+```bash
+pip install liger-kernel
+```
+
+You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:
+
+```py
+from transformers import TrainingArguments
+
+training_args = TrainingArguments(
+ output_dir="your-model",
+ learning_rate=2e-5,
+ per_device_train_batch_size=16,
+ per_device_eval_batch_size=16,
+ num_train_epochs=2,
+ weight_decay=0.01,
+ eval_strategy="epoch",
+ save_strategy="epoch",
+ load_best_model_at_end=True,
+ push_to_hub=True,
+ use_liger_kernel=True
+)
+```
+
+The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
+
+
+## Optimizers
+
+You can choose a built-in optimizer for training using:
+
+```python
+from transformers import TrainingArguments
+training_args = TrainingArguments(..., optim="adamw_torch")
+```
+
+See [`OptimizerNames`](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) for a full list of choices. We include advanced examples in the sections below.
+
+You can also use an arbitrary PyTorch optimizer via:
+
+```python
+import torch
+
+optimizer_cls = torch.optim.AdamW
+optimizer_kwargs = {
+ "lr": 4e-3,
+ "betas": (0.9, 0.999),
+ "weight_decay": 0.05,
+}
+
+from transformers import Trainer
+trainer = Trainer(..., optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs))
+```
+
+### GaLore
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.
@@ -382,42 +445,7 @@ trainer.train()
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
-## Liger Kernel
-
-[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.
-
-
-Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
-
-
-First make sure to install Liger official repository:
-```bash
-pip install liger-kernel
-```
-
-You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:
-
-```py
-from transformers import TrainingArguments
-
-training_args = TrainingArguments(
- output_dir="your-model",
- learning_rate=2e-5,
- per_device_train_batch_size=16,
- per_device_eval_batch_size=16,
- num_train_epochs=2,
- weight_decay=0.01,
- eval_strategy="epoch",
- save_strategy="epoch",
- load_best_model_at_end=True,
- push_to_hub=True,
- use_liger_kernel=True
-)
-```
-
-The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
-
-## LOMO optimizer
+### LOMO optimizer
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.
@@ -467,7 +495,7 @@ trainer = trl.SFTTrainer(
trainer.train()
```
-## GrokAdamW optimizer
+### GrokAdamW optimizer
The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.
@@ -518,7 +546,7 @@ trainer.train()
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
-## Schedule Free Optimizer
+### Schedule Free Optimizer
The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
diff --git a/examples/research_projects/decision_transformer/requirements.txt b/examples/research_projects/decision_transformer/requirements.txt
index a54f3d03cab21b..6d42c3256a83e9 100644
--- a/examples/research_projects/decision_transformer/requirements.txt
+++ b/examples/research_projects/decision_transformer/requirements.txt
@@ -233,7 +233,7 @@ urllib3==1.26.19
wasabi==0.9.0
wcwidth==0.2.5
websocket-client==1.3.1
-Werkzeug==3.0.3
+Werkzeug==3.0.6
wrapt==1.14.0
xxhash==3.0.0
yarl==1.7.2
diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py
index 4e0e1dd3430209..bf44d4b427cf7b 100644
--- a/src/transformers/dynamic_module_utils.py
+++ b/src/transformers/dynamic_module_utils.py
@@ -152,7 +152,8 @@ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
content = f.read()
# filter out try/except block so in custom code we can have try/except imports
- content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
+ content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
+
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
content = re.sub(
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index efe953db051cb3..6e6d5b8bdce71d 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -378,10 +378,14 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
+ # (we can't check exception 3 while compiling)
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
- if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
+ if (
+ inputs_embeds is not None # Exception 1
+ or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
+ ):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
@@ -414,7 +418,7 @@ def prepare_inputs_for_generation(
for model_input_name in ["position_ids", "token_type_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
- if past_key_values:
+ if past_key_values is not None:
model_input = model_input[:, -input_ids.shape[1] :]
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
@@ -568,27 +572,34 @@ def _maybe_initialize_input_ids_for_generation(
def _prepare_attention_mask_for_generation(
self,
- inputs: torch.Tensor,
- pad_token_id: Optional[torch.Tensor],
- eos_token_id: Optional[torch.Tensor],
+ inputs_tensor: torch.Tensor,
+ generation_config: GenerationConfig,
+ model_kwargs: Dict[str, Any],
) -> torch.LongTensor:
+ pad_token_id = generation_config._pad_token_tensor
+ eos_token_id = generation_config._eos_token_tensor
+
+ # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
+ if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
+ inputs_tensor = model_kwargs["input_ids"]
+
# No information for attention mask inference -> return default attention mask
- default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
+ default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
if pad_token_id is None:
return default_attention_mask
- is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
+ is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (
- isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
+ isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
- attention_mask_from_padding = inputs.ne(pad_token_id).long()
+ attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
@@ -2020,7 +2031,7 @@ def generate(
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
+ inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 797908277930cf..0661da8727996f 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -1288,7 +1288,7 @@ def forward(
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
- n_image_features = image_tokens.shape[0]
+ n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py
index d0224e3caa5b28..f422b17b204f13 100644
--- a/src/transformers/models/clap/modeling_clap.py
+++ b/src/transformers/models/clap/modeling_clap.py
@@ -575,7 +575,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio
class ClapAudioLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
@@ -583,7 +583,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size)
- self.drop_path = ClapDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = ClapAudioIntermediate(config, dim)
self.output = ClapAudioOutput(config, dim)
@@ -712,6 +712,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
+ drop_path_rate=drop_path[i],
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py
index 8d639131b841ca..2d5272e8642ee5 100644
--- a/src/transformers/models/donut/modeling_donut_swin.py
+++ b/src/transformers/models/donut/modeling_donut_swin.py
@@ -558,7 +558,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class DonutSwinLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
@@ -566,7 +566,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
- self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path = DonutSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DonutSwinIntermediate(config, dim)
self.output = DonutSwinOutput(config, dim)
@@ -695,6 +695,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
+ drop_path_rate=drop_path[i],
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index a0079f1787a2e9..6d6bf4a6f38e3f 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -527,8 +527,9 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
- n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
- n_image_features = image_features.shape[1]
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 44b372535d70bd..85c109919da736 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -911,7 +911,8 @@ def forward(
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
)
legacy_processing = False
@@ -1020,6 +1021,7 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
+
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py
index e9974e920493ff..2025140bb6e36a 100644
--- a/src/transformers/models/llava_next_video/modular_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py
@@ -424,7 +424,8 @@ def forward(
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
)
legacy_processing = False
@@ -533,6 +534,7 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
+
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index 946688bfcf07f4..2aa6b2fa1d6fa5 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -657,7 +657,8 @@ def forward(
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
- "You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one"
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
)
if inputs_embeds is None:
@@ -679,6 +680,7 @@ def forward(
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
+
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@@ -704,6 +706,7 @@ def forward(
)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
+
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py
index 9a40e050459816..598e1d8186a24a 100644
--- a/src/transformers/models/maskformer/modeling_maskformer_swin.py
+++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py
@@ -520,16 +520,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class MaskFormerSwinLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
super().__init__()
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
- self.drop_path = (
- MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
- )
+ self.drop_path = MaskFormerSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = MaskFormerSwinIntermediate(config, dim)
self.output = MaskFormerSwinOutput(config, dim)
@@ -644,6 +642,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
+ drop_path_rate=drop_path[i],
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index c18e1d1c9d86b1..109ddfb626d26b 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -1562,7 +1562,7 @@ def generate(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
+ input_ids, generation_config, model_kwargs
)
# 5. Prepare `max_length` depending on other stopping criteria.
@@ -2578,7 +2578,7 @@ def generate(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
+ inputs_tensor, generation_config, model_kwargs
)
if "encoder_outputs" not in model_kwargs:
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index d2f339afc41451..61f2ce414e1ddf 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -1484,7 +1484,7 @@ def generate(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
+ input_ids, generation_config, model_kwargs
)
# 5. Prepare `max_length` depending on other stopping criteria.
@@ -2425,7 +2425,7 @@ def generate(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
+ inputs_tensor, generation_config, model_kwargs
)
if "encoder_hidden_states" not in model_kwargs:
diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py
index 32325a929411ba..14db51b947e664 100644
--- a/src/transformers/models/pixtral/configuration_pixtral.py
+++ b/src/transformers/models/pixtral/configuration_pixtral.py
@@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig):
Dropout probability for the attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
@@ -82,6 +84,7 @@ def __init__(
hidden_act="gelu",
attention_dropout=0.0,
rope_theta=10000.0,
+ initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@@ -97,3 +100,4 @@ def __init__(
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.head_dim = hidden_size // num_attention_heads
+ self.initializer_range = initializer_range
diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py
index 06b9701a75661a..b65fbd634ba789 100644
--- a/src/transformers/models/pixtral/modeling_pixtral.py
+++ b/src/transformers/models/pixtral/modeling_pixtral.py
@@ -407,7 +407,7 @@ def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
- else self.config.text_config.initializer_range
+ else self.config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py
index 70d28fb7b79c93..5913e8688d00be 100644
--- a/src/transformers/models/pixtral/processing_pixtral.py
+++ b/src/transformers/models/pixtral/processing_pixtral.py
@@ -206,14 +206,15 @@ def __call__(
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
- images = [images]
- elif (
- not isinstance(images, list)
- and not isinstance(images[0], list)
- and not is_image_or_image_url(images[0][0])
- ):
+ if isinstance(text, list):
+ images = [[im] for im in images]
+ else:
+ images = [images]
+ elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
+ pass
+ else:
raise ValueError(
- "Invalid input images. Please provide a single image or a list of images or a list of list of images."
+ "Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index 17e722a217dfd6..9c0d0b45ee8e51 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -1503,13 +1503,14 @@ def get_rope_index(
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
- if attention_mask is not None:
- input_ids = input_ids[attention_mask[i] == 1]
+ input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index 45383a36d9bea8..23f0ba6da620cd 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -635,7 +635,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class SwinLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
@@ -643,7 +643,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
- self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = SwinIntermediate(config, dim)
self.output = SwinOutput(config, dim)
@@ -771,6 +771,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
+ drop_path_rate=drop_path[i],
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py
index 035b31e8d43b80..f1aa0bfef743ad 100644
--- a/src/transformers/models/swin/modeling_tf_swin.py
+++ b/src/transformers/models/swin/modeling_tf_swin.py
@@ -742,7 +742,14 @@ def build(self, input_shape=None):
class TFSwinLayer(keras.layers.Layer):
def __init__(
- self, config, dim, input_resolution: Tuple[int, int], num_heads: int, shift_size: int = 0, **kwargs
+ self,
+ config,
+ dim,
+ input_resolution: Tuple[int, int],
+ num_heads: int,
+ drop_path_rate: float = 0.0,
+ shift_size: int = 0,
+ **kwargs,
) -> None:
super().__init__(**kwargs)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@@ -754,8 +761,8 @@ def __init__(
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
self.attention = TFSwinAttention(config, dim, num_heads, name="attention")
self.drop_path = (
- TFSwinDropPath(config.drop_path_rate, name="drop_path")
- if config.drop_path_rate > 0.0
+ TFSwinDropPath(drop_path_rate, name="drop_path")
+ if drop_path_rate > 0.0
else keras.layers.Activation("linear", name="drop_path")
)
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
@@ -913,6 +920,7 @@ def __init__(
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ drop_path_rate=drop_path[i],
name=f"blocks.{i}",
)
for i in range(depth)
diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py
index b0a773c8af3472..d6bd8da9bed638 100644
--- a/src/transformers/models/swin2sr/modeling_swin2sr.py
+++ b/src/transformers/models/swin2sr/modeling_swin2sr.py
@@ -482,7 +482,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Layer with Swinv2->Swin2SR
class Swin2SRLayer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
+ def __init__(
+ self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
+ ):
super().__init__()
self.input_resolution = input_resolution
window_size, shift_size = self._compute_window_shift(
@@ -500,7 +502,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretr
else (pretrained_window_size, pretrained_window_size),
)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- self.drop_path = Swin2SRDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path = Swin2SRDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.intermediate = Swin2SRIntermediate(config, dim)
self.output = Swin2SROutput(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py
index 0c30e739a48f91..191923958cfbde 100644
--- a/src/transformers/models/swinv2/modeling_swinv2.py
+++ b/src/transformers/models/swinv2/modeling_swinv2.py
@@ -683,7 +683,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class Swinv2Layer(nn.Module):
- def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
+ def __init__(
+ self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
+ ):
super().__init__()
self.input_resolution = input_resolution
window_size, shift_size = self._compute_window_shift(
@@ -701,7 +703,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretr
else (pretrained_window_size, pretrained_window_size),
)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- self.drop_path = Swinv2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.drop_path = Swinv2DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.intermediate = Swinv2Intermediate(config, dim)
self.output = Swinv2Output(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
@@ -819,6 +821,7 @@ def __init__(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
+ drop_path_rate=drop_path[i],
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
pretrained_window_size=pretrained_window_size,
)
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index 30f82e45056c77..a3b3de33fa66ee 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -534,7 +534,8 @@ def forward(
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
+ "time, and must specify either one"
)
legacy_processing = False
@@ -628,8 +629,8 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if pixel_values_images is not None:
- n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
- n_image_features = image_features.shape[1]
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@@ -644,8 +645,8 @@ def forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None:
- n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
- n_video_features = video_features.shape[1]
+ n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
+ n_video_features = video_features.shape[0] * video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index c9db6e261c6a72..4060f8c8ecd1bf 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -517,8 +517,8 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
- n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
- n_image_features = image_features.shape[1]
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
diff --git a/src/transformers/pipelines/depth_estimation.py b/src/transformers/pipelines/depth_estimation.py
index ae86c552a720af..2203ac09c9cf9b 100644
--- a/src/transformers/pipelines/depth_estimation.py
+++ b/src/transformers/pipelines/depth_estimation.py
@@ -1,4 +1,3 @@
-import warnings
from typing import List, Union
from ..utils import (
@@ -72,6 +71,9 @@ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Imag
A dictionary of argument names to parameter values, to control pipeline behaviour.
The only parameter available right now is `timeout`, which is the length of time, in seconds,
that the pipeline should wait before giving up on trying to download an image.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
@@ -93,9 +95,6 @@ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Imag
def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs):
preprocess_params = {}
if timeout is not None:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_params["timeout"] = timeout
if isinstance(parameters, dict) and "timeout" in parameters:
preprocess_params["timeout"] = parameters["timeout"]
diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py
index 20ad72e79055e2..0085e5eb73f826 100644
--- a/src/transformers/pipelines/image_classification.py
+++ b/src/transformers/pipelines/image_classification.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import warnings
from typing import List, Union
import numpy as np
@@ -113,9 +112,6 @@ def __init__(self, *args, **kwargs):
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
preprocess_params = {}
if timeout is not None:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_params["timeout"] = timeout
postprocess_params = {}
if top_k is not None:
@@ -159,6 +155,9 @@ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Imag
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py
index 0ac653fd1e8725..d388e591bf9df4 100644
--- a/src/transformers/pipelines/image_segmentation.py
+++ b/src/transformers/pipelines/image_segmentation.py
@@ -1,4 +1,3 @@
-import warnings
from typing import Any, Dict, List, Union
import numpy as np
@@ -91,9 +90,6 @@ def _sanitize_parameters(self, **kwargs):
if "overlap_mask_area_threshold" in kwargs:
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
if "timeout" in kwargs:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_kwargs["timeout"] = kwargs["timeout"]
return preprocess_kwargs, {}, postprocess_kwargs
@@ -122,6 +118,9 @@ def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]
Threshold to use when turning the predicted masks into binary values.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
Mask overlap threshold to eliminate small, disconnected segments.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
Return:
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py
index a4f13d7b352f66..afd67b6ac9edee 100644
--- a/src/transformers/pipelines/image_to_text.py
+++ b/src/transformers/pipelines/image_to_text.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import warnings
from typing import List, Union
from ..utils import (
@@ -81,9 +80,6 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt
if prompt is not None:
preprocess_params["prompt"] = prompt
if timeout is not None:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_params["timeout"] = timeout
if max_new_tokens is not None:
@@ -118,6 +114,10 @@ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Imag
generate_kwargs (`Dict`, *optional*):
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
+
Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py
index c135b1e131acb9..c84f17b2bd6ad0 100644
--- a/src/transformers/pipelines/object_detection.py
+++ b/src/transformers/pipelines/object_detection.py
@@ -1,4 +1,3 @@
-import warnings
from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
@@ -64,9 +63,6 @@ def __init__(self, *args, **kwargs):
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "timeout" in kwargs:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_params["timeout"] = kwargs["timeout"]
postprocess_kwargs = {}
if "threshold" in kwargs:
@@ -89,6 +85,9 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
threshold (`float`, *optional*, defaults to 0.5):
The probability necessary to make a prediction.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
Return:
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py
index 253c684fcbbdad..c53b515dcccd9c 100644
--- a/src/transformers/pipelines/zero_shot_image_classification.py
+++ b/src/transformers/pipelines/zero_shot_image_classification.py
@@ -94,6 +94,10 @@ def __call__(self, image: Union[str, List[str], "Image", List["Image"]] = None,
replacing the placeholder with the candidate_labels. Pass "{}" if *candidate_labels* are
already formatted.
+ timeout (`float`, *optional*, defaults to None):
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
+ the call may block forever.
+
Return:
A list of dictionaries containing one entry per proposed label. Each dictionary contains the
following keys:
@@ -113,9 +117,6 @@ def _sanitize_parameters(self, tokenizer_kwargs=None, **kwargs):
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "timeout" in kwargs:
- warnings.warn(
- "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
- )
preprocess_params["timeout"] = kwargs["timeout"]
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index 2c26504670e0d8..b5b02f6a00aa09 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -874,12 +874,13 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
else:
# kwargs is a flat dictionary
for key in kwargs:
- if key not in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys():
- logger.warning_once(
- f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
- )
- elif key not in used_keys:
- output_kwargs["common_kwargs"][key] = kwargs[key]
+ if key not in used_keys:
+ if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys():
+ output_kwargs["common_kwargs"][key] = kwargs[key]
+ else:
+ logger.warning_once(
+ f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
+ )
# all modality-specific kwargs are updated with common kwargs
for modality in output_kwargs:
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 4f3187d510fad1..89ab2dc9260819 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -1687,8 +1687,8 @@ def __repr__(self) -> str:
f"{self.__class__.__name__}(name_or_path='{self.name_or_path}',"
f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast},"
f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
- f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}), "
- " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}"
+ f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces},"
+ " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)"
)
def __len__(self) -> int:
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 9176bd72a55032..e2ae622e2b6bf3 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -34,7 +34,7 @@
import warnings
from collections.abc import Mapping
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
# Integrations must be imported before ML frameworks:
@@ -358,6 +358,11 @@ class Trainer:
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
+ A tuple containing the optimizer class and keyword arguments to use.
+ Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
+
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
@@ -401,7 +406,8 @@ def __init__(
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
- optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+ optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
if args is None:
@@ -603,6 +609,9 @@ def __init__(
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.optimizer, self.lr_scheduler = optimizers
+ self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
+ if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
+ raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
@@ -1171,7 +1180,10 @@ def create_optimizer(self):
},
]
- optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
+ if self.optimizer_cls_and_kwargs is not None:
+ optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
+ else:
+ optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index d552bf73442ce7..545b696d67370a 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -29,6 +29,7 @@
from transformers.testing_utils import (
is_flaky,
require_accelerate,
+ require_flash_attn,
require_optimum_quanto,
require_torch,
require_torch_gpu,
@@ -136,6 +137,34 @@ def prepare_config_and_inputs_for_generate(self, batch_size=2):
return config, filtered_inputs_dict
+ def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
+ """
+ Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
+ the following siturations:
+ 1. The sequences are the same
+ 2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
+ """
+ # scores doesn't include data regarding decoder input tokens
+ decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
+ output_matches = output_1.sequences == output_2.sequences
+ has_matching_outputs = output_matches.all()
+ has_matching_scores = None
+ if not has_matching_outputs:
+ for batch_idx in range(output_1.sequences.shape[0]):
+ batch_matches = output_matches[batch_idx]
+ if batch_matches.all():
+ continue
+ first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
+ first_mismatch_idx -= decoder_input_length
+ output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
+ output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
+ has_matching_scores = torch.allclose(
+ output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
+ )
+ if not has_matching_scores:
+ break
+ self.assertTrue(has_matching_outputs or has_matching_scores)
+
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
@@ -426,7 +455,6 @@ def test_greedy_generate(self):
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@@ -453,13 +481,12 @@ def test_greedy_generate_dict_outputs(self):
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
- self._check_outputs(output_generate, main_input, model.config)
+ self._check_outputs(output_generate, model.config)
@pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -486,7 +513,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
- self._check_outputs(output_generate, main_input, model.config, use_cache=True)
+ self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_sample_generate(self):
@@ -505,7 +532,6 @@ def test_sample_generate(self):
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
@@ -533,7 +559,7 @@ def test_sample_generate_dict_output(self):
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
- self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2)
+ self._check_outputs(output_generate, model.config, num_return_sequences=2)
@pytest.mark.generate
def test_beam_search_generate(self):
@@ -554,7 +580,6 @@ def test_beam_search_generate(self):
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@@ -583,14 +608,16 @@ def test_beam_search_generate_dict_output(self):
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
- output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
+ output_generate,
+ model.config,
+ num_return_sequences=beam_kwargs["num_return_sequences"],
+ num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -623,10 +650,10 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
self._check_outputs(
output_generate,
- main_input,
model.config,
use_cache=True,
- num_return_sequences=beam_kwargs["num_beams"],
+ num_return_sequences=beam_kwargs["num_return_sequences"],
+ num_beams=beam_kwargs["num_beams"],
)
@require_accelerate
@@ -675,7 +702,6 @@ def test_beam_sample_generate(self):
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@@ -706,7 +732,10 @@ def test_beam_sample_generate_dict_output(self):
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(
- output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
+ output_generate,
+ model.config,
+ num_return_sequences=beam_kwargs["num_return_sequences"],
+ num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@@ -765,7 +794,6 @@ def test_group_beam_search_generate(self):
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
@@ -794,7 +822,10 @@ def test_group_beam_search_generate_dict_output(self):
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
- output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
+ output_generate,
+ model.config,
+ num_return_sequences=beam_kwargs["num_return_sequences"],
+ num_beams=beam_kwargs["num_beams"],
)
# TODO: @gante check why it is flaky
@@ -859,7 +890,6 @@ def test_constrained_beam_search_generate(self):
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
@@ -899,7 +929,10 @@ def test_constrained_beam_search_generate_dict_output(self):
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
- output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
+ output_generate,
+ model.config,
+ num_return_sequences=beam_kwargs["num_return_sequences"],
+ num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@@ -942,7 +975,6 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
self.skipTest(reason="Won't fix: old model with different cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -968,7 +1000,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
- self._check_outputs(output_generate, main_input, model.config, use_cache=True)
+ self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_contrastive_generate_low_memory(self):
@@ -1064,14 +1096,10 @@ def test_beam_search_low_memory(self):
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
- @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
- # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
- # shape differences -- and it may result in a different output. The input shape difference happens in the
- # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
- # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
- # NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
+ # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
+ # NOTE: It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
@@ -1100,7 +1128,6 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
- main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1141,12 +1168,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
-
- self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
+ self._check_similar_generate_outputs(output_greedy, output_assisted)
for output in (output_greedy, output_assisted):
- self._check_outputs(output, main_input, model.config, use_cache=True)
+ self._check_outputs(output, model.config, use_cache=True)
- @is_flaky()
@pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
@@ -1175,7 +1200,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
- main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1208,10 +1232,9 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
-
- self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
+ self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
for output in (output_greedy, output_prompt_lookup):
- self._check_outputs(output, main_input, model.config, use_cache=True)
+ self._check_outputs(output, model.config, use_cache=True)
@pytest.mark.generate
def test_dola_decoding_sample(self):
@@ -1231,7 +1254,6 @@ def test_dola_decoding_sample(self):
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- main_input = inputs_dict[model_class.main_input_name]
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
@@ -1259,7 +1281,7 @@ def test_dola_decoding_sample(self):
"dola_layers": "low",
}
output_dola = model.generate(**generation_kwargs, **inputs_dict)
- self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False))
+ self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
@pytest.mark.generate
def test_assisted_decoding_sample(self):
@@ -1289,7 +1311,6 @@ def test_assisted_decoding_sample(self):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
- main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1321,7 +1342,7 @@ def test_assisted_decoding_sample(self):
}
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
- self._check_outputs(output_assisted, main_input, config, use_cache=True)
+ self._check_outputs(output_assisted, config, use_cache=True)
@pytest.mark.generate
def test_prompt_lookup_decoding_stops_at_eos(self):
@@ -1547,75 +1568,93 @@ def test_past_key_values_format(self):
)
@pytest.mark.generate
- @parameterized.expand([(1,), (2,)])
- def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
+ @parameterized.expand([("greedy", 1), ("beam search", 2)])
+ def test_generate_from_inputs_embeds(self, _, num_beams):
+ """Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- # Ignore:
- # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
- # which would cause a mismatch),
- config.pad_token_id = config.eos_token_id = -1
- # b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
- # variable that holds the scaling factor, which is model-dependent)
- if hasattr(config, "scale_embedding"):
- config.scale_embedding = False
-
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
# decoder)
if config.is_encoder_decoder:
continue
+ config.is_decoder = True
# Skip models without explicit support
- config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
+ # There are a few exception patterns in this test:
+ # 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
+ requires_inputs_ids = any(
+ model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
+ )
+ # 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
+ # than calling the embedding layer with `input_ids`. Subcases of this exception:
+ # 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
+ if hasattr(config, "scale_embedding"):
+ config.scale_embedding = False
+ # 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
+ # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
+ # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
+ pixel_values_is_mutually_exclusive = any(
+ model_name in model_class.__name__.lower()
+ for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"]
+ )
+ if pixel_values_is_mutually_exclusive:
+ inputs_dict.pop("pixel_values", None)
+ inputs_dict.pop("pixel_values_videos", None)
+ inputs_dict.pop("pixel_values_images", None)
+ # 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
+ has_complex_embeds_computation = any(
+ model_name in model_class.__name__.lower() for model_name in ["moshi"]
+ )
+ # 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
+ # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
+ missing_attention_mask = "attention_mask" not in inputs_dict
+
+ # Traditional way of generating text
input_ids = inputs_dict.pop("input_ids")
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"num_beams": num_beams,
"do_sample": False,
+ "max_new_tokens": 5,
+ "min_new_tokens": 5, # generate exactly 5 tokens
}
-
- # Traditional way of generating text
- outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs)
+ outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
- # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
+ # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
+ # The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
- input_ids,
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- **generation_kwargs,
+ input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
- self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
+ if not has_complex_embeds_computation:
+ self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
- # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
+ # If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
- input_ids,
- inputs_embeds=random_embeds,
- max_new_tokens=5,
- **generation_kwargs,
+ input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
- # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
- outputs_from_embeds_wo_ids = model.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs
- )
- self.assertListEqual(
- outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
- outputs_from_embeds_wo_ids.sequences.tolist(),
- )
+ # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
+ # be the same
+ if not (requires_inputs_ids or missing_attention_mask):
+ outputs_from_embeds_wo_ids = model.generate(
+ inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
+ )
+ outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
+ self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
@@ -1829,10 +1868,8 @@ def test_new_cache_format(self, num_beams, do_sample):
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
- Tests if StaticCache works if we set attn_implementation=static when generation.
- This doesn't test if generation quality is good, but tests that models with
- self._supports_static_cache don't throw an error when generating and return
- a StaticCache object at the end.
+ Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
+ has the expected shapes
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
@@ -1851,13 +1888,15 @@ def test_generate_with_static_cache(self):
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
- "max_length": None,
"max_new_tokens": max_new_tokens,
- "cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
+ "output_scores": True,
"use_cache": True,
}
+ static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
+
+ # Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens
config = config.text_config if hasattr(config, "text_config") else config
head_dim = (
@@ -1869,12 +1908,14 @@ def test_generate_with_static_cache(self):
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
- results = model.generate(**generation_kwargs, **inputs_dict)
-
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
- self.assertTrue(isinstance(results.past_key_values, StaticCache))
- self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
- self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
+ self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
+ self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
+ self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
+
+ # Check 2: The outputs must be similar to the case with dynamic cache
+ dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
+ self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
@require_optimum_quanto
@pytest.mark.generate
@@ -1908,25 +1949,32 @@ def test_generate_with_quant_cache(self):
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)
+ @parameterized.expand(
+ [
+ ("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
+ ("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
+ ]
+ )
@pytest.mark.generate
@require_torch_gpu
@slow
- @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
- def test_generate_compile_fullgraph(self):
+ def test_generate_compile(self, _, end_to_end):
"""
- Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
+ Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
+ end-to-end compilation and forward pass compilation only.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
+
# TODO (joao) -- fix and enable me :)
- if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
+ if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
- if config.is_encoder_decoder:
+ if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
@@ -1941,27 +1989,33 @@ def test_generate_compile_fullgraph(self):
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
+ "return_dict_in_generate": True,
+ "output_scores": True,
}
+ # end-to-end works best with dynamic cache, forward compilation works best with static cache
+ if not end_to_end:
+ generation_kwargs["cache_implementation"] = "static"
- max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
- config = config.get_text_config()
- past_key_values = StaticCache(
- config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
- )
+ # get eager + dynamic cache results for future comparison
+ dynamic_outputs = []
+ for model_inputs in input_ids_sets:
+ dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
+
+ # get compiled results
+ generation_config = copy.deepcopy(model.generation_config)
+ generation_config.update(**generation_kwargs)
+ torch.compiler.reset()
+ if end_to_end:
+ model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
+ else:
+ model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
+ compiled_outputs = []
for model_inputs in input_ids_sets:
- # eager dynamic cache
- output_dynamic = model.generate(model_inputs, **generation_kwargs)
-
- # end-to-end compiled dynamic cache
- torch.compiler.reset()
- compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
- generation_config = copy.deepcopy(model.generation_config)
- generation_config.update(**generation_kwargs)
- output_compiled = compiled_generate(
- model_inputs, generation_config=generation_config, past_key_values=past_key_values
- )
- self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
+ compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
+
+ for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
+ self._check_similar_generate_outputs(dynamic_result, compiled_result)
@pytest.mark.generate
def test_generate_methods_with_num_logits_to_keep(self):
@@ -1989,7 +2043,6 @@ def test_generate_methods_with_num_logits_to_keep(self):
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
- @is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
@@ -1998,6 +2051,9 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
self.skipTest(reason="Stateful models don't support assisted generation")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
+ # NOTE: assisted generation only works with cache on at the moment.
+ if not hasattr(config, "use_cache"):
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -2010,14 +2066,16 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
+ "return_dict_in_generate": True,
+ "output_scores": True,
}
- assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
- self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
+
+ self._check_similar_generate_outputs(with_all_logits, without_all_logits)
@pytest.mark.generate
def test_inherits_generation_mixin(self):
@@ -2028,14 +2086,21 @@ def test_inherits_generation_mixin(self):
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
- @require_torch_sdpa
- @slow
- def test_eager_matches_sdpa_generate(self):
+ def _test_attention_implementation(self, attn_implementation):
+ """
+ Compares the output of generate with the eager attention implementation against other implementations.
+ NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence
+ this separate function.
+ """
max_new_tokens = 30
+ support_flag = {
+ "sdpa": "_supports_sdpa",
+ "flash_attention_2": "_supports_flash_attn_2",
+ }
for model_class in self.all_generative_model_classes:
- if not model_class._supports_sdpa:
- self.skipTest(f"{model_class.__name__} does not support SDPA")
+ if not getattr(model_class, support_flag[attn_implementation]):
+ self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
@@ -2062,63 +2127,59 @@ def test_eager_matches_sdpa_generate(self):
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
+ "use_cache": True,
}
- model_sdpa = model_class.from_pretrained(
+ model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ attn_implementation="eager",
).to(torch_device)
- res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
- del model_sdpa
+ res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
+ del model_eager
gc.collect()
- model_eager = model_class.from_pretrained(
+ model_attn = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
- attn_implementation="eager",
+ attn_implementation=attn_implementation,
).to(torch_device)
- res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
- del model_eager
+ res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
+ del model_attn
gc.collect()
- # Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
- # test would be flaky if we only checked the sequences. Two situations in which this test passes:
- # 1. The sequences are the same
- # 2. The sequences are different, but the scores up until the first mismatch are nearly identical
- output_matches = res_eager.sequences == res_sdpa.sequences
- has_matching_outputs = output_matches.all()
- has_matching_scores = None
- if not has_matching_outputs:
- input_length = main_input.shape[1]
- for batch_idx in range(res_eager.sequences.shape[0]):
- batch_matches = output_matches[batch_idx]
- if batch_matches.all():
- continue
- first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
- first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
- sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
- eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
- has_matching_scores = torch.allclose(
- sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
- )
- if not has_matching_scores:
- break
+ self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
- self.assertTrue(has_matching_outputs or has_matching_scores)
+ @pytest.mark.generate
+ @require_torch_sdpa
+ @slow
+ def test_eager_matches_sdpa_generate(self):
+ """Tests that generate has equivalent outputs with SDPA and eager attention implementations."""
+ self._test_attention_implementation("sdpa")
- def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
- # we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
- # so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
- batch_size = main_input.shape[0]
+ @pytest.mark.flash_attn_test
+ @require_flash_attn
+ @require_torch_gpu
+ @slow
+ def test_eager_matches_fa2_generate(self):
+ """Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
+ # TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
+ # check whether we still need the overwrites
+ self._test_attention_implementation("flash_attention_2")
+
+ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
+ input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
+ internal_batch_size = (
+ input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
+ )
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
- num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
@@ -2129,19 +2190,21 @@ def _check_outputs(self, output, main_input, config, use_cache=False, num_return
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
- self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
+ self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
- self._check_logits(num_sequences_in_output, output.logits, config=config)
+ self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
- self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
+ self._check_encoder_attention_for_generate(
+ output.encoder_attentions, input_batch_size, config, seq_length
+ )
# decoder
self._check_attentions_for_generate(
- num_sequences_in_output,
+ internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
@@ -2153,7 +2216,7 @@ def _check_outputs(self, output, main_input, config, use_cache=False, num_return
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
- num_sequences_in_output,
+ internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
@@ -2165,12 +2228,12 @@ def _check_outputs(self, output, main_input, config, use_cache=False, num_return
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
- output.encoder_hidden_states, batch_size, config, seq_length
+ output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
- num_sequences_in_output,
+ internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
@@ -2182,7 +2245,7 @@ def _check_outputs(self, output, main_input, config, use_cache=False, num_return
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
- num_sequences_in_output,
+ internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
@@ -2213,7 +2276,7 @@ def _check_outputs(self, output, main_input, config, use_cache=False, num_return
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
- num_sequences_in_output,
+ internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,
diff --git a/tests/models/albert/test_modeling_albert.py b/tests/models/albert/test_modeling_albert.py
index d1e5631b342d33..970f1dd8555e47 100644
--- a/tests/models/albert/test_modeling_albert.py
+++ b/tests/models/albert/test_modeling_albert.py
@@ -16,7 +16,9 @@
import unittest
-from transformers import AlbertConfig, is_torch_available
+from packaging import version
+
+from transformers import AlbertConfig, AutoTokenizer, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
@@ -342,3 +344,45 @@ def test_inference_no_head_absolute_embedding(self):
)
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
+
+ @slow
+ def test_export(self):
+ if version.parse(torch.__version__) < version.parse("2.4.0"):
+ self.skipTest(reason="This test requires torch >= 2.4 to run.")
+
+ distilbert_model = "albert/albert-base-v2"
+ device = "cpu"
+ attn_implementation = "sdpa"
+ max_length = 64
+
+ tokenizer = AutoTokenizer.from_pretrained(distilbert_model)
+ inputs = tokenizer(
+ f"Paris is the {tokenizer.mask_token} of France.",
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ )
+
+ model = AlbertForMaskedLM.from_pretrained(
+ distilbert_model,
+ device_map=device,
+ attn_implementation=attn_implementation,
+ )
+
+ logits = model(**inputs).logits
+ eg_predicted_mask = tokenizer.decode(logits[0, 4].topk(5).indices)
+ self.assertEqual(
+ eg_predicted_mask.split(),
+ ["capital", "capitol", "comune", "arrondissement", "bastille"],
+ )
+
+ exported_program = torch.export.export(
+ model,
+ args=(inputs["input_ids"],),
+ kwargs={"attention_mask": inputs["attention_mask"]},
+ strict=True,
+ )
+
+ result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
+ ep_predicted_mask = tokenizer.decode(result.logits[0, 4].topk(5).indices)
+ self.assertEqual(eg_predicted_mask, ep_predicted_mask)
diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py
index eda51d21199f31..e4d0df141be2b9 100644
--- a/tests/models/bart/test_modeling_bart.py
+++ b/tests/models/bart/test_modeling_bart.py
@@ -1532,8 +1532,3 @@ def test_retain_grad_hidden_states_attentions(self):
@unittest.skip
def test_save_load_fast_init_from_base(self):
pass
-
- @unittest.skip(reason="Generate needs input ids")
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- # generate only works with input ids for bartforcausalLM
- pass
diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py
index aa9835d8cd67c1..25566027742507 100644
--- a/tests/models/bert/test_modeling_bert.py
+++ b/tests/models/bert/test_modeling_bert.py
@@ -511,11 +511,6 @@ def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
- @unittest.skip(reason="Generate needs input ids")
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- # generate only works with input ids for bertforcausalLM
- pass
-
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(
diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py
index 2adca2c6da0668..bb2ba8b3428174 100644
--- a/tests/models/chameleon/test_modeling_chameleon.py
+++ b/tests/models/chameleon/test_modeling_chameleon.py
@@ -16,17 +16,14 @@
import unittest
-import pytest
import requests
from parameterized import parameterized
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
- require_flash_attn,
require_read_token,
require_torch,
- require_torch_gpu,
slow,
torch_device,
)
@@ -330,43 +327,6 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
- @require_flash_attn
- @require_read_token
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- model = ChameleonForConditionalGeneration.from_pretrained(
- "facebook/chameleon-7b",
- load_in_4bit=True,
- device_map={"": 0},
- )
-
- processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
- texts = ["hi", "Hello this is a very long sentence"]
-
- processor.tokenizer.padding_side = "right"
-
- inputs = processor(text=texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_native = processor.tokenizer.batch_decode(output_native)
-
- model = ChameleonForConditionalGeneration.from_pretrained(
- "facebook/chameleon-7b",
- load_in_4bit=True,
- attn_implementation="flash_attention_2",
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = processor.tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@unittest.skip("Chameleon forces some token ids to be -inf!")
def test_batching_equivalence(self):
pass
diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py
index a888bdcd3bc7be..e8483f8c7c7d32 100644
--- a/tests/models/gemma/test_modeling_gemma.py
+++ b/tests/models/gemma/test_modeling_gemma.py
@@ -319,9 +319,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "google/gemma-2b"
-
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
@@ -419,51 +416,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Gemma apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py
index 94670803daa998..7bca83f96d73ab 100644
--- a/tests/models/gemma2/test_modeling_gemma2.py
+++ b/tests/models/gemma2/test_modeling_gemma2.py
@@ -78,7 +78,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
- _torch_compile_test_ckpt = "google/gemma-2-9b"
def setUp(self):
self.model_tester = Gemma2ModelTester(self)
diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py
index 32bce7cbfa615e..b92c5db815b77a 100644
--- a/tests/models/glm/test_modeling_glm.py
+++ b/tests/models/glm/test_modeling_glm.py
@@ -28,7 +28,6 @@
require_flash_attn,
require_torch,
require_torch_accelerator,
- require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
@@ -306,10 +305,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
test_headmasking = False
test_pruning = False
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "THUDM/glm-4-9b"
- _torch_compile_test_revision = "refs/pr/15"
-
def setUp(self):
self.model_tester = GlmModelTester(self)
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
@@ -426,41 +421,6 @@ def test_custom_4d_attention_mask(self):
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """Overwrite the common test as the test is flaky on tiny models."""
- model = GlmForCausalLM.from_pretrained(
- "THUDM/glm-4-9b",
- device_map={"": 0},
- torch_dtype=torch.bfloat16,
- revision="refs/pr/15",
- )
-
- tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b", revision="refs/pr/15")
- tokenizer.padding_side = "right"
-
- texts = ["hi", "Hello this is a very long sentence"]
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=15, do_sample=False)
- output_native = tokenizer.batch_decode(output_native)
-
- model = GlmForCausalLM.from_pretrained(
- "THUDM/glm-4-9b",
- device_map={"": 0},
- attn_implementation="flash_attention_2",
- torch_dtype=torch.bfloat16,
- revision="refs/pr/15",
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=15, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
index 6f6fba50dc123a..afc741cd502dec 100644
--- a/tests/models/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -17,14 +17,9 @@
import datetime
import unittest
-import pytest
-
-from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
+from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import (
- require_bitsandbytes,
- require_flash_attn,
require_torch,
- require_torch_gpu,
slow,
tooslow,
torch_device,
@@ -505,44 +500,6 @@ def test_model_from_pretrained(self):
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
self.assertIsNotNone(model)
- @require_flash_attn
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
-
- texts = ["hi", "Hello this is a very long sentence"]
- expected_outputs = [
- "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
- "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
- ]
-
- tokenizer.padding_side = "right"
- tokenizer.pad_token = tokenizer.eos_token
-
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
-
- model = GPTJForCausalLM.from_pretrained(
- "EleutherAI/gpt-j-6b",
- device_map={"": 0},
- attn_implementation="flash_attention_2",
- revision="float16",
- torch_dtype=torch.float16,
- quantization_config=quantization_config,
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(expected_outputs, output_fa_2)
-
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py
index 1bcb6641803c04..97b59f5aa50621 100644
--- a/tests/models/granite/test_modeling_granite.py
+++ b/tests/models/granite/test_modeling_granite.py
@@ -17,12 +17,10 @@
import tempfile
import unittest
-import pytest
from parameterized import parameterized
-from transformers import AutoTokenizer, GraniteConfig, is_torch_available, set_seed
+from transformers import GraniteConfig, is_torch_available, set_seed
from transformers.testing_utils import (
- require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@@ -303,9 +301,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "ibm/PowerLM-3b"
-
def setUp(self):
self.model_tester = GraniteModelTester(self)
self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37)
@@ -423,46 +418,6 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
- @require_flash_attn
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @require_read_token
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- model = GraniteForCausalLM.from_pretrained(
- "ibm/PowerLM-3b",
- load_in_4bit=True,
- device_map={"": 0},
- )
-
- tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
-
- texts = ["hi", "Hello this is a very long sentence"]
-
- tokenizer.padding_side = "right"
- tokenizer.pad_token = tokenizer.eos_token
-
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_native = tokenizer.batch_decode(output_native)
-
- model = GraniteForCausalLM.from_pretrained(
- "ibm/PowerLM-3b",
- load_in_4bit=True,
- device_map={"": 0},
- attn_implementation="flash_attention_2",
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py
index 124ce0c3bb5ae6..f2f76b9fa75bf3 100644
--- a/tests/models/granitemoe/test_modeling_granitemoe.py
+++ b/tests/models/granitemoe/test_modeling_granitemoe.py
@@ -17,12 +17,10 @@
import tempfile
import unittest
-import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
from transformers.testing_utils import (
- require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@@ -302,9 +300,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "ibm/PowerMoE-3b"
-
def setUp(self):
self.model_tester = GraniteMoeModelTester(self)
self.config_tester = ConfigTester(self, config_class=GraniteMoeConfig, hidden_size=37)
@@ -422,46 +417,6 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
- @require_flash_attn
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @require_read_token
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- model = GraniteMoeForCausalLM.from_pretrained(
- "ibm-granite/granitemoe-3b",
- load_in_4bit=True,
- device_map={"": 0},
- )
-
- tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granitemoe-3b")
-
- texts = ["hi", "Hello this is a very long sentence"]
-
- tokenizer.padding_side = "right"
- tokenizer.pad_token = tokenizer.eos_token
-
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_native = tokenizer.batch_decode(output_native)
-
- model = GraniteMoeForCausalLM.from_pretrained(
- "ibm-granite/granitemoe-3b",
- load_in_4bit=True,
- device_map={"": 0},
- attn_implementation="flash_attention_2",
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py
index 158831ad0cda27..7be87fd78390ab 100644
--- a/tests/models/idefics/test_modeling_idefics.py
+++ b/tests/models/idefics/test_modeling_idefics.py
@@ -774,13 +774,6 @@ def test_contrastive_generate_low_memory(self):
def test_custom_4d_attention_mask(self):
pass
- @unittest.skip(
- reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
- )
- @parameterized.expand([(1,), (2,)])
- def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
- pass
-
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
pass
diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py
index 49828679146781..81f71a4746ebf7 100644
--- a/tests/models/idefics2/test_modeling_idefics2.py
+++ b/tests/models/idefics2/test_modeling_idefics2.py
@@ -20,7 +20,6 @@
import unittest
from io import BytesIO
-import pytest
import requests
from transformers import (
@@ -421,50 +420,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
def test_flash_attn_2_fp32_ln(self):
pass
- @pytest.mark.generate
- def test_generate_from_inputs_embeds_decoder_only(self):
- # overwrite because IDEFICS needs ids and embeds at the input to be not None
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.prepare_config_and_inputs_for_generate()
-
- # Ignore:
- # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
- # which would cause a mismatch),
- config.pad_token_id = config.eos_token_id = -1
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
- input_ids = inputs_dict.pop("input_ids")
-
- # Traditional way of generating text
- outputs_from_ids = model.generate(
- input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
- )
- self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
-
- # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
- inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs_from_embeds = model.generate(
- input_ids,
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
-
- # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
- # same, but the logits will almost surely be different)
- random_embeds = torch.rand_like(inputs_embeds)
- outputs_from_rand_embeds = model.generate(
- input_ids,
- inputs_embeds=random_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- for i in range(len(outputs_from_rand_embeds.scores)):
- self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
-
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py
index 2ce58bbcc57d73..35f6387b323ddc 100644
--- a/tests/models/idefics3/test_modeling_idefics3.py
+++ b/tests/models/idefics3/test_modeling_idefics3.py
@@ -19,7 +19,6 @@
import unittest
from io import BytesIO
-import pytest
import requests
from transformers import (
@@ -180,10 +179,6 @@ def test_inputs_embeds():
def test_inputs_embeds_matches_input_ids(self):
pass
- @unittest.skip(reason="Model does not support padding right")
- def test_flash_attn_2_generate_padding_right(self):
- pass
-
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
@@ -338,10 +333,6 @@ def setUp(self):
def test_inputs_embeds():
pass
- @unittest.skip(reason="Model does not support padding right")
- def test_flash_attn_2_generate_padding_right(self):
- pass
-
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
@@ -368,50 +359,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
def test_flash_attn_2_fp32_ln(self):
pass
- @pytest.mark.generate
- def test_generate_from_inputs_embeds_decoder_only(self):
- # overwrite because IDEFICS needs ids and embeds at the input to be not None
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.prepare_config_and_inputs_for_generate()
-
- # Ignore:
- # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
- # which would cause a mismatch),
- config.pad_token_id = config.eos_token_id = -1
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
- input_ids = inputs_dict.pop("input_ids")
-
- # Traditional way of generating text
- outputs_from_ids = model.generate(
- input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
- )
- self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
-
- # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
- inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs_from_embeds = model.generate(
- input_ids,
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
-
- # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
- # same, but the logits will almost surely be different)
- random_embeds = torch.rand_like(inputs_embeds)
- outputs_from_rand_embeds = model.generate(
- input_ids,
- inputs_embeds=random_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- for i in range(len(outputs_from_rand_embeds.scores)):
- self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
-
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
@@ -527,31 +474,6 @@ def test_resize_embeddings_untied(self):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- # overwrite because IDEFICS needs ids and embeds at the input to be not None
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- for model_class in self.all_model_classes:
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
- pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
-
- wte = model.get_input_embeddings()
-
- input_ids = inputs["input_ids"]
- # some models infer position ids/attn mask differently when input ids
- # by check if pad_token let's make sure no padding is in input ids
- not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
- input_ids[input_ids == pad_token_id] = not_pad_token_id
- del inputs["input_ids"]
- inputs_embeds = wte(input_ids)
- out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
- out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
-
- self.assertTrue(torch.allclose(out_embeds, out_ids))
-
@require_torch
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py
index 251f293f722661..ef0b5831587be1 100644
--- a/tests/models/jamba/test_modeling_jamba.py
+++ b/tests/models/jamba/test_modeling_jamba.py
@@ -539,93 +539,6 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- r"""
- Overriding the test_flash_attn_2_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
- right padding + use cache with FA2
- """
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- r"""
- Overriding the test_flash_attn_2_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
- right padding + use cache with FA2
- """
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Jamba does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py
index a04d8bba741a23..dc510f0ff040bb 100644
--- a/tests/models/jetmoe/test_modeling_jetmoe.py
+++ b/tests/models/jetmoe/test_modeling_jetmoe.py
@@ -15,7 +15,6 @@
"""Testing suite for the PyTorch JetMoe model."""
import gc
-import tempfile
import unittest
import pytest
@@ -377,85 +376,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: JetMoe apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py
index f1ec65113e1843..43266a750b8d6c 100644
--- a/tests/models/kosmos2/test_modeling_kosmos2.py
+++ b/tests/models/kosmos2/test_modeling_kosmos2.py
@@ -446,12 +446,6 @@ def check_same_values(layer_1, layer_2):
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
- @unittest.skip(
- "KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
- )
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- pass
-
@slow
def test_model_from_pretrained(self):
model_name = "microsoft/kosmos-2-patch14-224"
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index 824337d8bdda01..375ec1dd3e6f3a 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -26,7 +26,6 @@
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
backend_empty_cache,
- require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@@ -316,9 +315,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
-
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
@@ -585,43 +581,6 @@ def _reinitialize_config(base_config, new_kwargs):
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
- @require_flash_attn
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @require_read_token
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- model = LlamaForCausalLM.from_pretrained(
- "meta-llama/Llama-2-7b-hf",
- load_in_4bit=True,
- device_map={"": 0},
- )
-
- tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
-
- texts = ["hi", "Hello this is a very long sentence"]
-
- tokenizer.padding_side = "right"
- tokenizer.pad_token = tokenizer.eos_token
-
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_native = tokenizer.batch_decode(output_native)
-
- model = LlamaForCausalLM.from_pretrained(
- "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py
index 0401f0e68bf72e..2051218f8db398 100644
--- a/tests/models/llava/test_modeling_llava.py
+++ b/tests/models/llava/test_modeling_llava.py
@@ -239,6 +239,35 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py
index bfc56edbdb91b9..b19f3ec95a8c2a 100644
--- a/tests/models/llava_next/test_modeling_llava_next.py
+++ b/tests/models/llava_next/test_modeling_llava_next.py
@@ -284,6 +284,38 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:1]
+ image_sizes = input_dict["image_sizes"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py
index 05fc8a49e1e9b9..edf1dd2d4c07a4 100644
--- a/tests/models/llava_next_video/test_modeling_llava_next_video.py
+++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py
@@ -303,6 +303,38 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:1]
+ image_sizes = input_dict["image_sizes"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py
index 1a8cf04774531f..9b3a9563b58ddc 100644
--- a/tests/models/mamba2/test_modeling_mamba2.py
+++ b/tests/models/mamba2/test_modeling_mamba2.py
@@ -204,8 +204,8 @@ def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
- @parameterized.expand([(1,), (2,)])
- def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
+ @parameterized.expand([("greedy", 1), ("beam search", 2)])
+ def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
@@ -276,12 +276,6 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- @unittest.skip(
- reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
- )
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- pass
-
@require_torch
@slow
diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py
index 074dceae155214..df0007d666a077 100644
--- a/tests/models/mimi/test_modeling_mimi.py
+++ b/tests/models/mimi/test_modeling_mimi.py
@@ -21,7 +21,6 @@
import numpy as np
from datasets import Audio, load_dataset
-from packaging import version
from parameterized import parameterized
from pytest import mark
@@ -745,22 +744,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
def test_sdpa_can_compile_dynamic(self):
pass
- # For now, Let's focus only on GPU for `torch.compile`
- @slow
- @require_torch_gpu
- def test_torch_compile(self):
- if version.parse(torch.__version__) < version.parse("2.3"):
- self.skipTest(reason="This test requires torch >= 2.3 to run.")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- n_iter = 3
- for model_class in self.all_model_classes:
- model = model_class(config).to(torch_device)
- model.forward = torch.compile(model.forward)
- for i in range(n_iter):
- _ = model(inputs_dict["input_values"].to(torch_device))
-
@is_flaky()
def test_batching_equivalence(self):
super().test_batching_equivalence()
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index f2ee714bcdbafc..1538735ad78bd7 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Mistral model."""
import gc
-import tempfile
import unittest
import pytest
@@ -416,85 +415,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Mistral apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py
index b9b5faed851fe4..931bb1f17beccf 100644
--- a/tests/models/mixtral/test_modeling_mixtral.py
+++ b/tests/models/mixtral/test_modeling_mixtral.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Mixtral model."""
-import tempfile
import unittest
import pytest
@@ -415,85 +414,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Mixtral apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py
index 8df8197a2c3ccd..ac46930895f360 100644
--- a/tests/models/mllama/test_modeling_mllama.py
+++ b/tests/models/mllama/test_modeling_mllama.py
@@ -126,7 +126,6 @@ class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
- _torch_compile_test_ckpt = "nltpt/Llama-3.2-11B-Vision"
def setUp(self):
self.model_tester = MllamaText2TextModelTester(self)
diff --git a/tests/models/mobilebert/test_modeling_mobilebert.py b/tests/models/mobilebert/test_modeling_mobilebert.py
index d7a409427c9c51..d2bc11d09f1797 100644
--- a/tests/models/mobilebert/test_modeling_mobilebert.py
+++ b/tests/models/mobilebert/test_modeling_mobilebert.py
@@ -16,7 +16,9 @@
import unittest
-from transformers import MobileBertConfig, is_torch_available
+from packaging import version
+
+from transformers import AutoTokenizer, MobileBertConfig, MobileBertForMaskedLM, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
@@ -384,3 +386,42 @@ def test_inference_no_head(self):
upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE)
self.assertTrue(lower_bound and upper_bound)
+
+ @slow
+ def test_export(self):
+ if version.parse(torch.__version__) < version.parse("2.4.0"):
+ self.skipTest(reason="This test requires torch >= 2.4 to run.")
+
+ mobilebert_model = "google/mobilebert-uncased"
+ device = "cpu"
+ attn_implementation = "eager"
+ max_length = 512
+
+ tokenizer = AutoTokenizer.from_pretrained(mobilebert_model)
+ inputs = tokenizer(
+ f"the man worked as a {tokenizer.mask_token}.",
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ )
+
+ model = MobileBertForMaskedLM.from_pretrained(
+ mobilebert_model,
+ device_map=device,
+ attn_implementation=attn_implementation,
+ )
+
+ logits = model(**inputs).logits
+ eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
+ self.assertEqual(eg_predicted_mask.split(), ["carpenter", "waiter", "mechanic", "teacher", "clerk"])
+
+ exported_program = torch.export.export(
+ model,
+ args=(inputs["input_ids"],),
+ kwargs={"attention_mask": inputs["attention_mask"]},
+ strict=True,
+ )
+
+ result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
+ ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
+ self.assertEqual(eg_predicted_mask, ep_predicted_mask)
diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py
index b77a6ff10364ca..7d4b855c10d8bf 100644
--- a/tests/models/moshi/test_modeling_moshi.py
+++ b/tests/models/moshi/test_modeling_moshi.py
@@ -560,7 +560,7 @@ def _get_input_ids_and_config(self, batch_size=2):
return config, input_ids, attention_mask, inputs_dict
def prepare_config_and_inputs_for_generate(self, batch_size=2):
- config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
+ config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
# Make sure we only return `input_ids`.
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
@@ -591,9 +591,11 @@ def _check_hidden_states_for_generate(
[expected_shape] * len(iter_hidden_states),
)
- def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
+ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
- super()._check_outputs(output, input_ids, config, use_cache=True, num_return_sequences=num_return_sequences)
+ super()._check_outputs(
+ output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
+ )
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
@@ -655,59 +657,6 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
- @pytest.mark.generate
- @parameterized.expand([(1,), (2,)])
- def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
- for model_class in self.all_generative_model_classes:
- config, input_ids, _, inputs_dict = self._get_input_ids_and_config()
-
- model = model_class(config).to(torch_device).eval()
- generation_kwargs = {
- "return_dict_in_generate": True,
- "output_scores": True,
- "num_beams": num_beams,
- "do_sample": False,
- }
-
- # Traditional way of generating text
- outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs, **inputs_dict)
- self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
-
- # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
- inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs_from_embeds = model.generate(
- input_ids,
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- **generation_kwargs,
- **inputs_dict,
- )
-
- # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
- # same, but the logits will almost surely be different)
- random_embeds = torch.rand_like(inputs_embeds)
- outputs_from_rand_embeds = model.generate(
- input_ids,
- inputs_embeds=random_embeds,
- max_new_tokens=5,
- **generation_kwargs,
- **inputs_dict,
- )
- for i in range(len(outputs_from_rand_embeds.scores)):
- self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
-
- # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
- outputs_from_embeds_wo_ids = model.generate(
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- **generation_kwargs,
- **inputs_dict,
- )
- self.assertListEqual(
- outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
- outputs_from_embeds_wo_ids.sequences.tolist(),
- )
-
@unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
def test_generate_continue_from_past_key_values(self):
pass
diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py
index 20412da2e1db06..1628d3a5893eaa 100644
--- a/tests/models/mt5/test_modeling_mt5.py
+++ b/tests/models/mt5/test_modeling_mt5.py
@@ -576,9 +576,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small MT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "google/mt5-small"
-
def setUp(self):
self.model_tester = MT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py
index 346ad60debe23f..963cace28d6e41 100644
--- a/tests/models/musicgen/test_modeling_musicgen.py
+++ b/tests/models/musicgen/test_modeling_musicgen.py
@@ -450,144 +450,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
- def test_flash_attn_2_generate_left_padding(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # make sure we do left padding
- dummy_attention_mask[:, :-1] = 0
- dummy_attention_mask[:, -1:] = 1
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
- def test_flash_attn_2_generate_padding_right(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # make sure we do right padding
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
- def test_flash_attn_2_generate_use_cache(self):
- max_new_tokens = 30
-
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -1585,149 +1447,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
- def test_flash_attn_2_generate_left_padding(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask")
- if dummy_attention_mask is None:
- dummy_attention_mask = torch.ones_like(dummy_input)
-
- # make sure we do left padding
- dummy_attention_mask[:, :-1] = 0
- dummy_attention_mask[:, -1:] = 1
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
- def test_flash_attn_2_generate_padding_right(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask")
- if dummy_attention_mask is None:
- dummy_attention_mask = torch.ones_like(dummy_input)
- # make sure we do right padding
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
- def test_flash_attn_2_generate_use_cache(self):
- max_new_tokens = 30
-
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:
diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
index f3b6be0ac652eb..957db9f23b0f21 100644
--- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
@@ -1437,149 +1437,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
- def test_flash_attn_2_generate_left_padding(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask")
- if dummy_attention_mask is None:
- dummy_attention_mask = torch.ones_like(dummy_input)
-
- # make sure we do left padding
- dummy_attention_mask[:, :-1] = 0
- dummy_attention_mask[:, -1:] = 1
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
- def test_flash_attn_2_generate_padding_right(self):
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask")
- if dummy_attention_mask is None:
- dummy_attention_mask = torch.ones_like(dummy_input)
- # make sure we do right padding
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
- def test_flash_attn_2_generate_use_cache(self):
- max_new_tokens = 30
-
- # Ignore copy
- for model_class in self.greedy_sample_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:
diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py
index 13adfe1e579489..37a581a33866ce 100644
--- a/tests/models/nemotron/test_modeling_nemotron.py
+++ b/tests/models/nemotron/test_modeling_nemotron.py
@@ -92,8 +92,6 @@ class NemotronModelTest(GemmaModelTest):
test_pruning = False
fx_compatible = False
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py
index 7818f174508fc8..76a1375f754a16 100644
--- a/tests/models/paligemma/test_modeling_paligemma.py
+++ b/tests/models/paligemma/test_modeling_paligemma.py
@@ -237,6 +237,36 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@@ -317,10 +347,6 @@ def test_save_load_low_cpu_mem_usage_no_safetensors(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
- @unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results")
- def test_static_cache_matches_dynamic(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py
index c17f69a499866b..eae6789bef252e 100644
--- a/tests/models/phi/test_modeling_phi.py
+++ b/tests/models/phi/test_modeling_phi.py
@@ -17,15 +17,11 @@
import unittest
-import pytest
from parameterized import parameterized
from transformers import PhiConfig, is_torch_available, set_seed
from transformers.testing_utils import (
- require_bitsandbytes,
- require_flash_attn,
require_torch,
- require_torch_gpu,
slow,
torch_device,
)
@@ -468,43 +464,6 @@ def test_model_rope_scaling(self):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
- @require_flash_attn
- @require_torch_gpu
- @require_bitsandbytes
- @pytest.mark.flash_attn_test
- @slow
- # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
- def test_flash_attn_2_generate_padding_right(self):
- """
- Overwritting the common test as the test is flaky on tiny models
- """
- model = PhiForCausalLM.from_pretrained(
- "microsoft/phi-1",
- load_in_4bit=True,
- device_map={"": 0},
- )
-
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
-
- texts = ["hi", "Hello this is a very long sentence"]
-
- tokenizer.padding_side = "right"
- tokenizer.pad_token = tokenizer.eos_token
-
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
-
- output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_native = tokenizer.batch_decode(output_native)
-
- model = PhiForCausalLM.from_pretrained(
- "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
- )
-
- output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_fa_2 = tokenizer.batch_decode(output_fa_2)
-
- self.assertListEqual(output_native, output_fa_2)
-
@slow
@require_torch
diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py
index 9a128f6ad28823..0c36cb5a4e0554 100644
--- a/tests/models/pixtral/test_modeling_pixtral.py
+++ b/tests/models/pixtral/test_modeling_pixtral.py
@@ -14,22 +14,16 @@
# limitations under the License.
"""Testing suite for the PyTorch Pixtral model."""
-import gc
import unittest
-import requests
-
from transformers import (
- AutoProcessor,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
- require_bitsandbytes,
require_torch,
- slow,
torch_device,
)
@@ -43,7 +37,7 @@
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
- from PIL import Image
+ pass
class PixtralVisionModelTester:
@@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
+ test_torchscript = False
def setUp(self):
self.model_tester = PixtralVisionModelTester(self)
@@ -258,35 +253,3 @@ def test_disk_offload_safetensors(self):
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass
-
-
-@require_torch
-class PixtralVisionModelIntegrationTest(unittest.TestCase):
- def setUp(self):
- self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
-
- def tearDown(self):
- gc.collect()
- torch.cuda.empty_cache()
-
- @slow
- @require_bitsandbytes
- def test_small_model_integration_test(self):
- # Let' s make sure we test the preprocessing to replace what is used
- model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
-
- prompt = "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
- image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
- raw_image = Image.open(requests.get(image_file, stream=True).raw)
- inputs = self.processor(prompt, raw_image, return_tensors="pt")
-
- EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
- self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
-
- output = model.generate(**inputs, max_new_tokens=20)
- EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
-
- self.assertEqual(
- self.processor.decode(output[0], skip_special_tokens=True),
- EXPECTED_DECODED_TEXT,
- )
diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py
index 8cdbf93c6476b8..c3496dff3cdf81 100644
--- a/tests/models/pixtral/test_processor_pixtral.py
+++ b/tests/models/pixtral/test_processor_pixtral.py
@@ -171,7 +171,7 @@ def test_processor_with_multiple_images_single_list(self):
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
- )
+ )
# fmt: on
# Test passing in a url
@@ -246,6 +246,25 @@ def test_processor_with_multiple_images_multiple_lists(self):
)
# fmt: on
+ def test_processor_returns_full_length_batches(self):
+ # to avoid https://github.com/huggingface/transformers/issues/34204
+ processor = self.processor_class.from_pretrained(self.tmpdirname)
+ prompt_string = [
+ "USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
+ ] * 5
+ processor.tokenizer.pad_token = ""
+ image_inputs = [self.image_0] * 5
+
+ # Make small for checking image token expansion
+ processor.image_processor.size = {"longest_edge": 30}
+ processor.image_processor.patch_size = {"height": 2, "width": 2}
+
+ # Test passing in an image
+ inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
+ self.assertIn("input_ids", inputs_image)
+ self.assertTrue(len(inputs_image["input_ids"]) == 5)
+ self.assertTrue(len(inputs_image["pixel_values"]) == 5)
+
# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py
index 4e57f8e0f002fb..f51dc2e0a5e26f 100644
--- a/tests/models/qwen2/test_modeling_qwen2.py
+++ b/tests/models/qwen2/test_modeling_qwen2.py
@@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Qwen2 model."""
import gc
-import tempfile
import unittest
import pytest
@@ -428,85 +427,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Qwen2 apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
index c545e882faeeb3..abc7b57919b083 100644
--- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
+++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
@@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Qwen2MoE model."""
import gc
-import tempfile
import unittest
import pytest
@@ -453,85 +452,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
index 76951d07fce2fb..6c04ba40df19d6 100644
--- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
@@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
def __init__(
self,
parent,
- batch_size=2,
+ batch_size=3,
seq_length=7,
num_channels=3,
ignore_index=-100,
@@ -246,6 +246,40 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ patch_size = config.vision_config.patch_size
+ one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
+ input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
+ input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:one_img_length]
+ image_grid_thw = input_dict["image_grid_thw"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
@@ -268,10 +302,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self):
def test_feed_forward_chunking(self):
pass
- @unittest.skip(reason="Generate needs input ids")
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- pass
-
@unittest.skip(reason="CPU offload is not yet supported")
def test_cpu_offload(self):
pass
diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
index 542955f9fa4511..985115d7707b6e 100644
--- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
+++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
@@ -420,10 +420,6 @@ def _check_hidden_states_for_generate(
def test_initialization(self):
pass
- @unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- pass
-
@require_torch_accelerator
@slow
diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py
index ca557937803cff..1c128513b17d13 100644
--- a/tests/models/roberta/test_modeling_roberta.py
+++ b/tests/models/roberta/test_modeling_roberta.py
@@ -16,7 +16,7 @@
import unittest
-from transformers import RobertaConfig, is_torch_available
+from transformers import AutoTokenizer, RobertaConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
@@ -41,6 +41,7 @@
RobertaEmbeddings,
create_position_ids_from_input_ids,
)
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
@@ -576,3 +577,43 @@ def test_inference_classification_head(self):
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
+
+ @slow
+ def test_export(self):
+ if not is_torch_greater_or_equal_than_2_4:
+ self.skipTest(reason="This test requires torch >= 2.4 to run.")
+
+ roberta_model = "FacebookAI/roberta-base"
+ device = "cpu"
+ attn_implementation = "sdpa"
+ max_length = 512
+
+ tokenizer = AutoTokenizer.from_pretrained(roberta_model)
+ inputs = tokenizer(
+ "The goal of life is .",
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ )
+
+ model = RobertaForMaskedLM.from_pretrained(
+ roberta_model,
+ device_map=device,
+ attn_implementation=attn_implementation,
+ use_cache=True,
+ )
+
+ logits = model(**inputs).logits
+ eager_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
+ self.assertEqual(eager_predicted_mask.split(), ["happiness", "love", "peace", "freedom", "simplicity"])
+
+ exported_program = torch.export.export(
+ model,
+ args=(inputs["input_ids"],),
+ kwargs={"attention_mask": inputs["attention_mask"]},
+ strict=True,
+ )
+
+ result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
+ exported_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
+ self.assertEqual(eager_predicted_mask, exported_predicted_mask)
diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py
index 32d28143d72ffa..df743f132c1140 100644
--- a/tests/models/starcoder2/test_modeling_starcoder2.py
+++ b/tests/models/starcoder2/test_modeling_starcoder2.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Starcoder2 model."""
-import tempfile
import unittest
import pytest
@@ -404,85 +403,6 @@ def test_save_load_fast_init_from_base(self):
def test_past_key_values_format(self):
pass
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Starcoder2 apparently does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py
index 68dd5a52b3d69b..b03416390766d0 100644
--- a/tests/models/t5/test_modeling_t5.py
+++ b/tests/models/t5/test_modeling_t5.py
@@ -580,9 +580,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small T5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "google-t5/t5-small"
-
def setUp(self):
self.model_tester = T5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py
index ec4c1d019b6d17..377668851c5815 100644
--- a/tests/models/umt5/test_modeling_umt5.py
+++ b/tests/models/umt5/test_modeling_umt5.py
@@ -317,9 +317,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The small UMT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
- # used in `test_torch_compile`
- _torch_compile_test_ckpt = "google/umt5-small"
-
def setUp(self):
self.model_tester = UMT5ModelTester(self)
diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py
index fd4c49f4a6966d..e25ad1d44460c7 100644
--- a/tests/models/video_llava/test_modeling_video_llava.py
+++ b/tests/models/video_llava/test_modeling_video_llava.py
@@ -123,9 +123,9 @@ def __init__(
self.batch_size = 5
self.num_channels = 3
self.image_size = 224
- self.encoder_seq_length = 64
+ self.encoder_seq_length = 246
self.num_image_tokens = 25
- self.num_video_tokens = 26
+ self.num_video_tokens = 26 * self.num_frames
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
def get_config(self):
@@ -267,7 +267,7 @@ def test_mixed_input(self):
# if we remove some images from inputs leaving only one
# image number mismatch error should raise
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
- with self.assertRaises(RuntimeError):
+ with self.assertRaises(ValueError):
_ = model(**inputs)
def test_video_only_input(self):
@@ -401,6 +401,35 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values_images"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values_images=pixel_values)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ _ = model(input_ids=input_ids, pixel_values_images=pixel_values)
+
@require_torch
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py
index 0cb9c2751e78d6..c9946422fde9d5 100644
--- a/tests/models/vipllava/test_modeling_vipllava.py
+++ b/tests/models/vipllava/test_modeling_vipllava.py
@@ -218,6 +218,36 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
+ # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that VLMs through an error with explicit message saying what is wrong
+ when number of images don't match number of image tokens in the text.
+ Also we need to test multi-image cases when one prompr has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ _ = model(**input_dict) # successfull forward with no modifications
+
+ # remove one image but leave the image token in text
+ input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**input_dict)
+
+ # simulate multi-image case by concatenating inputs where each has exactly one image/image-token
+ input_ids = input_dict["input_ids"][:1]
+ pixel_values = input_dict["pixel_values"][:1]
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+
+ # one image and two image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
+ # two images and two image tokens don't raise an error
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index b24c577a16e575..12aedaca8cf986 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -1574,59 +1574,6 @@ def test_generate_output_type(self, return_dict_in_generate):
)
assert isinstance(pred_ids, expected_output_type)
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_reuse_cache(self):
- max_new_tokens = 2
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name][..., :10]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # run generate once to get filled cache
- output = model.generate(
- dummy_input,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- return_dict_in_generate=True,
- )
- past_key_values = output.past_key_values
-
- # Try to continue generation from where we left, given that we have more than 1 new token to process
- # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
- _ = model.generate(
- dummy_input,
- decoder_input_ids=output.sequences,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- past_key_values=past_key_values,
- )
-
def test_labels_sequence_max_length_correct(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -3961,11 +3908,6 @@ def test_generate_without_input_ids(self):
# generate only works with input ids for whisper
pass
- @unittest.skip(reason="Generate needs input ids")
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- # generate only works with input ids for whisper
- pass
-
@unittest.skip(reason="Decoder can't keep attention grads")
def test_retain_grad_hidden_states_attentions(self):
return
@@ -3974,18 +3916,6 @@ def test_retain_grad_hidden_states_attentions(self):
def test_save_load_fast_init_from_base(self):
pass
- @unittest.skip(
- reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
- )
- def test_flash_attn_2_generate_reuse_cache(self):
- pass
-
- @unittest.skip(
- "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
- )
- def test_flash_attn_2_generate_padding_right(self):
- pass
-
@unittest.skip(
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py
index c0a8020bedd76a..a6dd516f98a412 100644
--- a/tests/models/zamba/test_modeling_zamba.py
+++ b/tests/models/zamba/test_modeling_zamba.py
@@ -542,93 +542,6 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- r"""
- Overriding the test_flash_attn_2_generate_padding_right test as the Zamba model, like Mixtral, doesn't support
- right padding + use cache with FA2
- """
- import torch
-
- for model_class in self.all_generative_model_classes:
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
- dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
-
- model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- with self.assertRaises(ValueError):
- _ = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- r"""
- Overriding the test_flash_attn_2_generate_use_cache test as the Zamba model, like Mixtral, doesn't support
- right padding + use cache with FA2
- """
- import torch
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # NOTE: Zamba does not support right padding + use_cache with FA2.
- dummy_attention_mask[:, -1] = 1
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index d88b0dc5f02f83..e2719d8cf1b600 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -22,7 +22,6 @@
import random
import re
import tempfile
-import time
import warnings
from collections import defaultdict
from contextlib import contextmanager
@@ -37,10 +36,7 @@
from transformers import (
AutoModel,
AutoModelForCausalLM,
- AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
- AutoTokenizer,
- GenerationConfig,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
@@ -86,7 +82,6 @@
require_deepspeed,
require_flash_attn,
require_non_xpu,
- require_read_token,
require_safetensors,
require_torch,
require_torch_accelerator,
@@ -3000,71 +2995,6 @@ def test_inputs_embeds_matches_input_ids(self):
)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
- def test_inputs_embeds_matches_input_ids_with_generate(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- for model_class in self.all_generative_model_classes:
- if model_class.__name__ not in [
- *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
- *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
- ]:
- continue
-
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- model_forward_args = inspect.signature(model.forward).parameters
- if any(argument not in model_forward_args for argument in ["inputs_embeds", "position_ids"]):
- self.skipTest(reason="This model doesn't use `inputs_embeds` or `position_ids`.")
- has_inputs_embeds_forwarding = "inputs_embeds" in set(
- inspect.signature(model.prepare_inputs_for_generation).parameters.keys()
- )
- if not has_inputs_embeds_forwarding:
- self.skipTest(reason="This model doesn't support `inputs_embeds` passed to `generate`.")
- inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
- pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
-
- # VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
- # embeds already
- if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
- inputs.pop("pixel_values", None)
- inputs.pop("pixel_values_videos", None)
- inputs.pop("pixel_values_images", None)
-
- wte = model.get_input_embeddings()
- if not self.is_encoder_decoder:
- input_ids = inputs["input_ids"]
- # some models infer position ids/attn mask differently when input ids
- # by check if pad_token let's make sure no padding is in input ids
- not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
- input_ids[input_ids == pad_token_id] = not_pad_token_id
- del inputs["input_ids"]
- inputs_embeds = wte(input_ids)
- out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
- out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
- else:
- encoder_input_ids = inputs["input_ids"]
- decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
- encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
- decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
- del inputs["input_ids"]
- inputs.pop("decoder_input_ids", None)
- inputs_embeds = wte(encoder_input_ids)
- decoder_inputs_embeds = wte(decoder_input_ids)
- out_ids = model.generate(
- input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
- )[:, -2:]
- out_embeds = model.generate(
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- **inputs,
- max_new_tokens=2,
- )
- # NOTE: this test changes the order of FP ops, there may be tiny differences in the output
- number_of_different_tokens = (out_ids != out_embeds).sum()
- max_differences = int(out_ids.shape[0] * out_ids.shape[1] * 0.1)
- self.assertTrue(number_of_different_tokens <= max_differences) # accept up to 10% mismatch
-
@require_non_xpu
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
@@ -3857,102 +3787,6 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- @is_flaky()
- def test_flash_attn_2_generate_left_padding(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # make sure we do left padding
- dummy_attention_mask[:, :-1] = 0
- dummy_attention_mask[:, -1:] = 1
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @is_flaky()
- @slow
- def test_flash_attn_2_generate_padding_right(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
- torch_device
- )
-
- dummy_input = inputs_dict[model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- # make sure we do right padding
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
-
- out = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- out_fa = model.generate(
- dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
- )
-
- self.assertTrue(torch.allclose(out, out_fa))
-
def test_attn_implementation_composite_models(self):
"""
Tests if composite models can receive a dict object as attn_implementation, where each key should be
@@ -4525,65 +4359,6 @@ def test_sdpa_matches_eager_sliding_window(self):
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_use_cache(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # Just test that a large cache works as expected
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
- # Generate with one batch only to test generation when attention mask will be None
- # when real inputs are used, because there is no padding. See issue #32237 for more
- dummy_input = dummy_input[:1, ...]
- dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
- _ = model.generate(
- dummy_input,
- attention_mask=dummy_attention_mask,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- )
-
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4640,62 +4415,6 @@ def test_flash_attn_2_can_dispatch_composite_models(self):
if not has_fa2:
raise ValueError("The FA2 model should have FA2 layers")
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- def test_flash_attn_2_generate_reuse_cache(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- max_new_tokens = 2
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- low_cpu_mem_usage=True,
- ).to(torch_device)
-
- # run generate once to get filled cache
- output = model.generate(
- dummy_input,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- return_dict_in_generate=True,
- )
- past_key_values = output.past_key_values
-
- # Try to continue generation from where we left, given that we have more than 1 new token to process
- # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
- dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
- _ = model.generate(
- dummy_input_updated,
- max_new_tokens=max_new_tokens,
- do_sample=False,
- use_cache=True,
- past_key_values=past_key_values,
- )
-
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@@ -4999,82 +4718,6 @@ def test_custom_4d_attention_mask(self):
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
- def test_static_cache_matches_dynamic(self):
- """
- Tests that generating with static cache give almost same results as with dynamic cache.
- This test does not compile the model and check only logits similarity for numerical precision
- errors.
- """
- if len(self.all_generative_model_classes) == 0:
- self.skipTest(
- reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
- )
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
- self.skipTest(f"{model_class.__name__} does not support static cache")
-
- if not model_class._supports_cache_class:
- self.skipTest(f"{model_class.__name__} does not support cache class")
-
- config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
- if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
- self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
-
- model = model_class(config).to(device=torch_device, dtype=torch.float32)
- model.eval()
-
- dynamic_out = model.generate(
- **inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
- )
- static_out = model.generate(
- **inputs,
- do_sample=False,
- max_new_tokens=10,
- cache_implementation="static",
- output_logits=True,
- return_dict_in_generate=True,
- )
- self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
-
- # For now, Let's focus only on GPU for `torch.compile`
- @slow
- @require_torch_accelerator
- @require_read_token
- def test_torch_compile(self):
- if version.parse(torch.__version__) < version.parse("2.3"):
- self.skipTest(reason="This test requires torch >= 2.3 to run.")
- torch.compiler.reset()
- if not hasattr(self, "_torch_compile_test_ckpt"):
- self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
- ckpt = self._torch_compile_test_ckpt
- revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
-
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
-
- batch_size = 1
- n_iter = 3
-
- tokenizer = AutoTokenizer.from_pretrained(ckpt)
- if self.is_encoder_decoder:
- model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
- torch_device
- )
- else:
- model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
- torch_device
- )
-
- model.generation_config.max_new_tokens = 4
-
- model.generation_config.cache_implementation = "static"
- model.forward = torch.compile(model.forward, mode="reduce-overhead")
-
- input_text = "Why dogs are cute?"
- input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
-
- for i in range(n_iter):
- _ = model.generate(**input_ids, do_sample=False)
-
@slow
@require_torch_gpu
def test_torch_compile_for_training(self):
@@ -5118,74 +4761,6 @@ def test_torch_compile_for_training(self):
for name, param in model._orig_mod.named_parameters():
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
- @slow
- @require_torch_gpu # Testing cuda graphs.
- @require_read_token
- def test_compile_cuda_graph_time(self):
- if version.parse(torch.__version__) < version.parse("2.3"):
- self.skipTest(reason="This test requires torch >= 2.3 to run.")
-
- # TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
- # At the moment, only llama, gemma and gemma2 are tested here!
- if not hasattr(self, "_torch_compile_test_ckpt"):
- self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
- ckpt = self._torch_compile_test_ckpt
- revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
-
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
-
- tokenizer = AutoTokenizer.from_pretrained(ckpt)
- if self.is_encoder_decoder:
- model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
- torch_device
- )
- else:
- model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
- torch_device
- )
-
- cache_implementation = "static"
- if model.config.model_type == "gemma2":
- cache_implementation = "hybrid"
-
- new_tokens = 50
- gen_config = GenerationConfig(
- max_new_tokens=new_tokens,
- min_new_tokens=new_tokens,
- use_cache=True,
- pad_token_id=tokenizer.pad_token_id,
- num_beams=1,
- do_sample=False,
- eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
- )
- model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
-
- model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
-
- inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
-
- # First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
- start = time.perf_counter()
- _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
- end = time.perf_counter()
- graph_warmup_time = end - start
-
- # Second run: CUDA Graph recording, and replays it
- start = time.perf_counter()
- _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
- end = time.perf_counter()
- record_time = end - start
-
- # Finally: we hit the optimized, CUDA Graph replay path
- start = time.perf_counter()
- _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
- end = time.perf_counter()
- opt_time = end - start
-
- # For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
- self.assertTrue(record_time < 0.15 * graph_warmup_time)
- self.assertTrue(opt_time < record_time)
-
def test_forward_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py
index acc8ba79bb43a0..94bc3d5fae1ad2 100644
--- a/tests/test_pipeline_mixin.py
+++ b/tests/test_pipeline_mixin.py
@@ -930,6 +930,8 @@ def parse_args_from_docstring_by_indentation(docstring):
def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
+ ALLOWED_TRANSFORMERS_ONLY_ARGS = ["timeout"]
+
docstring = inspect.getdoc(pipeline_class.__call__).strip()
docstring_args = set(parse_args_from_docstring_by_indentation(docstring))
hub_args = set(get_arg_names_from_hub_spec(hub_spec))
@@ -947,6 +949,11 @@ def compare_pipeline_args_to_hub_spec(pipeline_class, hub_spec):
hub_args.remove(js_generate_args[0])
docstring_args.remove(docstring_generate_args[0])
+ # Special casing 2: We permit some transformers-only arguments that don't affect pipeline output
+ for arg in ALLOWED_TRANSFORMERS_ONLY_ARGS:
+ if arg in docstring_args and arg not in hub_args:
+ docstring_args.remove(arg)
+
if hub_args != docstring_args:
error = [f"{pipeline_class.__name__} differs from JS spec {hub_spec.__name__}"]
matching_args = hub_args & docstring_args