diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index 0fd441bc05fe4a..2b74dca91f30bc 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 LABEL maintainer="Hugging Face" ARG DEBIAN_FRONTEND=noninteractive diff --git a/docs/source/en/model_doc/gemma.md b/docs/source/en/model_doc/gemma.md index f55995b6d85b6a..abd077af8da170 100644 --- a/docs/source/en/model_doc/gemma.md +++ b/docs/source/en/model_doc/gemma.md @@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [ [[autodoc]] GemmaForSequenceClassification - forward +## GemmaForTokenClassification + +[[autodoc]] GemmaForTokenClassification + - forward + ## FlaxGemmaModel [[autodoc]] FlaxGemmaModel diff --git a/docs/source/en/model_doc/llama.md b/docs/source/en/model_doc/llama.md index 915d5ecc70b554..2f0eb63da00a84 100644 --- a/docs/source/en/model_doc/llama.md +++ b/docs/source/en/model_doc/llama.md @@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] LlamaForQuestionAnswering - forward +## LlamaForTokenClassification + +[[autodoc]] LlamaForTokenClassification + - forward + ## FlaxLlamaModel [[autodoc]] FlaxLlamaModel diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index a2a3913fcad7b8..a4a1419ee00ac8 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -68,6 +68,8 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ ## Usage example +### Single image inference + Here's how to load the model and perform inference in half-precision (`torch.float16`): ```python @@ -94,6 +96,45 @@ output = model.generate(**inputs, max_new_tokens=100) print(processor.decode(output[0], skip_special_tokens=True)) ``` +### Multi image inference + +LLaVa-Next can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it: + +```python +import requests +from PIL import Image +import torch +from transformers import AutoProcessor, LlavaNextForConditionalGeneration + +# Load the model in half-precision +model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, device_map="auto") +processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + +# Get three different images +url = "https://www.ilankelman.org/stopsigns/australia.jpg" +image_stop = Image.open(requests.get(url, stream=True).raw) + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image_cats = Image.open(requests.get(url, stream=True).raw) + +url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +image_snowman = Image.open(requests.get(url, stream=True).raw) + +# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not +prompt = [ + "[INST] \nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] \nWhat about this image? How many cats do you see [/INST]", + "[INST] \nWhat is shown in this image? [/INST]" +] + +# We can simply feed images in the order they have to be used in the text prompt +# Each "" token uses one image leaving the next for the subsequent "" tokens +inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) + +# Generate +generate_ids = model.generate(**inputs, max_new_tokens=30) +processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) +``` + ## Model optimization ### Quantization using Bitsandbytes diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index 0ab214206165f1..d4bc76106083c6 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MistralForSequenceClassification - forward +## MistralForTokenClassification + +[[autodoc]] MistralForTokenClassification + - forward + ## FlaxMistralModel [[autodoc]] FlaxMistralModel diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md index 942b040c3f2fd5..b93acdec581525 100644 --- a/docs/source/en/model_doc/mixtral.md +++ b/docs/source/en/model_doc/mixtral.md @@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MixtralForSequenceClassification - forward + +## MixtralForTokenClassification + +[[autodoc]] MixtralForTokenClassification + - forward diff --git a/docs/source/en/model_doc/persimmon.md b/docs/source/en/model_doc/persimmon.md index fe9e66a0b7175e..7a105ac5543d60 100644 --- a/docs/source/en/model_doc/persimmon.md +++ b/docs/source/en/model_doc/persimmon.md @@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T [[autodoc]] PersimmonForSequenceClassification - forward + +## PersimmonForTokenClassification + +[[autodoc]] PersimmonForTokenClassification + - forward diff --git a/docs/source/en/model_doc/qwen2.md b/docs/source/en/model_doc/qwen2.md index 5f9e5dba22b844..ac0e25e02c35f9 100644 --- a/docs/source/en/model_doc/qwen2.md +++ b/docs/source/en/model_doc/qwen2.md @@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen [[autodoc]] Qwen2ForSequenceClassification - forward + +## Qwen2ForTokenClassification + +[[autodoc]] Qwen2ForTokenClassification + - forward diff --git a/docs/source/en/model_doc/qwen2_moe.md b/docs/source/en/model_doc/qwen2_moe.md index 8a546c4016ad5e..9c6dc80beb61e5 100644 --- a/docs/source/en/model_doc/qwen2_moe.md +++ b/docs/source/en/model_doc/qwen2_moe.md @@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf [[autodoc]] Qwen2MoeForSequenceClassification - forward + +## Qwen2MoeForTokenClassification + +[[autodoc]] Qwen2MoeForTokenClassification + - forward diff --git a/docs/source/en/model_doc/stablelm.md b/docs/source/en/model_doc/stablelm.md index 6a50995ca086e8..09c0e5855c3a1d 100644 --- a/docs/source/en/model_doc/stablelm.md +++ b/docs/source/en/model_doc/stablelm.md @@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below: [[autodoc]] StableLmForSequenceClassification - forward + +## StableLmForTokenClassification + +[[autodoc]] StableLmForTokenClassification + - forward diff --git a/docs/source/en/model_doc/starcoder2.md b/docs/source/en/model_doc/starcoder2.md index 9e2e547b8c3eae..1d107b3855564a 100644 --- a/docs/source/en/model_doc/starcoder2.md +++ b/docs/source/en/model_doc/starcoder2.md @@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu [[autodoc]] Starcoder2ForSequenceClassification - forward + +## Starcoder2ForTokenClassification + +[[autodoc]] Starcoder2ForTokenClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 18faabc807cfea..4255e303799442 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2031,6 +2031,7 @@ [ "GemmaForCausalLM", "GemmaForSequenceClassification", + "GemmaForTokenClassification", "GemmaModel", "GemmaPreTrainedModel", ] @@ -2288,6 +2289,7 @@ "LlamaForCausalLM", "LlamaForQuestionAnswering", "LlamaForSequenceClassification", + "LlamaForTokenClassification", "LlamaModel", "LlamaPreTrainedModel", ] @@ -2435,12 +2437,19 @@ [ "MistralForCausalLM", "MistralForSequenceClassification", + "MistralForTokenClassification", "MistralModel", "MistralPreTrainedModel", ] ) _import_structure["models.mixtral"].extend( - ["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"] + [ + "MixtralForCausalLM", + "MixtralForSequenceClassification", + "MixtralForTokenClassification", + "MixtralModel", + "MixtralPreTrainedModel", + ] ) _import_structure["models.mobilebert"].extend( [ @@ -2714,6 +2723,7 @@ [ "PersimmonForCausalLM", "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", "PersimmonModel", "PersimmonPreTrainedModel", ] @@ -2810,6 +2820,7 @@ [ "Qwen2ForCausalLM", "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", "Qwen2Model", "Qwen2PreTrainedModel", ] @@ -2818,6 +2829,7 @@ [ "Qwen2MoeForCausalLM", "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", "Qwen2MoeModel", "Qwen2MoePreTrainedModel", ] @@ -3066,6 +3078,7 @@ [ "StableLmForCausalLM", "StableLmForSequenceClassification", + "StableLmForTokenClassification", "StableLmModel", "StableLmPreTrainedModel", ] @@ -3074,6 +3087,7 @@ [ "Starcoder2ForCausalLM", "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", "Starcoder2Model", "Starcoder2PreTrainedModel", ] @@ -6489,6 +6503,7 @@ from .models.gemma import ( GemmaForCausalLM, GemmaForSequenceClassification, + GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, ) @@ -6686,6 +6701,7 @@ LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) @@ -6801,12 +6817,14 @@ from .models.mistral import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, MistralPreTrainedModel, ) from .models.mixtral import ( MixtralForCausalLM, MixtralForSequenceClassification, + MixtralForTokenClassification, MixtralModel, MixtralPreTrainedModel, ) @@ -7025,6 +7043,7 @@ from .models.persimmon import ( PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, PersimmonPreTrainedModel, ) @@ -7099,12 +7118,14 @@ from .models.qwen2 import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, Qwen2PreTrainedModel, ) from .models.qwen2_moe import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, Qwen2MoePreTrainedModel, ) @@ -7306,12 +7327,14 @@ from .models.stablelm import ( StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, StableLmPreTrainedModel, ) from .models.starcoder2 import ( Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, Starcoder2PreTrainedModel, ) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index fc7f782e348e8d..86b946d16267bd 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -32,6 +32,7 @@ CONFIG_NAME, PushToHubMixin, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -736,6 +737,10 @@ def _get_config_dict( config_dict["auto_map"] = add_model_info_to_auto_map( config_dict["auto_map"], pretrained_model_name_or_path ) + if "custom_pipelines" in config_dict and not is_local: + config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + config_dict["custom_pipelines"], pretrained_model_name_or_path + ) return config_dict, kwargs @classmethod diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 76070ebeb81b7a..44e01f8d85bac3 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -31,6 +31,7 @@ PushToHubMixin, TensorType, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -539,10 +540,15 @@ def get_feature_extractor_dict( f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" ) - if "auto_map" in feature_extractor_dict and not is_local: - feature_extractor_dict["auto_map"] = add_model_info_to_auto_map( - feature_extractor_dict["auto_map"], pretrained_model_name_or_path - ) + if not is_local: + if "auto_map" in feature_extractor_dict: + feature_extractor_dict["auto_map"] = add_model_info_to_auto_map( + feature_extractor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in feature_extractor_dict: + feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path + ) return feature_extractor_dict, kwargs diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index efd1e04a62106f..c42378d8f3a59e 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -31,6 +31,7 @@ IMAGE_PROCESSOR_NAME, PushToHubMixin, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -375,11 +376,15 @@ def get_image_processor_dict( f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" ) - if "auto_map" in image_processor_dict and not is_local: - image_processor_dict["auto_map"] = add_model_info_to_auto_map( - image_processor_dict["auto_map"], pretrained_model_name_or_path - ) - + if not is_local: + if "auto_map" in image_processor_dict: + image_processor_dict["auto_map"] = add_model_info_to_auto_map( + image_processor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in image_processor_dict: + image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + image_processor_dict["custom_pipelines"], pretrained_model_name_or_path + ) return image_processor_dict, kwargs @classmethod diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7825a0217fcc13..c16ab83bdc6b01 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1038,6 +1038,7 @@ ("flaubert", "FlaubertForTokenClassification"), ("fnet", "FNetForTokenClassification"), ("funnel", "FunnelForTokenClassification"), + ("gemma", "GemmaForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"), @@ -1048,11 +1049,14 @@ ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), + ("llama", "LlamaForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), ("mega", "MegaForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), + ("mistral", "MistralForTokenClassification"), + ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), ("mpt", "MptForTokenClassification"), @@ -1060,15 +1064,20 @@ ("mt5", "MT5ForTokenClassification"), ("nezha", "NezhaForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"), + ("persimmon", "PersimmonForTokenClassification"), ("phi", "PhiForTokenClassification"), ("phi3", "Phi3ForTokenClassification"), ("qdqbert", "QDQBertForTokenClassification"), + ("qwen2", "Qwen2ForTokenClassification"), + ("qwen2_moe", "Qwen2MoeForTokenClassification"), ("rembert", "RemBertForTokenClassification"), ("roberta", "RobertaForTokenClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), ("roc_bert", "RoCBertForTokenClassification"), ("roformer", "RoFormerForTokenClassification"), ("squeezebert", "SqueezeBertForTokenClassification"), + ("stablelm", "StableLmForTokenClassification"), + ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 33fea9a0183afe..bcc9ac3e07461b 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -592,6 +592,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -600,8 +605,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b516a97187b8bc..033dc6ba666461 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -428,9 +428,11 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d6f37af8dabae5..41c4e151a3da13 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -587,8 +587,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index e77bc728ab365c..8837c278c50981 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -788,6 +788,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -796,8 +801,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index bab552da35d424..b77ae4f8f8be55 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1616,8 +1616,8 @@ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes) valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = meshgrid( - torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device), - torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device), + torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device), + torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device), indexing="ij", ) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index e48e1ddfe14cb2..c1b27cd1806cd0 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -166,10 +166,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -180,7 +218,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -219,7 +260,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -227,6 +270,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -849,6 +897,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -899,6 +949,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, DonutSwinModelOutput]: r""" @@ -921,7 +972,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5e1e3d7ed909d1..75346601d75b41 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -434,16 +434,19 @@ def forward( if alibi is None: if self._use_sdpa and not output_attentions: - attn_output = F.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not + # create a causal mask in case query_length == 1. + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=is_causal, ) - attention_scores = None else: attention_scores = query_layer @ key_layer.transpose(-1, -2) @@ -466,13 +469,16 @@ def forward( else: if self._use_sdpa and not output_attentions and head_mask is None: - attn_output = F.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) diff --git a/src/transformers/models/gemma/__init__.py b/src/transformers/models/gemma/__init__.py index 1c832e9051b38c..1aafae6e88c2f1 100644 --- a/src/transformers/models/gemma/__init__.py +++ b/src/transformers/models/gemma/__init__.py @@ -55,6 +55,7 @@ "GemmaModel", "GemmaPreTrainedModel", "GemmaForSequenceClassification", + "GemmaForTokenClassification", ] try: @@ -98,6 +99,7 @@ from .modeling_gemma import ( GemmaForCausalLM, GemmaForSequenceClassification, + GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c6ccb9b6972742..6003c91ceea579 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,7 +30,12 @@ AttentionMaskConverter, _prepare_4d_causal_attention_mask, ) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...utils import ( @@ -565,8 +570,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -1345,3 +1350,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA +class GemmaForTokenClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 79f26f3bf24b14..98649c644e728c 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -76,15 +76,21 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + tokenizer_kwargs, image_processor_kwargs = {}, {} + if kwargs: + tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys} + image_processor_kwargs = { + k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys + } if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 37ed2aba620861..7b7bfaf1d42325 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -549,14 +549,19 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): key = key.contiguous() value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not + # create a causal mask in case query_length == 1. + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=self.attn_pdrop if self.training else 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, + is_causal=is_causal, scale=scale, ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 3d1d0884c6aebc..83d90f5ded6be5 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -852,6 +852,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -860,8 +865,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 83a5cb65106383..2ed51413d9ab00 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -660,14 +660,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - attn_output = nn.functional.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, - dropout_p=self.dropout, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index cfa80d2ce34bcd..5a582acdd2b473 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -962,15 +962,15 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] - discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) scan_outputs = [] for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediade_size, seq_len] + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) @@ -978,7 +978,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa cache_params.ssm_states[self.layer_idx] = ssm_state # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states # fmt: on diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 4b8a33118ccc8e..3f6461c4c093f2 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -55,6 +55,7 @@ "LlamaPreTrainedModel", "LlamaForSequenceClassification", "LlamaForQuestionAnswering", + "LlamaForTokenClassification", ] try: @@ -95,6 +96,7 @@ LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a74863e6e3a121..1f4f6ac9a0660d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -36,6 +36,7 @@ CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -640,8 +641,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -1514,3 +1515,87 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/llava_next/image_processing_llava_next.py b/src/transformers/models/llava_next/image_processing_llava_next.py index 34de0f4db0bb59..6295fb9562458b 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next.py +++ b/src/transformers/models/llava_next/image_processing_llava_next.py @@ -37,6 +37,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, + is_valid_image, make_list_of_images, to_numpy_array, valid_images, @@ -52,6 +53,29 @@ from PIL import Image +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: """ Divides an image into patches of a specified size. @@ -651,7 +675,7 @@ def preprocess( do_pad = do_pad if do_pad is not None else self.do_pad do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - images = make_list_of_images(images) + images = make_batched_images(images) if not valid_images(images): raise ValueError( diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 8a4b76e9c68a6c..91cd544ab6484e 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -40,8 +40,8 @@ class LlavaNextProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - image_processor_class = "LlavaNextImageProcessor" - tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 4834fd62473195..a8309e043015d3 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -279,16 +279,16 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] - discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) scan_outputs = [] for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) @@ -296,7 +296,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): cache_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states # fmt: on @@ -399,7 +399,7 @@ def _init_weights(self, module): # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): - p /= math.sqrt(self.config.num_layers) + p /= math.sqrt(self.config.num_hidden_layers) @dataclass diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 1c358c88de4e7f..fc9c642adc8124 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -163,12 +163,50 @@ def __init__(self, config): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values): - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values, interpolate_pos_encoding): + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -207,7 +245,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -215,6 +255,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -780,6 +825,7 @@ def forward( head_mask=None, output_attentions=None, output_hidden_states=None, + interpolate_pos_encoding=False, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -798,7 +844,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values) + embedding_output, input_dimensions = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, diff --git a/src/transformers/models/mistral/__init__.py b/src/transformers/models/mistral/__init__.py index dc0b85980ff600..abf1e32a4b4845 100644 --- a/src/transformers/models/mistral/__init__.py +++ b/src/transformers/models/mistral/__init__.py @@ -32,6 +32,7 @@ "MistralModel", "MistralPreTrainedModel", "MistralForSequenceClassification", + "MistralForTokenClassification", ] try: @@ -59,6 +60,7 @@ from .modeling_mistral import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, MistralPreTrainedModel, ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e692e009e2c29f..c54b8774eea5d4 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -32,7 +32,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -97,30 +102,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - - @property - def sin_cached(self): - logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" - ) - return self._sin_cached - - @property - def cos_cached(self): - logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `MistralAttention` class" - ) - return self._cos_cached @torch.no_grad() # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward @@ -661,8 +642,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -1469,3 +1450,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForTokenClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/__init__.py b/src/transformers/models/mixtral/__init__.py index 7b8f061dac8362..b124d41dfbec10 100644 --- a/src/transformers/models/mixtral/__init__.py +++ b/src/transformers/models/mixtral/__init__.py @@ -36,6 +36,7 @@ "MixtralModel", "MixtralPreTrainedModel", "MixtralForSequenceClassification", + "MixtralForTokenClassification", ] @@ -51,6 +52,7 @@ from .modeling_mixtral import ( MixtralForCausalLM, MixtralForSequenceClassification, + MixtralForTokenClassification, MixtralModel, MixtralPreTrainedModel, ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9bf580fb7adb3a..70e1746392ea62 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -38,6 +38,7 @@ MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 @@ -753,14 +754,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1588,3 +1593,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForTokenClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 8e8b1fe2842f23..d6f0ae96f40e4d 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -618,6 +618,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -626,8 +631,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 9865a4b9179ac8..6458df0a1b7be7 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -634,6 +634,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -642,8 +647,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 8bfa3dc60626a9..9b4b08239bc4d9 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -617,8 +617,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/persimmon/__init__.py b/src/transformers/models/persimmon/__init__.py index 75bc218a2913c7..e1f24ca1b7c23d 100644 --- a/src/transformers/models/persimmon/__init__.py +++ b/src/transformers/models/persimmon/__init__.py @@ -36,6 +36,7 @@ "PersimmonModel", "PersimmonPreTrainedModel", "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", ] @@ -51,6 +52,7 @@ from .modeling_persimmon import ( PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, PersimmonPreTrainedModel, ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 01d124fb9873fe..ea2bd074ee72fc 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,7 +29,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_persimmon import PersimmonConfig @@ -1011,3 +1016,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PERSIMMON_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON +class PersimmonForTokenClassification(PersimmonPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PersimmonModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 7b79643a17ba8c..1436138f91949d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -709,13 +709,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 41765632b5ec27..224aad00858e0e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -778,14 +778,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/qwen2/__init__.py b/src/transformers/models/qwen2/__init__.py index 3409f28214d1fd..35df37e91a98c4 100644 --- a/src/transformers/models/qwen2/__init__.py +++ b/src/transformers/models/qwen2/__init__.py @@ -45,6 +45,7 @@ "Qwen2Model", "Qwen2PreTrainedModel", "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", ] @@ -69,6 +70,7 @@ from .modeling_qwen2 import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, Qwen2PreTrainedModel, ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 3112929a8b718f..98a9e27c9b7914 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -31,7 +31,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -676,14 +681,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1375,3 +1384,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2_moe/__init__.py b/src/transformers/models/qwen2_moe/__init__.py index fb123832787f1f..e2b73ba2d1f9c4 100644 --- a/src/transformers/models/qwen2_moe/__init__.py +++ b/src/transformers/models/qwen2_moe/__init__.py @@ -36,6 +36,7 @@ "Qwen2MoeModel", "Qwen2MoePreTrainedModel", "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", ] @@ -51,6 +52,7 @@ from .modeling_qwen2_moe import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, Qwen2MoePreTrainedModel, ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 4a26976bb6b9cb..0e4b4b75e8120d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -32,7 +32,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -754,14 +759,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1571,3 +1580,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2MoeModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 60cbb777c795a3..199d0e543a8b6e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -852,6 +852,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -860,8 +865,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/stablelm/__init__.py b/src/transformers/models/stablelm/__init__.py index 7fc3a6857fa55a..c00c045f7f81a4 100644 --- a/src/transformers/models/stablelm/__init__.py +++ b/src/transformers/models/stablelm/__init__.py @@ -36,6 +36,7 @@ "StableLmModel", "StableLmPreTrainedModel", "StableLmForSequenceClassification", + "StableLmForTokenClassification", ] @@ -51,6 +52,7 @@ from .modeling_stablelm import ( StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, StableLmPreTrainedModel, ) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index f6a8a8a2be2be3..5b81abc693cf69 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -30,7 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -482,14 +487,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1383,3 +1392,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM +class StableLmForTokenClassification(StableLmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = StableLmModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/starcoder2/__init__.py b/src/transformers/models/starcoder2/__init__.py index 1eb195fde16b03..d9dc2cd1e5001c 100644 --- a/src/transformers/models/starcoder2/__init__.py +++ b/src/transformers/models/starcoder2/__init__.py @@ -36,6 +36,7 @@ "Starcoder2Model", "Starcoder2PreTrainedModel", "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", ] @@ -51,6 +52,7 @@ from .modeling_starcoder2 import ( Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, Starcoder2PreTrainedModel, ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 99abe6919e1e12..e0abef4b41b207 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -31,7 +31,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -655,14 +660,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1358,3 +1367,88 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STARCODER2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2 +class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Starcoder2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index cb0eff88abc26f..2a6363c8e69b7f 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -252,10 +252,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -266,7 +304,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -304,7 +345,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -312,6 +355,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -924,6 +972,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -981,6 +1031,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinModelOutput]: r""" @@ -1003,7 +1054,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -1074,6 +1127,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinMaskedImageModelingOutput]: r""" @@ -1113,6 +1167,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1156,6 +1211,14 @@ def forward( """ Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + """, SWIN_START_DOCSTRING, ) @@ -1188,6 +1251,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, SwinImageClassifierOutput]: r""" @@ -1203,6 +1267,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 213d60a386dcc8..ac8ec197e599d1 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -295,10 +295,48 @@ def __init__(self, config, use_mask_token=False): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def forward( - self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: - embeddings, output_dimensions = self.patch_embeddings(pixel_values) + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -309,7 +347,10 @@ def forward( embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: - embeddings = embeddings + self.position_embeddings + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) @@ -348,7 +389,9 @@ def maybe_pad(self, pixel_values, height, width): pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values - def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + def forward( + self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False + ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -356,6 +399,11 @@ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tens ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) @@ -979,6 +1027,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1031,6 +1081,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2ModelOutput]: r""" @@ -1053,7 +1104,9 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -1126,6 +1179,7 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: r""" @@ -1165,6 +1219,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1208,6 +1263,14 @@ def forward( """ Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + """, SWINV2_START_DOCSTRING, ) @@ -1241,6 +1304,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Swinv2ImageClassifierOutput]: r""" @@ -1256,6 +1320,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 5c1557fb1f1834..aa4bb7827eb760 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -888,6 +888,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -896,8 +901,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 853c521e5e9cfa..4663fc05d8ba36 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -905,6 +905,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -913,8 +918,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 0355d756ce2700..b71600501b4ca3 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -42,7 +42,7 @@ class VideoLlavaProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "VideoLlavaImageProcessor" - tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ec928762b58733..5fb64c1f2cf4ef 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -953,6 +953,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -961,8 +966,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index b388be245f1389..9fceb1e61a1bf1 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -70,15 +70,15 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): with language model support into a single processor for language model boosted speech recognition decoding. Args: - feature_extractor ([`Wav2Vec2FeatureExtractor`]): - An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input. + feature_extractor ([`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]): + An instance of [`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]. The feature extractor is a required input. tokenizer ([`Wav2Vec2CTCTokenizer`]): An instance of [`Wav2Vec2CTCTokenizer`]. The tokenizer is a required input. decoder (`pyctcdecode.BeamSearchDecoderCTC`): An instance of [`pyctcdecode.BeamSearchDecoderCTC`]. The decoder is a required input. """ - feature_extractor_class = "Wav2Vec2FeatureExtractor" + feature_extractor_class = "AutoFeatureExtractor" tokenizer_class = "Wav2Vec2CTCTokenizer" def __init__( @@ -93,6 +93,11 @@ def __init__( if not isinstance(decoder, BeamSearchDecoderCTC): raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}") + if feature_extractor.__class__.__name__ not in ["Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor"]: + raise ValueError( + f"`feature_extractor` has to be of type `Wav2Vec2FeatureExtractor` or `SeamlessM4TFeatureExtractor`, but is {type(feature_extractor)}" + ) + # make sure that decoder's alphabet and tokenizer's vocab match in content missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer) if len(missing_decoder_tokens) > 0: @@ -117,7 +122,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - This class method is simply calling Wav2Vec2FeatureExtractor's + This class method is simply calling the feature extractor's [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and [`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`]. @@ -213,8 +218,8 @@ def get_missing_alphabet_tokens(decoder, tokenizer): def __call__(self, *args, **kwargs): """ - When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's - [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context + When used in normal mode, this method forwards all its arguments to the feature extractor's + [`~FeatureExtractionMixin.__call__`] and returns its output. If used in the context [`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. @@ -252,8 +257,8 @@ def __call__(self, *args, **kwargs): def pad(self, *args, **kwargs): """ - When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's - [`~Wav2Vec2FeatureExtractor.pad`] and returns its output. If used in the context + When used in normal mode, this method forwards all its arguments to the feature extractor's + [`~FeatureExtractionMixin.pad`] and returns its output. If used in the context [`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 508e85b91ffd2d..f2d6da56608961 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -188,6 +188,7 @@ def __call__( sampling_rate: Optional[int] = None, do_normalize: Optional[bool] = None, device: Optional[str] = "cpu", + return_token_timestamps: Optional[bool] = None, **kwargs, ) -> BatchFeature: """ @@ -237,6 +238,9 @@ def __call__( device (`str`, *optional*, defaults to `'cpu'`): Specifies the device for computation of the log-mel spectrogram of audio signals in the `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") + return_token_timestamps (`bool`, *optional*, defaults to `None`): + Whether or not to return the number of frames of the input raw_speech. + These num_frames can be used by the model to compute word level timestamps. """ if sampling_rate is not None: @@ -302,6 +306,7 @@ def __call__( if isinstance(input_features[0], List): padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + else: padded_inputs["input_features"] = input_features @@ -309,6 +314,9 @@ def __call__( # rescale from sample (48000) to feature (3000) padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] + if return_token_timestamps is not None: + padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] + if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index c58b0d35e55618..2bdff6e534dcd8 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -209,11 +209,15 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec # 2. num_frames is different, compute the DTW matrix for each sample sequentially # we're using np.unique because num_frames can be int/list/tuple - if len(np.unique(num_frames)) == 1: - # if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch - num_frames = num_frames if isinstance(num_frames, int) else num_frames[0] - + if isinstance(num_frames, int): weights = weights[..., : num_frames // 2] + + elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1: + weights = weights[..., : num_frames[0] // 2] + + elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1: + weights = weights[..., : num_frames[0] // 2] + else: # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) @@ -231,7 +235,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec # Perform dynamic time warping on each element of the batch. for batch_idx in range(batch_size): - if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)): + if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)): matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] # Normalize and smoothen the weights. @@ -475,6 +479,7 @@ def generate( "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", FutureWarning, ) + # 1. prepare generation config generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c0db404e5c88a5..e4fda437bfa4ab 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -686,6 +686,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -694,8 +699,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index de1a9b57ac6e3e..123dbcdb67afd7 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -443,11 +443,18 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): return_tensors="pt", ) else: - processed = self.feature_extractor( - inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" - ) - if stride is None: - extra["segment_size"] = len(inputs) + if self.type == "seq2seq_whisper" and stride is None: + processed = self.feature_extractor( + inputs, + sampling_rate=self.feature_extractor.sampling_rate, + return_tensors="pt", + return_token_timestamps=True, + ) + extra["num_frames"] = processed.pop("num_frames") + else: + processed = self.feature_extractor( + inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" + ) if self.torch_dtype is not None: processed = processed.to(dtype=self.torch_dtype) @@ -461,11 +468,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): attention_mask = model_inputs.pop("attention_mask", None) stride = model_inputs.pop("stride", None) - segment_size = model_inputs.pop("segment_size", None) + num_frames = model_inputs.pop("num_frames", None) is_last = model_inputs.pop("is_last") - if stride is not None and segment_size is not None: - raise ValueError("segment_size must be used only when stride is None") + if stride is not None and num_frames is not None: + raise ValueError("num_frames must be used only when stride is None") if self.type in {"seq2seq", "seq2seq_whisper"}: encoder = self.model.get_encoder() @@ -495,10 +502,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride] else: - if isinstance(segment_size, int): - generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length - else: - generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length + generate_kwargs["num_frames"] = num_frames if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames: generate_kwargs["input_features"] = inputs diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 463591032ceb7c..a8e47fb6831e3d 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -30,6 +30,7 @@ PROCESSOR_NAME, PushToHubMixin, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_file, copy_func, direct_transformers_import, @@ -355,10 +356,15 @@ def get_processor_dict( else: logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}") - if "auto_map" in processor_dict and not is_local: - processor_dict["auto_map"] = add_model_info_to_auto_map( - processor_dict["auto_map"], pretrained_model_name_or_path - ) + if not is_local: + if "auto_map" in processor_dict: + processor_dict["auto_map"] = add_model_info_to_auto_map( + processor_dict["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in processor_dict: + processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( + processor_dict["custom_pipelines"], pretrained_model_name_or_path + ) return processor_dict, kwargs diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 395f9859cd68ce..4cb75c98646ce1 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -42,6 +42,7 @@ TensorType, add_end_docstrings, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -2177,13 +2178,18 @@ def _from_pretrained( config_tokenizer_class = None init_kwargs = init_configuration - if "auto_map" in init_kwargs and not _is_local: - # For backward compatibility with odl format. - if isinstance(init_kwargs["auto_map"], (tuple, list)): - init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} - init_kwargs["auto_map"] = add_model_info_to_auto_map( - init_kwargs["auto_map"], pretrained_model_name_or_path - ) + if not _is_local: + if "auto_map" in init_kwargs: + # For backward compatibility with odl format. + if isinstance(init_kwargs["auto_map"], (tuple, list)): + init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} + init_kwargs["auto_map"] = add_model_info_to_auto_map( + init_kwargs["auto_map"], pretrained_model_name_or_path + ) + if "custom_pipelines" in init_kwargs: + init_kwargs["custom_pipelines"] = add_model_info_to_custom_pipelines( + init_kwargs["custom_pipelines"], pretrained_model_name_or_path + ) if config_tokenizer_class is None: # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo. diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 9defa91b2b8bc8..8ac0281912ce19 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1250,6 +1250,10 @@ class AcceleratorConfig: Whether to use non-blocking CUDA calls to help minimize synchronization during distributed training with prepared `DataLoader` inputs being moved to device. Best if used with `pin_memory=True` in the `TrainingArguments`. + use_configured_state (`bool*, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined + before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState` + must be initialized. May lead to issues using sweeps or hyperparameter tuning. """ @@ -1312,6 +1316,13 @@ class AcceleratorConfig: " The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`." }, ) + use_configured_state: bool = field( + default=False, + metadata={ + "help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`." + "If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning." + }, + ) @classmethod def from_json_file(cls, json_file): @@ -1331,6 +1342,9 @@ def from_json_file(cls, json_file): def to_dict(self): return copy.deepcopy(self.__dict__) + def pop(self, key, default=None): + return self.__dict__.pop(key, default) + class LayerWiseDummyOptimizer(torch.optim.Optimizer): """ diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2807c9951aa6d6..08d9000cb6258c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -572,6 +572,10 @@ class TrainingArguments: training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results. + - use_configured_state (`bool`, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. + If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues + with hyperparameter tuning. label_smoothing_factor (`float`, *optional*, defaults to 0.0): The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded @@ -1635,6 +1639,39 @@ def __post_init__(self): if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") + # We need to setup the accelerator config here *before* the first call to `self.device` + if is_accelerate_available(): + if not isinstance(self.accelerator_config, (AcceleratorConfig)): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + + if self.dispatch_batches is not None: + warnings.warn( + "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'dispatch_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.dispatch_batches = self.dispatch_batches + + if self.split_batches is not None: + warnings.warn( + "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'split_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.split_batches = self.split_batches + if ( self.framework == "pt" and is_torch_available() @@ -1873,37 +1910,6 @@ def __post_init__(self): os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() - if is_accelerate_available(): - if not isinstance(self.accelerator_config, (AcceleratorConfig)): - if self.accelerator_config is None: - self.accelerator_config = AcceleratorConfig() - elif isinstance(self.accelerator_config, dict): - self.accelerator_config = AcceleratorConfig(**self.accelerator_config) - # Check that a user didn't pass in the class instantiator - # such as `accelerator_config = AcceleratorConfig` - elif isinstance(self.accelerator_config, type): - raise NotImplementedError( - "Tried passing in a callable to `accelerator_config`, but this is not supported. " - "Please pass in a fully constructed `AcceleratorConfig` object instead." - ) - else: - self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) - if self.dispatch_batches is not None: - warnings.warn( - "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'dispatch_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.dispatch_batches = self.dispatch_batches - - if self.split_batches is not None: - warnings.warn( - "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'split_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.split_batches = self.split_batches - if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" @@ -2056,32 +2062,62 @@ def _setup_devices(self) -> "torch.device": f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " "Please run `pip install transformers[torch]` or `pip install accelerate -U`" ) + # We delay the init of `PartialState` to the end for clarity + accelerator_state_kwargs = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if PartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + # We rely on `PartialState` to yell if there's issues here (which it will) + self.distributed_state = PartialState(cpu=self.use_cpu) + if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED: + raise RuntimeError( + "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, " + "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set " + "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly." + ) + else: AcceleratorState._reset_state(reset_partial_state=True) - self.distributed_state = None + self.distributed_state = None if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: os.environ["ACCELERATE_USE_IPEX"] = "false" + + self._n_gpu = 1 if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): - self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) + accelerator_state_kwargs["cpu"] = True + accelerator_state_kwargs["backend"] = self.ddp_backend self._n_gpu = 0 elif is_sagemaker_mp_enabled(): + accelerator_state_kwargs["enabled"] = False local_rank = smp.local_rank() device = torch.device("cuda", local_rank) - self._n_gpu = 1 torch.cuda.set_device(device) elif is_sagemaker_dp_enabled(): - self.distributed_state = PartialState(_use_sagemaker_dp=True) - self._n_gpu = 1 + accelerator_state_kwargs["_use_sagemaker_dp"] = True elif self.deepspeed: - # Need to do similar for Accelerator init - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) - del os.environ["ACCELERATE_USE_DEEPSPEED"] - self._n_gpu = 1 + accelerator_state_kwargs["use_deepspeed"] = True + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) else: - self.distributed_state = PartialState( - backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) - ) - self._n_gpu = 1 + accelerator_state_kwargs["backend"] = self.ddp_backend + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + + # Now we pop everything + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + # We need to patch this env var when enabling to detect deepspeed + use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False) + if use_deepspeed: + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(**accelerator_state_kwargs) + if use_deepspeed: + del os.environ["ACCELERATE_USE_DEEPSPEED"] if not is_sagemaker_mp_enabled(): device = self.distributed_state.device self.local_rank = self.distributed_state.local_process_index @@ -2108,23 +2144,17 @@ def _setup_devices(self) -> "torch.device": "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ " "or current PyTorch install was not built with MPS enabled." ) - if device.type == "mps": - self._n_gpu = 1 - elif self.use_cpu: + if self.use_cpu: device = torch.device("cpu") - self._n_gpu = 0 elif is_torch_xpu_available(): device = torch.device("xpu:0") torch.xpu.set_device(device) - self._n_gpu = 1 elif is_torch_mlu_available(): device = torch.device("mlu:0") torch.mlu.set_device(device) - self._n_gpu = 1 elif is_torch_npu_available(): device = torch.device("npu:0") torch.npu.set_device(device) - self._n_gpu = 1 else: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 8c91463322ab8b..82d0b5001c7786 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -37,6 +37,7 @@ PaddingStrategy, TensorType, add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, cached_property, can_return_loss, expand_dims, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 681f8585566f3c..5e00230aed4cf5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3692,6 +3692,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class GemmaForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GemmaModel(metaclass=DummyObject): _backends = ["torch"] @@ -4642,6 +4649,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LlamaForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlamaModel(metaclass=DummyObject): _backends = ["torch"] @@ -5237,6 +5251,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MistralForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MistralModel(metaclass=DummyObject): _backends = ["torch"] @@ -5265,6 +5286,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MixtralForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MixtralModel(metaclass=DummyObject): _backends = ["torch"] @@ -6373,6 +6401,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class PersimmonForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PersimmonModel(metaclass=DummyObject): _backends = ["torch"] @@ -6734,6 +6769,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Qwen2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2Model(metaclass=DummyObject): _backends = ["torch"] @@ -6762,6 +6804,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Qwen2MoeForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2MoeModel(metaclass=DummyObject): _backends = ["torch"] @@ -7793,6 +7842,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class StableLmForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StableLmModel(metaclass=DummyObject): _backends = ["torch"] @@ -7821,6 +7877,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Starcoder2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Starcoder2Model(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index a8277588ffdee7..1f332434a9c815 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -721,6 +721,19 @@ def add_model_info_to_auto_map(auto_map, repo_id): return auto_map +def add_model_info_to_custom_pipelines(custom_pipeline, repo_id): + """ + Adds the information of the repo_id to a given custom pipeline. + """ + # {custom_pipelines : {task: {"impl": "path.to.task"},...} } + for task in custom_pipeline.keys(): + if "impl" in custom_pipeline[task]: + module = custom_pipeline[task]["impl"] + if "--" not in module: + custom_pipeline[task]["impl"] = f"{repo_id}--{module}" + return custom_pipeline + + def infer_framework(model_class): """ Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 7cc2ebe7bd24b7..8281f5999371ee 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -41,7 +41,13 @@ if is_torch_available(): import torch - from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer + from transformers import ( + GemmaForCausalLM, + GemmaForSequenceClassification, + GemmaForTokenClassification, + GemmaModel, + GemmaTokenizer, + ) class GemmaModelTester: @@ -284,12 +290,17 @@ def prepare_config_and_inputs_for_common(self): @require_torch class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification) + if is_torch_available() + else () + ) all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": GemmaModel, "text-classification": GemmaForSequenceClassification, + "token-classification": GemmaForTokenClassification, "text-generation": GemmaForCausalLM, "zero-shot": GemmaForSequenceClassification, } @@ -370,6 +381,22 @@ def test_Gemma_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma + def test_Gemma_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = GemmaForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Gemma buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 9fcdcc6d7e5cdc..58269d62e08c2b 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -47,6 +47,7 @@ LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaTokenizer, ) @@ -286,7 +287,13 @@ def prepare_config_and_inputs_for_common(self): @require_torch class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering) + ( + LlamaModel, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForQuestionAnswering, + LlamaForTokenClassification, + ) if is_torch_available() else () ) @@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi "text-generation": LlamaForCausalLM, "zero-shot": LlamaForSequenceClassification, "question-answering": LlamaForQuestionAnswering, + "token-classification": LlamaForTokenClassification, } if is_torch_available() else {} @@ -370,6 +378,21 @@ def test_llama_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + def test_llama_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = LlamaForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Llama buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/llava_next/test_image_processor_llava_next.py b/tests/models/llava_next/test_image_processor_llava_next.py index 7369f8a9186951..8b1f98bbcaefc4 100644 --- a/tests/models/llava_next/test_image_processor_llava_next.py +++ b/tests/models/llava_next/test_image_processor_llava_next.py @@ -199,3 +199,21 @@ def test_call_pytorch(self): @unittest.skip("LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy def test_call_numpy_4_channels(self): pass + + def test_nested_input(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 1445, 3, 18, 18) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all()) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 13076c14424603..13bdf83c5bbe09 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -46,6 +46,7 @@ from transformers import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, ) @@ -288,13 +289,16 @@ def prepare_config_and_inputs_for_common(self): @require_torch class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else () + (MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": MistralModel, "text-classification": MistralForSequenceClassification, + "token-classification": MistralForTokenClassification, "text-generation": MistralForCausalLM, "zero-shot": MistralForSequenceClassification, } @@ -376,6 +380,22 @@ def test_Mistral_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral + def test_Mistral_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = MistralForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Mistral buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8cfa85..0972207fce4a55 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -40,7 +40,12 @@ if is_torch_available(): import torch - from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel + from transformers import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralModel, + ) class MixtralModelTester: @@ -287,13 +292,16 @@ def prepare_config_and_inputs_for_common(self): # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else () + (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": MixtralModel, "text-classification": MixtralForSequenceClassification, + "token-classification": MixtralForTokenClassification, "text-generation": MixtralForCausalLM, "zero-shot": MixtralForSequenceClassification, } @@ -375,6 +383,22 @@ def test_Mixtral_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral + def test_Mixtral_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = MixtralForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Mixtral buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 86a69d774f1681..46a650c55abfed 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -44,6 +44,7 @@ AutoTokenizer, PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, ) from transformers.models.persimmon.modeling_persimmon import ( @@ -283,12 +284,15 @@ def prepare_config_and_inputs_for_common(self): @require_torch class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else () + (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification) + if is_torch_available() + else () ) pipeline_model_mapping = ( { "feature-extraction": PersimmonModel, "text-classification": PersimmonForSequenceClassification, + "token-classification": PersimmonForTokenClassification, # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # "text-generation": PersimmonForCausalLM, # "zero-shot": PersimmonForSequenceClassification, @@ -365,6 +369,22 @@ def test_persimmon_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon + def test_persimmon_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = PersimmonForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Persimmon buffers include complex numbers, which breaks this test") # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base def test_save_load_fast_init_from_base(self): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index f4e88a97f06a53..54718c430387a8 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -45,6 +45,7 @@ from transformers import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, ) @@ -299,12 +300,17 @@ def prepare_config_and_inputs_for_common(self): @require_torch # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification) + if is_torch_available() + else () + ) all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Qwen2Model, "text-classification": Qwen2ForSequenceClassification, + "token-classification": Qwen2ForTokenClassification, "text-generation": Qwen2ForCausalLM, "zero-shot": Qwen2ForSequenceClassification, } @@ -387,6 +393,22 @@ def test_Qwen2_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2 + def test_Qwen2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Qwen2 buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index f0818e680d3da8..48c6ccf78ac447 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -45,6 +45,7 @@ from transformers import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, ) @@ -327,13 +328,16 @@ def prepare_config_and_inputs_for_common(self): # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else () + (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Qwen2MoeModel, "text-classification": Qwen2MoeForSequenceClassification, + "token-classification": Qwen2MoeForTokenClassification, "text-generation": Qwen2MoeForCausalLM, "zero-shot": Qwen2MoeForSequenceClassification, } @@ -414,6 +418,22 @@ def test_Qwen2Moe_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe + def test_Qwen2Moe_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen2MoeForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index f5e74ead9b8502..083f928612a03e 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -43,6 +43,7 @@ AutoTokenizer, StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, ) from transformers.models.stablelm.modeling_stablelm import ( @@ -287,12 +288,15 @@ def prepare_config_and_inputs_for_common(self): # Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else () + (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification) + if is_torch_available() + else () ) pipeline_model_mapping = ( { "feature-extraction": StableLmModel, "text-classification": StableLmForSequenceClassification, + "token-classification": StableLmForTokenClassification, # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # "text-generation": StableLmForCausalLM, # "zero-shot": StableLmForSequenceClassification, @@ -356,6 +360,22 @@ def test_stablelm_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm + def test_stablelm_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = StableLmForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm def test_model_rope_scaling_from_config(self, scaling_type): diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 95f604d06b3713..faba4d254ba8cc 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -43,6 +43,7 @@ AutoTokenizer, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, ) @@ -290,13 +291,16 @@ def prepare_config_and_inputs_for_common(self): # Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2 class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else () + (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Starcoder2Model, "text-classification": Starcoder2ForSequenceClassification, + "token-classification": Starcoder2ForTokenClassification, "text-generation": Starcoder2ForCausalLM, "zero-shot": Starcoder2ForSequenceClassification, } @@ -370,6 +374,22 @@ def test_Starcoder2_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2 + def test_Starcoder2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Starcoder2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 9220784e23029a..699171722d0db9 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -493,6 +493,26 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + @slow + def test_inference_interpolate_pos_encoding(self): + # Swin models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. + model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device) + + image_processor = self.default_image_processor + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 256, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin): diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py index b8f97ee7c23bc6..7a948d1282c1b6 100644 --- a/tests/models/swinv2/test_modeling_swinv2.py +++ b/tests/models/swinv2/test_modeling_swinv2.py @@ -485,6 +485,26 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + @slow + def test_inference_interpolate_pos_encoding(self): + # Swinv2 models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. + model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device) + + image_processor = self.default_image_processor + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 256, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + @require_torch class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin): diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 2c52a921653c61..61dee30091d282 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -25,7 +25,7 @@ from datasets import load_dataset from parameterized import parameterized -from transformers import AutoProcessor +from transformers import AutoFeatureExtractor, AutoProcessor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow @@ -157,6 +157,35 @@ def test_feature_extractor(self): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + def test_another_feature_extractor(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) + + def test_wrong_feature_extractor_raises_error(self): + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3") + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + with self.assertRaises(ValueError): + Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + def test_tokenizer(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index fed1b9c0592522..58acb5f2fdd451 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1950,6 +1950,69 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_large_timestamp_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + model.to(torch_device) + + input_speech = np.concatenate(self._load_datasamples(4)) + input_features = processor( + input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True + ).input_features + input_features = input_features.to(torch_device) + + generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") + + # fmt: off + EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257]) + # fmt: on + self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT)) + + EXPECTED_TRANSCRIPT = [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + " Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive" + " season of the year, with Christmas and roast beef looming before us, similes drawn from eating" + " and its results occur most readily to the mind. He has grave doubts whether Sir Frederick " + "Leighton's work is really Greek after all," + ), + "offsets": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 5.28), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.34, 10.1), + }, + { + "text": ( + " He tells us that at this festive season of the year, with Christmas and roast beef looming before us," + ), + "timestamp": (10.92, 17.6), + }, + { + "text": (" similes drawn from eating and its results occur most readily to the mind."), + "timestamp": (18.44, 22.580000000000002), + }, + { + "text": ( + " He has grave doubts whether Sir Frederick Leighton's work is really Greek after all," + ), + "timestamp": (23.16, 28.68), + }, + ], + } + ] + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow def test_tiny_token_timestamp_generation(self): set_seed(0) @@ -1979,6 +2042,36 @@ def test_tiny_token_timestamp_generation(self): self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow + def test_large_token_timestamp_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") + model.to(torch_device) + + input_speech = self._load_datasamples(4) + input_features = processor( + input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True + ) + input_features = input_features.to(torch_device) + + generate_outputs = model.generate( + **input_features, max_length=448, return_timestamps=True, return_token_timestamps=True + ) + + self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + + # fmt: off + EXPECTED_OUTPUT = torch.tensor([ + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000, 12.4600], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800, 9.8800] + ]) + # fmt: on + + self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow def test_tiny_token_timestamp_batch_generation(self): set_seed(0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 763c7d1a883314..6ede7d1c7ac3c9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -925,6 +925,24 @@ def test_push_to_hub_dynamic_pipeline(self): # Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a # dynamic module self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline") + # check for tag exitence, tag needs to be added when we are calling a custom pipeline from the hub + # useful for cases such as finetuning + self.assertDictEqual( + new_classifier.model.config.custom_pipelines, + { + "pair-classification": { + "impl": f"{USER}/test-dynamic-pipeline--custom_pipeline.PairClassificationPipeline", + "pt": ("AutoModelForSequenceClassification",), + "tf": (), + } + }, + ) + # test if the pipeline still works after the model is finetuned + # (we are actually testing if the pipeline still works from the final repo) + # this is where the user/repo--module.class is used for + new_classifier.model.push_to_hub(repo_name=f"{USER}/test-pipeline-for-a-finetuned-model", token=self._token) + del new_classifier # free up memory + new_classifier = pipeline(model=f"{USER}/test-pipeline-for-a-finetuned-model", trust_remote_code=True) results = classifier("I hate you", second_text="I love you") new_results = new_classifier("I hate you", second_text="I love you") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 533eedb7f336c1..5f3ac898daeea6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4014,6 +4014,47 @@ def test_sdpa_can_dispatch_on_flash(self): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_compile_dynamic(self): + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + if config.model_type in ["dbrx"]: + self.skipTest( + "DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile." + ) + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") + model.to(torch_device) + + # For PyTorch 2.1 - 2.3.0 set `dynamic=True`. In the future setting `dynamic=None` and using `torch._dynamo.mark_dynamic()` + # on input tensors will be required. `mark_dynamic` currently raises inconsistent shape errors. + model = torch.compile(model, dynamic=True) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + # use no_grad to save some memory + with torch.no_grad(): + _ = model(**inputs_dict) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c420da4052f186..1711f600cebbc8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -131,6 +131,10 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") +if is_accelerate_available(): + from accelerate import Accelerator + from accelerate.state import AcceleratorState + PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" @@ -3266,6 +3270,16 @@ def test_accelerator_config_only_deprecated_args(self): trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.split_batches, True) + def test_accelerator_custom_state(self): + AcceleratorState._reset_state(reset_partial_state=True) + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(ValueError) as cm: + _ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True}) + self.assertIn("Please define this beforehand", str(cm.warnings[0].message)) + _ = Accelerator() + _ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True}) + AcceleratorState._reset_state(reset_partial_state=True) + @require_accelerate_version_min_0_28 def test_accelerator_config_from_dict_grad_accum_num_steps(self): with tempfile.TemporaryDirectory() as tmp_dir: