From c9f191a0b7f391594901a960671f2a199122ef48 Mon Sep 17 00:00:00 2001 From: Merve Noyan Date: Thu, 27 Jun 2024 12:46:36 +0300 Subject: [PATCH 01/18] Fix ONNX exports for Optimum compatible models (#31311) * fixed models * format with bumped ruff version on my local * fix copies * add tracing checks * format * Update src/transformers/utils/generic.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * format * style fix * Update modeling_mobilevit.py * add docstring and change name * Update __init__.py * Update __init__.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/clap/modeling_clap.py | 7 ++++-- .../models/donut/modeling_donut_swin.py | 7 ++++-- src/transformers/models/dpt/modeling_dpt.py | 4 ++-- .../models/imagegpt/modeling_imagegpt.py | 10 ++++++-- .../models/layoutlmv3/modeling_layoutlmv3.py | 12 +++++++--- .../models/mobilevit/modeling_mobilevit.py | 13 ++++++++-- src/transformers/models/sam/modeling_sam.py | 3 +-- src/transformers/models/swin/modeling_swin.py | 7 ++++-- src/transformers/utils/__init__.py | 2 ++ src/transformers/utils/generic.py | 24 +++++++++++++++++++ 10 files changed, 72 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1c236d29d4e734..3e83daa942c022 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -37,6 +37,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig @@ -590,8 +591,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 7e899f453f1c0f..115808a6b11a71 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, logging, + torch_int, ) from .configuration_donut_swin import DonutSwinConfig @@ -562,8 +563,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index a7e554742f2de2..db5db0eae1189b 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, logging +from ...utils import ModelOutput, logging, torch_int from ...utils.backbone_utils import load_backbone from .configuration_dpt import DPTConfig @@ -226,7 +226,7 @@ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_ind posemb_tok = posemb[:, :start_index] posemb_grid = posemb[0, start_index:] - old_grid_size = int(math.sqrt(len(posemb_grid))) + old_grid_size = torch_int(posemb_grid.size(0) ** 0.5) posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index c0b0a83c24d66f..5d59a4ed90e4c9 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -33,7 +33,13 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_float, +) from .configuration_imagegpt import ImageGPTConfig @@ -229,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 941ff860042adf..629490350c7dc3 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -33,7 +33,13 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_layoutlmv3 import LayoutLMv3Config @@ -910,8 +916,8 @@ def forward( patch_height = patch_width = None if pixel_values is not None: patch_height, patch_width = ( - int(pixel_values.shape[2] / self.config.patch_size), - int(pixel_values.shape[3] / self.config.patch_size), + torch_int(pixel_values.shape[2] / self.config.patch_size), + torch_int(pixel_values.shape[3] / self.config.patch_size), ) visual_embeddings = self.forward_image(pixel_values) visual_attention_mask = torch.ones( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 551b4ee734b511..59c191b3789641 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -39,6 +39,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_mobilevit import MobileViTConfig @@ -437,8 +438,16 @@ def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]: batch_size, channels, orig_height, orig_width = features.shape - new_height = int(math.ceil(orig_height / patch_height) * patch_height) - new_width = int(math.ceil(orig_width / patch_width) * patch_width) + new_height = ( + torch_int(torch.ceil(orig_height / patch_height) * patch_height) + if torch.jit.is_tracing() + else int(math.ceil(orig_height / patch_height) * patch_height) + ) + new_width = ( + torch_int(torch.ceil(orig_width / patch_width) * patch_width) + if torch.jit.is_tracing() + else int(math.ceil(orig_width / patch_width) * patch_width) + ) interpolate = False if new_width != orig_width or new_height != orig_height: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index f5baf5bcf3bfd0..c99fb9d7e869f8 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -15,7 +15,6 @@ """PyTorch SAM model.""" import collections -import math from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union @@ -232,7 +231,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit # SamAttention _, _, _, c_per_head = query.shape attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens - attn = attn / math.sqrt(c_per_head) + attn = attn / (c_per_head**0.5) attn = torch.softmax(attn, dim=-1) if attention_similarity is not None: diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index f3f2dedeb6f3dd..8813d555968880 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -36,6 +36,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_swin import SwinConfig @@ -639,8 +640,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ce87bc8623132e..ddd329817aff24 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -60,6 +60,8 @@ tensor_size, to_numpy, to_py_obj, + torch_float, + torch_int, transpose, working_or_temp_dir, ) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 4a3c1d970116ae..80232898ce4707 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -753,6 +753,30 @@ def infer_framework(model_class): raise TypeError(f"Could not infer framework from class {model_class}.") +def torch_int(x): + """ + Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int. + """ + if not is_torch_available(): + return int(x) + + import torch + + return x.to(torch.int64) if torch.jit.is_tracing() else int(x) + + +def torch_float(x): + """ + Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float. + """ + if not is_torch_available(): + return int(x) + + import torch + + return x.to(torch.float32) if torch.jit.is_tracing() else int(x) + + def filter_out_non_signature_kwargs(extra: Optional[list] = None): """ Decorator to filter out named arguments that are not in the function signature. From 11138ca013add0041753aab276e83261c38b08ac Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 27 Jun 2024 12:35:19 +0200 Subject: [PATCH 02/18] [`Llama`] Conversion: fix and simplify the script! (#31591) * fix and simplify the script! * add co-author --------- Co-authored-by: crackalamoo --- .../llama/convert_llama_weights_to_hf.py | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index a98d44b7484ada..a0fbe4680addf2 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -105,21 +105,18 @@ def write_json(text, path): def write_model( model_path, input_base_path, - model_size, + model_size=None, safe_serialization=True, llama_version=1, vocab_size=None, + num_shards=None, ): - # for backward compatibility, before you needed the repo to be called `my_repo/model_size` - if not os.path.isfile(os.path.join(input_base_path, "params.json")): - input_base_path = os.path.join(input_base_path, model_size) - os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) params = read_json(os.path.join(input_base_path, "params.json")) - num_shards = NUM_SHARDS[model_size] + num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards params = params.get("model", params) n_layers = params["n_layers"] n_heads = params["n_heads"] @@ -142,12 +139,13 @@ def write_model( vocab_size = vocab_size if vocab_size is not None else 32000 if params.get("n_kv_heads", None) is not None: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_local_key_value_heads = n_heads_per_shard // num_key_value_heads - key_value_dim = dim // num_key_value_heads + num_key_value_heads_per_shard = num_key_value_heads // num_shards + key_value_dim = dims_per_head * num_key_value_heads else: # compatibility with other checkpoints num_key_value_heads = n_heads - num_local_key_value_heads = n_heads_per_shard - key_value_dim = dim + num_key_value_heads_per_shard = n_heads_per_shard + key_value_dim = dims_per_head * num_key_value_heads + print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim) # permute for sliced rotary def permute(w, n_heads, dim1=dim, dim2=dim): @@ -162,8 +160,9 @@ def permute(w, n_heads, dim1=dim, dim2=dim): else: # Sharded loaded = [ - torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") - for i in range(num_shards) + torch.load(os.path.join(input_base_path, file), map_location="cpu") + for file in os.listdir(input_base_path) + if file.endswith(".pth") ] param_count = 0 index_dict = {"weight_map": {}} @@ -178,7 +177,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( loaded[f"layers.{layer_i}.attention.wk.weight"], n_heads=num_key_value_heads, - dim1=dim // num_local_key_value_heads, + dim1=key_value_dim, ), f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], @@ -206,7 +205,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) - for i in range(num_shards) + for i in range(len(loaded)) ], dim=0, ).reshape(dim, dim), @@ -216,9 +215,9 @@ def permute(w, n_heads, dim1=dim, dim2=dim): torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( - num_local_key_value_heads, dims_per_head, dim + num_key_value_heads_per_shard, dims_per_head, dim ) - for i in range(num_shards) + for i in range(len(loaded)) ], dim=0, ).reshape(key_value_dim, dim), @@ -229,24 +228,24 @@ def permute(w, n_heads, dim1=dim, dim2=dim): state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( - num_local_key_value_heads, dims_per_head, dim + num_key_value_heads_per_shard, dims_per_head, dim ) - for i in range(num_shards) + for i in range(len(loaded)) ], dim=0, ).reshape(key_value_dim, dim) state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1 ) state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0 ) state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1 ) state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0 ) state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq @@ -268,9 +267,9 @@ def permute(w, n_heads, dim1=dim, dim2=dim): state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=concat_dim + [loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim ), - "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0), } for k, v in state_dict.items(): @@ -310,7 +309,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): model.config.torch_dtype = torch.float16 print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) - shutil.rmtree(tmp_model_path) + shutil.rmtree(tmp_model_path, ignore_errors=True) class Llama3Converter(TikTokenConverter): @@ -371,8 +370,8 @@ def main(): ) parser.add_argument( "--model_size", - choices=["7B", "8B", "8Bf", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], - help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", + default=None, + help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", ) parser.add_argument( "--output_dir", @@ -389,7 +388,15 @@ def main(): type=int, help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", ) + parser.add_argument( + "--num_shards", + default=None, + type=int, + help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", + ) args = parser.parse_args() + if args.model_size is None and args.num_shards is None: + raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") spm_path = os.path.join(args.input_dir, "tokenizer.model") vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version)) if args.model_size != "tokenizer_only": @@ -400,6 +407,7 @@ def main(): safe_serialization=args.safe_serialization, llama_version=args.llama_version, vocab_size=vocab_size, + num_shards=args.num_shards, ) From 3a028101e91069b51629f5e74096ae78e490022b Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Thu, 27 Jun 2024 18:41:49 +0800 Subject: [PATCH 03/18] [QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590) * Allow dtype str for torch_dtype in from_pretrained * Update docstring * Add tests for str torch_dtype --- src/transformers/modeling_utils.py | 6 +++++- tests/test_modeling_utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f7b0db6d77f8e4..c991c1c95ba24b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2958,6 +2958,8 @@ def from_pretrained( using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc. + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or @@ -3661,9 +3663,11 @@ def from_pretrained( "Since the `torch_dtype` attribute can't be found in model's config object, " "will use torch_dtype={torch_dtype} as derived from model's weights" ) + elif hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) else: raise ValueError( - f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f5b30d50339093..758fe4d1fdf398 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -445,6 +445,18 @@ def test_model_from_config_torch_dtype(self): with self.assertRaises(ValueError): model = AutoModel.from_config(config, torch_dtype=torch.int64) + def test_model_from_config_torch_dtype_str(self): + # test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32") + self.assertEqual(model.dtype, torch.float32) + + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16") + self.assertEqual(model.dtype, torch.float16) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument From be50a0338b9d7b76448fcc9c5046a78118a4d968 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Thu, 27 Jun 2024 20:36:55 +0900 Subject: [PATCH 04/18] change anchor_image_size None for compatibility (#31640) * change anchor_image_size None for compatibility * make fix-copies --- src/transformers/models/rt_detr/configuration_rt_detr.py | 6 +++--- tests/models/rt_detr/test_modeling_rt_detr.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/rt_detr/configuration_rt_detr.py b/src/transformers/models/rt_detr/configuration_rt_detr.py index a3d49fafeaedc7..d0f4bb17562b3a 100644 --- a/src/transformers/models/rt_detr/configuration_rt_detr.py +++ b/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -118,8 +118,8 @@ class RTDetrConfig(PretrainedConfig): Scale or magnitude of noise to be added to the bounding boxes. learn_initial_query (`bool`, *optional*, defaults to `False`): Indicates whether the initial query embeddings for the decoder should be learned during training - anchor_image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`): - Height and width of the input image used during evaluation to generate the bounding box anchors. + anchor_image_size (`Tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. disable_custom_kernels (`bool`, *optional*, defaults to `True`): Whether to disable custom kernels. with_box_refine (`bool`, *optional*, defaults to `True`): @@ -218,7 +218,7 @@ def __init__( label_noise_ratio=0.5, box_noise_scale=1.0, learn_initial_query=False, - anchor_image_size=[640, 640], + anchor_image_size=None, disable_custom_kernels=True, with_box_refine=True, is_encoder_decoder=True, diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py index 44647be5ac6762..2d3d48dba33125 100644 --- a/tests/models/rt_detr/test_modeling_rt_detr.py +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -91,7 +91,7 @@ def __init__( label_noise_ratio=0.5, box_noise_scale=1.0, learn_initial_query=False, - anchor_image_size=[64, 64], + anchor_image_size=None, image_size=64, disable_custom_kernels=True, with_box_refine=True, From 4aa17d00690b7f82c95bb2949ea57e22c35b4336 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 27 Jun 2024 16:54:41 +0500 Subject: [PATCH 05/18] Remove deprecated config attribute in VLMs (#31655) remove --- .../models/llava/configuration_llava.py | 18 --------------- .../models/vipllava/configuration_vipllava.py | 23 ------------------- 2 files changed, 41 deletions(-) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index 6930dcc78c46f7..34e67ee4221f35 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -131,23 +131,5 @@ def __init__( text_config = CONFIG_MAPPING["llama"]() self.text_config = text_config - self._vocab_size = self.text_config.vocab_size super().__init__(**kwargs) - - @property - def vocab_size(self): - warnings.warn( - "The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.", - FutureWarning, - ) - return self._vocab_size - - @vocab_size.setter - def vocab_size(self, value): - self._vocab_size = value - - def to_dict(self): - output = super().to_dict() - output.pop("_vocab_size", None) - return output diff --git a/src/transformers/models/vipllava/configuration_vipllava.py b/src/transformers/models/vipllava/configuration_vipllava.py index d98099b21b0546..c80487702c6525 100644 --- a/src/transformers/models/vipllava/configuration_vipllava.py +++ b/src/transformers/models/vipllava/configuration_vipllava.py @@ -13,8 +13,6 @@ # limitations under the License. """VipLlava model configuration""" -import warnings - from ...configuration_utils import PretrainedConfig from ...utils import logging from ..auto import CONFIG_MAPPING @@ -90,13 +88,6 @@ def __init__( self.projector_hidden_act = projector_hidden_act self.projector_layernorm_eps = projector_layernorm_eps self.vision_feature_layers = vision_feature_layers - - if "vocab_size" in kwargs: - warnings.warn( - "The `vocab_size` argument is deprecated and will be removed in v4.42, since it can be inferred from the `text_config`. Passing this argument has no effect", - FutureWarning, - ) - self.vision_config = vision_config if isinstance(self.vision_config, dict): @@ -123,19 +114,5 @@ def __init__( text_config = CONFIG_MAPPING["llama"]() self.text_config = text_config - self._vocab_size = self.text_config.vocab_size super().__init__(**kwargs) - - @property - def vocab_size(self): - warnings.warn( - "The `vocab_size` attribute is deprecated and will be removed in v4.42, Please use `text_config.vocab_size` instead.", - FutureWarning, - ) - return self._vocab_size - - def to_dict(self): - output = super().to_dict() - output.pop("_vocab_size", None) - return output From 0cf60f13ab1c857c17fc3fb127129048c93bf06c Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:36:19 +0200 Subject: [PATCH 06/18] Add gemma 2 (#31659) * inital commit * Add doc * protect? * fixup stuffs * update tests * fix build documentation * mmmmmmm config attributes * style * nit * uodate * nit * Fix docs * protect some stuff --------- Co-authored-by: Lysandre --- docs/source/en/index.md | 1 + docs/source/en/model_doc/gemma2.md | 58 + src/transformers/__init__.py | 18 + src/transformers/cache_utils.py | 122 ++ .../generation/configuration_utils.py | 2 +- src/transformers/generation/utils.py | 17 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/gemma/diff_gemma.py | 3 +- .../models/gemma/modeling_gemma.py | 14 +- src/transformers/models/gemma2/__init__.py | 61 + .../models/gemma2/configuration_gemma2.py | 149 ++ .../gemma2/convert_gemma2_weights_to_hf.py | 239 +++ src/transformers/models/gemma2/diff_gemma2.py | 781 +++++++++ .../models/gemma2/modeling_gemma2.py | 1392 +++++++++++++++++ .../models/mistral/modeling_mistral.py | 1 - src/transformers/utils/dummy_pt_objects.py | 35 + tests/models/gemma/test_modeling_gemma.py | 75 +- tests/models/gemma2/__init__.py | 0 tests/models/gemma2/test_modeling_gemma2.py | 141 ++ tests/test_modeling_common.py | 2 +- utils/check_config_attributes.py | 1 + 24 files changed, 3057 insertions(+), 69 deletions(-) create mode 100644 docs/source/en/model_doc/gemma2.md create mode 100644 src/transformers/models/gemma2/__init__.py create mode 100644 src/transformers/models/gemma2/configuration_gemma2.py create mode 100644 src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py create mode 100644 src/transformers/models/gemma2/diff_gemma2.py create mode 100644 src/transformers/models/gemma2/modeling_gemma2.py create mode 100644 tests/models/gemma2/__init__.py create mode 100644 tests/models/gemma2/test_modeling_gemma2.py diff --git a/docs/source/en/index.md b/docs/source/en/index.md index bf1fd008692ea5..ac026067ac24b7 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -145,6 +145,7 @@ Flax), PyTorch, and/or TensorFlow. | [Funnel Transformer](model_doc/funnel) | ✅ | ✅ | ❌ | | [Fuyu](model_doc/fuyu) | ✅ | ❌ | ❌ | | [Gemma](model_doc/gemma) | ✅ | ❌ | ✅ | +| [Gemma2](model_doc/gemma2) | ✅ | ❌ | ❌ | | [GIT](model_doc/git) | ✅ | ❌ | ❌ | | [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ | | [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ | diff --git a/docs/source/en/model_doc/gemma2.md b/docs/source/en/model_doc/gemma2.md new file mode 100644 index 00000000000000..fa16dfbc4ba0fc --- /dev/null +++ b/docs/source/en/model_doc/gemma2.md @@ -0,0 +1,58 @@ + + + +# Gemma2 + +## Overview + +The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/Gemma2-open-models/) by Gemma2 Team, Google. +Gemma2 models are trained on 6T tokens, and released with 2 versions, 2b and 7b. + +The abstract from the paper is the following: + +*This work introduces Gemma2, a new family of open language models demonstrating strong performance across academic benchmarks for language understanding, reasoning, and safety. We release two sizes of models (2 billion and 7 billion parameters), and provide both pretrained and fine-tuned checkpoints. Gemma2 outperforms similarly sized open models on 11 out of 18 text-based tasks, and we present comprehensive evaluations of safety and responsibility aspects of the models, alongside a detailed description of our model development. We believe the responsible release of LLMs is critical for improving the safety of frontier models, and for enabling the next wave of LLM innovations* + +Tips: + +- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` + +This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen](). + + +## Gemma2Config + +[[autodoc]] Gemma2Config + +## Gemma2Model + +[[autodoc]] Gemma2Model + - forward + +## Gemma2ForCausalLM + +[[autodoc]] Gemma2ForCausalLM + - forward + +## Gemma2ForSequenceClassification + +[[autodoc]] Gemma2ForSequenceClassification + - forward + +## Gemma2ForTokenClassification + +[[autodoc]] Gemma2ForTokenClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 922e19915d8040..7b39fd479edf82 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -435,6 +435,7 @@ ], "models.fuyu": ["FuyuConfig"], "models.gemma": ["GemmaConfig"], + "models.gemma2": ["Gemma2Config"], "models.git": [ "GitConfig", "GitProcessor", @@ -2181,6 +2182,15 @@ "GemmaPreTrainedModel", ] ) + _import_structure["models.gemma2"].extend( + [ + "Gemma2ForCausalLM", + "Gemma2ForSequenceClassification", + "Gemma2ForTokenClassification", + "Gemma2Model", + "Gemma2PreTrainedModel", + ] + ) _import_structure["models.git"].extend( [ "GitForCausalLM", @@ -5062,6 +5072,7 @@ ) from .models.fuyu import FuyuConfig from .models.gemma import GemmaConfig + from .models.gemma2 import Gemma2Config from .models.git import ( GitConfig, GitProcessor, @@ -6694,6 +6705,13 @@ GemmaModel, GemmaPreTrainedModel, ) + from .models.gemma2 import ( + Gemma2ForCausalLM, + Gemma2ForSequenceClassification, + Gemma2ForTokenClassification, + Gemma2Model, + Gemma2PreTrainedModel, + ) from .models.git import ( GitForCausalLM, GitModel, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 532b921a3697a3..b167cd1d117085 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -970,3 +970,125 @@ def get_max_length(self) -> Optional[int]: # in theory there is no limit because the sliding window size is fixed # no matter how long the sentence is return None + + +class HybridCache(Cache): + def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + ) + self.is_sliding = torch.tensor( + [i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device + ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + sliding_cache_shape = ( + max_batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + for i in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + sliding_window: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_length(self) -> Optional[int]: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return self.max_cache_len + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return None + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 8bb5e091d6db59..8ba17a6a350f78 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -400,7 +400,7 @@ def __init__(self, **kwargs): # Cache implementation self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None: + if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] if self.cache_config is None: self.cache_config = cache_config_class() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3abd604cdb8611..2686f3af7af3a1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ Cache, DynamicCache, HQQQuantizedCache, + HybridCache, QuantizedCacheConfig, QuantoQuantizedCache, SlidingWindowCache, @@ -112,7 +113,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module -NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache} QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} @@ -1395,10 +1396,12 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = 0 if model_kwargs.get("past_key_values") is not None: - if isinstance(model_kwargs["past_key_values"], Cache): - past_length = model_kwargs["past_key_values"].get_seq_length() - else: - past_length = model_kwargs["past_key_values"][0][0].shape[2] + cache = model_kwargs["past_key_values"] + if not isinstance(cache, Cache): + past_length = cache[0][0].shape[2] + elif hasattr(cache, "get_seq_length"): + past_length = cache.get_seq_length() + if "inputs_embeds" in model_kwargs: cur_len = model_kwargs["inputs_embeds"].shape[1] else: @@ -1739,7 +1742,9 @@ def generate( "issue: https://github.com/huggingface/transformers/issues/28981" ) model_kwargs["past_key_values"] = self._get_cache( - generation_config.cache_implementation, batch_size, generation_config.max_length + generation_config.cache_implementation, + getattr(generation_config, "num_beams", 1) * batch_size, + generation_config.max_length, ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 87586686a02669..f4c33491472833 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -92,6 +92,7 @@ funnel, fuyu, gemma, + gemma2, git, glpn, gpt2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index dab9244dd0171d..7f52b3dc280ac6 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -108,6 +108,7 @@ ("funnel", "FunnelConfig"), ("fuyu", "FuyuConfig"), ("gemma", "GemmaConfig"), + ("gemma2", "Gemma2Config"), ("git", "GitConfig"), ("glpn", "GLPNConfig"), ("gpt-sw3", "GPT2Config"), @@ -385,6 +386,7 @@ ("funnel", "Funnel Transformer"), ("fuyu", "Fuyu"), ("gemma", "Gemma"), + ("gemma2", "Gemma2"), ("git", "GIT"), ("glpn", "GLPN"), ("gpt-sw3", "GPT-Sw3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7190d75a873c60..f674b777fca7be 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -105,6 +105,7 @@ ("fsmt", "FSMTModel"), ("funnel", ("FunnelModel", "FunnelBaseModel")), ("gemma", "GemmaModel"), + ("gemma2", "Gemma2Model"), ("git", "GitModel"), ("glpn", "GLPNModel"), ("gpt-sw3", "GPT2Model"), @@ -454,6 +455,7 @@ ("falcon", "FalconForCausalLM"), ("fuyu", "FuyuForCausalLM"), ("gemma", "GemmaForCausalLM"), + ("gemma2", "Gemma2ForCausalLM"), ("git", "GitForCausalLM"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), @@ -863,6 +865,7 @@ ("fnet", "FNetForSequenceClassification"), ("funnel", "FunnelForSequenceClassification"), ("gemma", "GemmaForSequenceClassification"), + ("gemma2", "Gemma2ForSequenceClassification"), ("gpt-sw3", "GPT2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), @@ -1044,6 +1047,7 @@ ("fnet", "FNetForTokenClassification"), ("funnel", "FunnelForTokenClassification"), ("gemma", "GemmaForTokenClassification"), + ("gemma2", "Gemma2ForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7dd805ae7f186b..dddab5379f5657 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -188,6 +188,13 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "gemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index 1165b05483fc82..d1df9d8cfb07d6 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -257,6 +257,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -305,7 +306,7 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ed6d61793bcc6f..c0da2530fe2c4e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -240,6 +240,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -288,7 +289,7 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -898,6 +899,13 @@ def forward( # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1397,7 +1405,7 @@ def set_input_embeddings(self, value): @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -1407,7 +1415,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/gemma2/__init__.py b/src/transformers/models/gemma2/__init__.py new file mode 100644 index 00000000000000..0d0aa148be5e33 --- /dev/null +++ b/src/transformers/models/gemma2/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_gemma2": ["Gemma2Config"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gemma2"] = [ + "Gemma2ForCausalLM", + "Gemma2Model", + "Gemma2PreTrainedModel", + "Gemma2ForSequenceClassification", + "Gemma2ForTokenClassification", + ] + +if TYPE_CHECKING: + from .configuration_gemma import Gemma2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gemma import ( + Gemma2ForCausalLM, + Gemma2ForSequenceClassification, + Gemma2ForTokenClassification, + Gemma2Model, + Gemma2PreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py new file mode 100644 index 00000000000000..47207d7ca12436 --- /dev/null +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -0,0 +1,149 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers import PretrainedConfig + + +class Gemma2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma2-7B. + e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma2Model`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the + size of the sliding window. + ```python + >>> from transformers import Gemma2Model, Gemma2Config + >>> # Initializing a Gemma2 gemma2-9b style configuration + >>> configuration = Gemma2Config() + >>> # Initializing a model from the gemma2-9b style configuration + >>> model = Gemma2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + final_logit_softcapping=30.0, + query_pre_attn_scalar=224, + sliding_window=4096, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.final_logit_softcapping = final_logit_softcapping + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.cache_implementation = "hybrid" diff --git a/src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py b/src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py new file mode 100644 index 00000000000000..1ad7d23c3c3e3c --- /dev/null +++ b/src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py @@ -0,0 +1,239 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import warnings + +import torch +from accelerate import init_empty_weights + +from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer + + +try: + from transformers import GemmaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + GemmaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/gemma2/convert_gemma2_weights_to_hf.py \ + --input_dir /path/to/downloaded/gemma/weights --model_size 9B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Gemma2ForCausalLM, GemmaTokenizerFast + +model = Gemma2ForCausalLM.from_pretrained("/output/path") +tokenizer = GemmaTokenizerFast.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +gemma_9b_config = Gemma2Config( + num_hidden_layers=42, + num_attention_heads=16, + num_key_value_heads=8, + hidden_size=3584, + intermediate_size=14336, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + head_dim=256, + sliding_window=4096, + query_pre_attn_scalar=224, +) + +gemma_27b_config = Gemma2Config( + num_hidden_layers=46, + num_attention_heads=32, + num_key_value_heads=16, + hidden_size=4608, + intermediate_size=36864, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + head_dim=128, + sliding_window=4096, + query_pre_attn_scalar=144, +) + +CONFIG_MAPPING = {"9B": gemma_9b_config, "27B": gemma_27b_config} +LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"} + + +def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False, dtype=torch.float32): + num_attn_heads = config.num_attention_heads + hidden_size = config.hidden_size + num_kv_heads = config.num_key_value_heads + head_dim = config.head_dim + + print(f"Fetching all parameters from the checkpoint at '{input_base_path}'") + + if os.path.isdir(input_base_path): + print("Model seems sharded") + + model_state_dict = {} + files = [file for file in os.listdir(input_base_path) if file.endswith(".bin")] + + for file in files: + print(file) + loaded_state_dict = torch.load(os.path.join(input_base_path, file), map_location="cpu") + model_state_dict.update(loaded_state_dict) + else: + print("Model does not seem to be sharded") + model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"] + model_state_dict.pop("freqs_cis") + + state_dict = {} + for k, v in model_state_dict.items(): + if "qkv_proj" in k: + if num_kv_heads == 1: + v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, hidden_size) + q_proj = v[:num_attn_heads, ...] + k_proj = v[num_attn_heads : num_attn_heads + num_kv_heads, ...].repeat(num_kv_heads, 1, 1) + v_proj = v[-num_kv_heads:, ...].repeat(num_kv_heads, 1, 1) + + state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape( + num_attn_heads * head_dim, hidden_size + ).clone() + state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape( + num_kv_heads * head_dim, hidden_size + ).clone() + state_dict[k.replace("qkv_proj", "v_proj")] = v_proj[0].clone() + else: + q_proj, k_proj, v_proj = torch.split( + v, [num_attn_heads * head_dim, num_kv_heads * head_dim, num_kv_heads * head_dim], 0 + ) + state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape( + num_attn_heads * head_dim, hidden_size + ).clone() + state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape( + num_kv_heads * head_dim, hidden_size + ).clone() + state_dict[k.replace("qkv_proj", "v_proj")] = v_proj.reshape( + num_kv_heads * head_dim, hidden_size + ).clone() + + elif k == "embedder.weight": + state_dict[LAYER_NAME_MAPPING[k]] = v + state_dict["lm_head.weight"] = v + else: + state_dict[k] = v + + torch.set_default_dtype(dtype) + + print("Loading the checkpoint in a Gemma2 model.") + with init_empty_weights(): + model = Gemma2ForCausalLM(config) + model.load_state_dict(state_dict, assign=True, strict=False) + + model.config.torch_dtype = torch.float32 + del model.config._name_or_path + print("Saving in the Transformers format.") + + if push_to_hub: + print(f"pushing the model to {save_path}") + model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True) + else: + model.save_pretrained(save_path, safe_serialization=safe_serialization) + + +def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {save_path}.") + tokenizer = tokenizer_class(input_tokenizer_path) + if push_to_hub: + tokenizer.push_to_hub(save_path) + else: + tokenizer.save_pretrained(save_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_checkpoint", + help="Absolute path to the target Gemma2 weights.", + required=True, + ) + parser.add_argument( + "--tokenizer_checkpoint", + help="Location of Gemma2 tokenizer model", + ) + parser.add_argument( + "--model_size", + default="9B", + choices=["9B", "27B", "tokenizer_only"], + help="'f' models correspond to the finetuned versions, and are specific to the Gemma22 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b", + ) + parser.add_argument( + "--output_dir", + default="google/gemma-9b", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--pickle_serialization", + help="Whether or not to save using `safetensors`.", + action="store_true", + default=False, + ) + parser.add_argument( + "--convert_tokenizer", + help="Whether or not to convert the tokenizer as well.", + action="store_true", + default=False, + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", + action="store_true", + default=False, + ) + parser.add_argument( + "--dtype", + default="float32", + help="Target dtype of the converted model", + ) + args = parser.parse_args() + + if args.convert_tokenizer: + if args.tokenizer_checkpoint is None: + raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer") + + spm_path = os.path.join(args.tokenizer_checkpoint) + write_tokenizer(spm_path, args.output_dir, args.push_to_hub) + if not args.model_size == "tokenizer_only": + config = CONFIG_MAPPING[args.model_size] + dtype = getattr(torch, args.dtype) + write_model( + config=config, + input_base_path=args.input_checkpoint, + save_path=args.output_dir, + safe_serialization=not args.pickle_serialization, + push_to_hub=args.push_to_hub, + dtype=dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py new file mode 100644 index 00000000000000..443be0cf87f5de --- /dev/null +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -0,0 +1,781 @@ +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from transformers.models.gemma.configuration_gemma import GemmaConfig +from transformers.models.gemma.modeling_gemma import ( + GemmaAttention, + GemmaDecoderLayer, + GemmaForCausalLM, + GemmaForSequenceClassification, + GemmaForTokenClassification, + GemmaModel, + GemmaRMSNorm, + apply_rotary_pos_emb, + repeat_kv, +) + +from ...cache_utils import Cache +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +logger = logging.get_logger(__name__) + + +class Gemma2Config(GemmaConfig): + cache_implementation = "hybrid" # TODO this is not properly ported, but cls attr is better + + def __init__( + self, + query_pre_attn_scalar=224, + sliding_window=4096, + final_logit_softcapping=30.0, + **super_kwargs, + ): + super().__init__(self, **super_kwargs) + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.cache_implementation = "hybrid" + self.final_logit_softcapping = final_logit_softcapping + + +class Gemma2RMSNorm(GemmaRMSNorm): + pass + + +class Gemma2Attention(GemmaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + self.scaling = config.query_pre_attn_scalar**-0.5 + + super().__init__(config, layer_idx) + + +class Gemma2FlashAttention2(Gemma2Attention): + """ + Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Gemma2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.scaling, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + cache_position=0, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Gemma2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + use_sliding_windows = ( + _flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window + ) + flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {} + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class Gemma2SdpaAttention(Gemma2Attention): + """ + Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Gemma2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + scale=self.scaling, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class Gemma2DecoderLayer(GemmaDecoderLayer): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__(config, layer_idx) + + self.is_sliding = bool(layer_idx % 2) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + attention_mask = attention_mask * torch.tril( + torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1]) + ) + if cache_position[0] > 0: + attention_mask = attention_mask[:, -self.sliding_window :] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Gemma2Model(GemmaModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = past_key_values if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @torch.no_grad() + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if past_key_values is not None: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Gemma2ForCausalLM(GemmaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device) + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +class Gemma2ForSequenceClassification(GemmaForSequenceClassification): + pass + + +class Gemma2ForTokenClassification(GemmaForTokenClassification): + pass diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py new file mode 100644 index 00000000000000..2f4768e59f46e9 --- /dev/null +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -0,0 +1,1392 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_gemma2 import Gemma2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class Gemma2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + +class Gemma2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Gemma2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Gemma2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = config.query_pre_attn_scalar**-0.5 + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + self.sliding_window = config.sliding_window if layer_idx % 2 else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Gemma2FlashAttention2(Gemma2Attention): + """ + Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Gemma2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.scaling, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + cache_position=0, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Gemma2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # TODO this is not compile compatible + use_sliding_windows = ( + _flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window + ) + flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {} + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class Gemma2SdpaAttention(Gemma2Attention): + """ + Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Gemma2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Gemma2Model is using Gemma2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + scale=self.scaling, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +GEMMA2_ATTENTION_CLASSES = { + "eager": Gemma2Attention, + "flash_attention_2": Gemma2FlashAttention2, + "sdpa": Gemma2SdpaAttention, +} + + +class Gemma2DecoderLayer(nn.Module): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.is_sliding = bool(layer_idx % 2) + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + attention_mask = attention_mask * torch.tril( + torch.ones_like(attention_mask), diagonal=-self.sliding_window + ) + if attention_mask.shape[1] <= 1: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +GEMMA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Gemma2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", + GEMMA2_START_DOCSTRING, +) +class Gemma2PreTrainedModel(PreTrainedModel): + config_class = Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = False + _supports_quantized_cache = False + _supports_static_cache = True + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +_CONFIG_FOR_DOC = "Gemma2Config" + + +GEMMA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", + GEMMA2_START_DOCSTRING, +) +class Gemma2Model(Gemma2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`] + + Args: + config: Gemma2Config + """ + + def __init__(self, config: Gemma2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = past_key_values if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if past_key_values is not None: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Gemma2ForCausalLM(Gemma2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Gemma2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device) + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Gemma2 Model transformer with a sequence classification head on top (linear layer). + + [`Gemma2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GEMMA2_START_DOCSTRING, +) +class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Gemma2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Gemma2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + GEMMA2_START_DOCSTRING, +) +class Gemma2ForTokenClassification(Gemma2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Gemma2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 07942e87e68c29..475dda72c59295 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -227,7 +227,6 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): base=self.rope_theta, ) - # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9132d161820b11..c9267debc5de81 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4197,6 +4197,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Gemma2ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma2ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GitForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index bdf012774371de..c7fb55f682ed0e 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -47,11 +47,18 @@ GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, - GemmaTokenizer, ) +@require_torch class GemmaModelTester: + config_class = GemmaConfig + if is_torch_available(): + model_class = GemmaModel + for_causal_lm_class = GemmaForCausalLM + for_sequence_class = GemmaForSequenceClassification + for_token_class = GemmaForTokenClassification + def __init__( self, parent, @@ -129,9 +136,8 @@ def prepare_config_and_inputs(self): return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - # Ignore copy def get_config(self): - return GemmaConfig( + return self.config_class( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, @@ -149,18 +155,16 @@ def get_config(self): head_dim=self.head_dim, ) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): - model = GemmaModel(config=config) + model = self.model_class(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma def create_and_check_model_as_decoder( self, config, @@ -174,7 +178,7 @@ def create_and_check_model_as_decoder( encoder_attention_mask, ): config.add_cross_attention = True - model = GemmaModel(config) + model = self.model_class(config) model.to(torch_device) model.eval() result = model( @@ -191,7 +195,6 @@ def create_and_check_model_as_decoder( result = model(input_ids, attention_mask=input_mask) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma def create_and_check_for_causal_lm( self, config, @@ -204,13 +207,12 @@ def create_and_check_for_causal_lm( encoder_hidden_states, encoder_attention_mask, ): - model = GemmaForCausalLM(config=config) + model = self.for_causal_lm_class(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma def create_and_check_decoder_model_past_large_inputs( self, config, @@ -225,7 +227,7 @@ def create_and_check_decoder_model_past_large_inputs( ): config.is_decoder = True config.add_cross_attention = True - model = GemmaForCausalLM(config=config) + model = self.for_causal_lm_class(config=config) model.to(torch_device) model.eval() @@ -348,7 +350,7 @@ def test_Gemma_sequence_classification_model(self): input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = GemmaForSequenceClassification(config) + model = self.model_tester.for_sequence_class(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -361,7 +363,7 @@ def test_Gemma_sequence_classification_model_for_single_label(self): input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = GemmaForSequenceClassification(config) + model = self.model_tester.for_sequence_class(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -376,20 +378,19 @@ def test_Gemma_sequence_classification_model_for_multi_label(self): sequence_labels = ids_tensor( [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size ).to(torch.float) - model = GemmaForSequenceClassification(config) + model = self.model_tester.for_sequence_class(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma def test_Gemma_token_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = GemmaForTokenClassification(config=config) + model = self.model_tester.for_token_class(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=token_labels) @@ -539,47 +540,9 @@ def setUpClass(cls): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @require_read_token - def test_model_2b_fp32(self): - model_id = "google/gemma-2b" - EXPECTED_TEXTS = [ - "Hello I am doing a project on the 1990s and I need to know what the most popular music", - "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", - ] - - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device) - - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) - - output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - - self.assertEqual(output_text, EXPECTED_TEXTS) - @require_read_token def test_model_2b_fp16(self): - model_id = "google/gemma-2b" - EXPECTED_TEXTS = [ - "Hello I am doing a project on the 1990s and I need to know what the most popular music", - "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", - ] - - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( - torch_device - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) - - output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - - self.assertEqual(output_text, EXPECTED_TEXTS) - - @require_read_token - def test_model_2b_fp16_static_cache(self): - model_id = "google/gemma-2b" + model_id = "google/gemma-2-9b" EXPECTED_TEXTS = [ "Hello I am doing a project on the 1990s and I need to know what the most popular music", "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", @@ -903,7 +866,7 @@ def test_compile_static_cache(self): } prompts = ["Hello I am doing", "Hi today"] - tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) diff --git a/tests/models/gemma2/__init__.py b/tests/models/gemma2/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py new file mode 100644 index 00000000000000..6a6c5688d5d68c --- /dev/null +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Gemma2 model.""" + +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available +from transformers.testing_utils import ( + require_read_token, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester +from ...test_configuration_common import ConfigTester + + +if is_torch_available(): + import torch + + from transformers import ( + Gemma2ForCausalLM, + Gemma2ForSequenceClassification, + Gemma2ForTokenClassification, + Gemma2Model, + ) + + +class Gemma2ModelTester(GemmaModelTester): + config_class = Gemma2Config + model_class = Gemma2Model + for_causal_lm_class = Gemma2ForCausalLM + for_sequence_class = Gemma2ForSequenceClassification + for_token_class = Gemma2ForTokenClassification + + +@require_torch +class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): + all_model_classes = ( + (Gemma2Model, Gemma2ForCausalLM, Gemma2ForSequenceClassification, Gemma2ForTokenClassification) + if is_torch_available() + else () + ) + all_generative_model_classes = () + pipeline_model_mapping = ( + { + "feature-extraction": Gemma2Model, + "text-classification": Gemma2ForSequenceClassification, + "token-classification": Gemma2ForTokenClassification, + "text-generation": Gemma2ForCausalLM, + "zero-shot": Gemma2ForSequenceClassification, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + _is_stateful = True + model_split_percents = [0.5, 0.6] + _torch_compile_test_ckpt = "google/gemma-2-9b" + + def setUp(self): + self.model_tester = Gemma2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37) + + @unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @unittest.skip("Gemma2's outputs are expected to be different") + def test_eager_matches_sdpa_inference(self): + pass + + +@slow +@require_torch_gpu +class Gemma2IntegrationTest(unittest.TestCase): + input_text = ["Hello I am doing", "Hi today"] + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + + @require_read_token + def test_model_2b_bf16(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project for a class and I am trying to use the ", + "Hi today. So, I'm going to show you how to do a problem from the textbook. So", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + def test_model_2b_fp16(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the effect of the temperature on the rate of a reaction. I am using a ", + "Hi today I'm going to be talking about the 1000-4000-", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 11c34462ba5d2c..4561c93c21db31 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -505,7 +505,7 @@ def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): # Check that the parameters are equal. for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()): - self.assertEquals(p1.data.ne(p2.data).sum(), 0) + self.assertEqual(p1.data.ne(p2.data).sum(), 0) # Check that the state dict keys are equal. self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys())) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index e6edcf517a0936..91113717610a9b 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -41,6 +41,7 @@ "expert_layer_offset", "expert_layer_period", ], + "Gemma2Config": ["tie_word_embeddings"], # used to compute the property `self.chunk_length` "EncodecConfig": ["overlap"], # used to compute the property `self.layers_block_type` From 727eea4ab0aa50b1b79df3de37640fd10010fd71 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 27 Jun 2024 17:40:07 +0200 Subject: [PATCH 07/18] v4.43.0.dev0 --- examples/flax/question-answering/run_qa.py | 2 +- .../speech-recognition/run_flax_speech_recognition_seq2seq.py | 2 +- examples/flax/text-classification/run_flax_glue.py | 2 +- examples/flax/token-classification/run_flax_ner.py | 2 +- .../pytorch/audio-classification/run_audio_classification.py | 2 +- examples/pytorch/contrastive-image-text/run_clip.py | 2 +- .../pytorch/image-classification/run_image_classification.py | 2 +- .../run_image_classification_no_trainer.py | 2 +- examples/pytorch/image-pretraining/run_mae.py | 2 +- examples/pytorch/image-pretraining/run_mim.py | 2 +- examples/pytorch/image-pretraining/run_mim_no_trainer.py | 2 +- .../pytorch/instance-segmentation/run_instance_segmentation.py | 3 ++- .../run_instance_segmentation_no_trainer.py | 3 ++- examples/pytorch/language-modeling/run_clm.py | 2 +- examples/pytorch/language-modeling/run_clm_no_trainer.py | 2 +- examples/pytorch/language-modeling/run_fim.py | 2 +- examples/pytorch/language-modeling/run_fim_no_trainer.py | 2 +- examples/pytorch/language-modeling/run_mlm.py | 2 +- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 2 +- examples/pytorch/language-modeling/run_plm.py | 2 +- examples/pytorch/multiple-choice/run_swag.py | 2 +- examples/pytorch/multiple-choice/run_swag_no_trainer.py | 2 +- examples/pytorch/object-detection/run_object_detection.py | 2 +- .../object-detection/run_object_detection_no_trainer.py | 2 +- examples/pytorch/question-answering/run_qa.py | 2 +- examples/pytorch/question-answering/run_qa_beam_search.py | 2 +- .../question-answering/run_qa_beam_search_no_trainer.py | 2 +- examples/pytorch/question-answering/run_qa_no_trainer.py | 2 +- examples/pytorch/question-answering/run_seq2seq_qa.py | 2 +- .../pytorch/semantic-segmentation/run_semantic_segmentation.py | 2 +- .../run_semantic_segmentation_no_trainer.py | 2 +- .../pytorch/speech-recognition/run_speech_recognition_ctc.py | 2 +- .../speech-recognition/run_speech_recognition_ctc_adapter.py | 2 +- .../speech-recognition/run_speech_recognition_seq2seq.py | 2 +- examples/pytorch/summarization/run_summarization.py | 2 +- examples/pytorch/summarization/run_summarization_no_trainer.py | 2 +- examples/pytorch/text-classification/run_classification.py | 2 +- examples/pytorch/text-classification/run_glue.py | 2 +- examples/pytorch/text-classification/run_glue_no_trainer.py | 2 +- examples/pytorch/text-classification/run_xnli.py | 2 +- examples/pytorch/token-classification/run_ner.py | 2 +- examples/pytorch/token-classification/run_ner_no_trainer.py | 2 +- examples/pytorch/translation/run_translation.py | 2 +- examples/pytorch/translation/run_translation_no_trainer.py | 2 +- examples/tensorflow/contrastive-image-text/run_clip.py | 2 +- .../image-classification/run_image_classification.py | 2 +- examples/tensorflow/multiple-choice/run_swag.py | 2 +- examples/tensorflow/question-answering/run_qa.py | 2 +- examples/tensorflow/summarization/run_summarization.py | 2 +- examples/tensorflow/text-classification/run_glue.py | 2 +- examples/tensorflow/translation/run_translation.py | 2 +- setup.py | 2 +- src/transformers/__init__.py | 2 +- 53 files changed, 55 insertions(+), 53 deletions(-) diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index f80cd8a0341462..1819b0235fafc9 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -61,7 +61,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") Array = Any Dataset = datasets.arrow_dataset.Dataset diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index d911797cb9f26d..501aaf18642400 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -60,7 +60,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risk. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index d1234db015dc5b..4494243c9a5d2b 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") Array = Any Dataset = datasets.arrow_dataset.Dataset diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 51f66777cd8ac2..6eb162adcb0784 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") diff --git a/examples/pytorch/audio-classification/run_audio_classification.py b/examples/pytorch/audio-classification/run_audio_classification.py index 269199a5b708f5..3c75f0b1504d19 100644 --- a/examples/pytorch/audio-classification/run_audio_classification.py +++ b/examples/pytorch/audio-classification/run_audio_classification.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py index 726ce3c4421c41..fed4d0bf6fab56 100644 --- a/examples/pytorch/contrastive-image-text/run_clip.py +++ b/examples/pytorch/contrastive-image-text/run_clip.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index 63c1a0a76005bc..b7557b903fdf06 100755 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 2d0dd070e86a10..e67424f6819ca9 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -49,7 +49,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index dd94c3e5104e47..8af6e18b1ca37c 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") diff --git a/examples/pytorch/image-pretraining/run_mim.py b/examples/pytorch/image-pretraining/run_mim.py index ce90aeb75c09eb..c2c3ff818b5b6b 100644 --- a/examples/pytorch/image-pretraining/run_mim.py +++ b/examples/pytorch/image-pretraining/run_mim.py @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") diff --git a/examples/pytorch/image-pretraining/run_mim_no_trainer.py b/examples/pytorch/image-pretraining/run_mim_no_trainer.py index 0008f2bd7fccaf..e3efbec76c4419 100644 --- a/examples/pytorch/image-pretraining/run_mim_no_trainer.py +++ b/examples/pytorch/image-pretraining/run_mim_no_trainer.py @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") diff --git a/examples/pytorch/instance-segmentation/run_instance_segmentation.py b/examples/pytorch/instance-segmentation/run_instance_segmentation.py index e8d7ee04891e24..9a29e43d7d304c 100644 --- a/examples/pytorch/instance-segmentation/run_instance_segmentation.py +++ b/examples/pytorch/instance-segmentation/run_instance_segmentation.py @@ -46,7 +46,8 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") + require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") diff --git a/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py b/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py index 7c0eb31068bd42..8f57997deacbc7 100644 --- a/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py +++ b/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py @@ -52,7 +52,8 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") + require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index c0db5703722ef4..cf80ae83ab2e0a 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -55,7 +55,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index f93935f406ea06..7ef8d94f3e3cfb 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -57,7 +57,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) diff --git a/examples/pytorch/language-modeling/run_fim.py b/examples/pytorch/language-modeling/run_fim.py index fa7ebbfd747e03..7154f1ffcd71e5 100644 --- a/examples/pytorch/language-modeling/run_fim.py +++ b/examples/pytorch/language-modeling/run_fim.py @@ -58,7 +58,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/pytorch/language-modeling/run_fim_no_trainer.py b/examples/pytorch/language-modeling/run_fim_no_trainer.py index 23de80fc829e4b..11c64c7c4849ef 100644 --- a/examples/pytorch/language-modeling/run_fim_no_trainer.py +++ b/examples/pytorch/language-modeling/run_fim_no_trainer.py @@ -60,7 +60,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 3485b9ca1b083c..f40015d9701dc1 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -54,7 +54,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 42cb0008e7b298..75c3b1936fd024 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -57,7 +57,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index b51f1acbf376c8..33dc8baaa6e93f 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index 0e9b1390664b68..51c13458ed44c0 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index bf293bb190ca6d..8356493762ed7c 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -56,7 +56,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) # You should update this to your particular problem to have better documentation of `model_type` diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index e82d902913996e..3f1eb681df225a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") diff --git a/examples/pytorch/object-detection/run_object_detection_no_trainer.py b/examples/pytorch/object-detection/run_object_detection_no_trainer.py index f79c9ddf9bbda9..296f045a5234e3 100644 --- a/examples/pytorch/object-detection/run_object_detection_no_trainer.py +++ b/examples/pytorch/object-detection/run_object_detection_no_trainer.py @@ -51,7 +51,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logging.basicConfig(level=logging.INFO) logger = get_logger(__name__) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 6e5ddbf0810742..83eda3e98a75ea 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py index e528a1bc53b19a..4ba3564d6e505e 100755 --- a/examples/pytorch/question-answering/run_qa_beam_search.py +++ b/examples/pytorch/question-answering/run_qa_beam_search.py @@ -48,7 +48,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py index b05f0c6f503d88..db6256aedbacfd 100644 --- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py @@ -56,7 +56,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py index c0f98ce2331475..202cc7d661db87 100755 --- a/examples/pytorch/question-answering/run_qa_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_no_trainer.py @@ -57,7 +57,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/pytorch/question-answering/run_seq2seq_qa.py b/examples/pytorch/question-answering/run_seq2seq_qa.py index fffd643a650918..1932df9677ce6c 100644 --- a/examples/pytorch/question-answering/run_seq2seq_qa.py +++ b/examples/pytorch/question-answering/run_seq2seq_qa.py @@ -46,7 +46,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py index 6aa66b9c48a8cb..a5929205b2f5fe 100644 --- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py +++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py index 4aa26f9aab1914..5ff906c22cba16 100644 --- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py +++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index ac92612ea8ff13..b4019d64f774ce 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py index 9a93910739e4ee..3da281430ec272 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py @@ -53,7 +53,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index 666bf075f29e24..501b2df1c5eb6a 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -48,7 +48,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 7b9a126bda90d5..225b20fcd63572 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -52,7 +52,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 2af52b935a65c8..325cbebd9634b0 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -56,7 +56,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index ff05b78cb538ec..b4520d6af340eb 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index b566c45215183f..967b6c7be5bf28 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -48,7 +48,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index 6a8123f076decd..e66b6a35150f3d 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -49,7 +49,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) diff --git a/examples/pytorch/text-classification/run_xnli.py b/examples/pytorch/text-classification/run_xnli.py index 127f06e0f67f57..7fa0d3bc9005bf 100755 --- a/examples/pytorch/text-classification/run_xnli.py +++ b/examples/pytorch/text-classification/run_xnli.py @@ -48,7 +48,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index c9bb86588fd645..9229b50063c842 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -49,7 +49,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index b7d3e3df0078bb..a7d67942979dcc 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -56,7 +56,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 35a2ab0ef23c19..7ccfc30802f5f0 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -52,7 +52,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py index f30d12a77e8793..e2e9f3d3d3daa0 100644 --- a/examples/pytorch/translation/run_translation_no_trainer.py +++ b/examples/pytorch/translation/run_translation_no_trainer.py @@ -57,7 +57,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = get_logger(__name__) require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") diff --git a/examples/tensorflow/contrastive-image-text/run_clip.py b/examples/tensorflow/contrastive-image-text/run_clip.py index 786e9800007d73..0644ab25bafc69 100644 --- a/examples/tensorflow/contrastive-image-text/run_clip.py +++ b/examples/tensorflow/contrastive-image-text/run_clip.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version( "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" diff --git a/examples/tensorflow/image-classification/run_image_classification.py b/examples/tensorflow/image-classification/run_image_classification.py index 1cdb6ef2950138..a3ea7cf0b10d99 100644 --- a/examples/tensorflow/image-classification/run_image_classification.py +++ b/examples/tensorflow/image-classification/run_image_classification.py @@ -55,7 +55,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") diff --git a/examples/tensorflow/multiple-choice/run_swag.py b/examples/tensorflow/multiple-choice/run_swag.py index 02c55bc771a2b6..7d8189b087efd8 100644 --- a/examples/tensorflow/multiple-choice/run_swag.py +++ b/examples/tensorflow/multiple-choice/run_swag.py @@ -50,7 +50,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/tensorflow/question-answering/run_qa.py b/examples/tensorflow/question-answering/run_qa.py index 821c8529e54563..d9758c401826b7 100755 --- a/examples/tensorflow/question-answering/run_qa.py +++ b/examples/tensorflow/question-answering/run_qa.py @@ -62,7 +62,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/tensorflow/summarization/run_summarization.py b/examples/tensorflow/summarization/run_summarization.py index 8aaa033c1bc324..3f5190186dc3f5 100644 --- a/examples/tensorflow/summarization/run_summarization.py +++ b/examples/tensorflow/summarization/run_summarization.py @@ -53,7 +53,7 @@ # region Checking dependencies # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/examples/tensorflow/text-classification/run_glue.py b/examples/tensorflow/text-classification/run_glue.py index 9f3893e8873452..e2eb2157fa9e65 100644 --- a/examples/tensorflow/text-classification/run_glue.py +++ b/examples/tensorflow/text-classification/run_glue.py @@ -47,7 +47,7 @@ # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") task_to_keys = { "cola": ("sentence", None), diff --git a/examples/tensorflow/translation/run_translation.py b/examples/tensorflow/translation/run_translation.py index f183657e49a143..7280f7a95b37a5 100644 --- a/examples/tensorflow/translation/run_translation.py +++ b/examples/tensorflow/translation/run_translation.py @@ -56,7 +56,7 @@ # region Dependencies and constants # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.42.0.dev0") +check_min_version("4.43.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") diff --git a/setup.py b/setup.py index f438dd8225a4b9..8f28f0f9e4adb7 100644 --- a/setup.py +++ b/setup.py @@ -430,7 +430,7 @@ def run(self): setup( name="transformers", - version="4.42.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="4.43.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", author_email="transformers@huggingface.co", description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7b39fd479edf82..c559ed61acad03 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -18,7 +18,7 @@ # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # in the namespace without actually importing anything (and especially none of the backends). -__version__ = "4.42.0.dev0" +__version__ = "4.43.0.dev0" from typing import TYPE_CHECKING From 75a6319864225b8350c31b623ea2c73c23012a40 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:51:42 +0200 Subject: [PATCH 08/18] Fix post gemma merge (#31660) * nit * toctree issue * protect gemma2 tests as well * sdpa supported --- docs/source/en/_toctree.yml | 2 ++ docs/source/en/perf_infer_gpu_one.md | 2 ++ tests/models/gemma2/test_modeling_gemma2.py | 11 ++++++----- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 94f5d8d19e6f38..e48378d8c25377 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -382,6 +382,8 @@ title: Fuyu - local: model_doc/gemma title: Gemma + - local: model_doc/gemma2 + title: Gemma2 - local: model_doc/openai-gpt title: GPT - local: model_doc/gpt_neo diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index add92a9440c2b8..1569bef1f6ba1f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -43,6 +43,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) @@ -202,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 6a6c5688d5d68c..870265f9460f7b 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -41,11 +41,12 @@ class Gemma2ModelTester(GemmaModelTester): - config_class = Gemma2Config - model_class = Gemma2Model - for_causal_lm_class = Gemma2ForCausalLM - for_sequence_class = Gemma2ForSequenceClassification - for_token_class = Gemma2ForTokenClassification + if is_torch_available(): + config_class = Gemma2Config + model_class = Gemma2Model + for_causal_lm_class = Gemma2ForCausalLM + for_sequence_class = Gemma2ForSequenceClassification + for_token_class = Gemma2ForTokenClassification @require_torch From e44b878c0252ac1c841afcd68dd873c7fe307289 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Fri, 28 Jun 2024 01:07:33 +0800 Subject: [PATCH 09/18] Fix float out of range in owlvit and owlv2 when using FP16 or lower precision (#31657) --- src/transformers/models/owlv2/modeling_owlv2.py | 2 +- src/transformers/models/owlvit/modeling_owlvit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 05c5cd4595b5df..638a9d966e0c7f 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -1276,7 +1276,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index ee6d8aa423d1cf..32e2012b214646 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1257,7 +1257,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds) From 464aa74659b9711d2d64159ab82e1c49fa739fb7 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Thu, 27 Jun 2024 10:32:51 -0700 Subject: [PATCH 10/18] [docs] Llama3 (#31662) quick usage to top --- docs/source/en/model_doc/llama3.md | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/source/en/model_doc/llama3.md b/docs/source/en/model_doc/llama3.md index 067d2e9ba934d5..996be7f5c10b7d 100644 --- a/docs/source/en/model_doc/llama3.md +++ b/docs/source/en/model_doc/llama3.md @@ -16,6 +16,15 @@ rendered properly in your Markdown viewer. # Llama3 +```py3 +import transformers +import torch + +model_id = "meta-llama/Meta-Llama-3-8B" + +pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto") +pipeline("Hey how are you doing today?") +``` ## Overview @@ -66,20 +75,7 @@ model = AutoModelForCausalLM.from_pretrained("/output/path") Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed. - - When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type. -## Quick usage - -```py3 -import transformers -import torch - -model_id = "meta-llama/Meta-Llama-3-8B" - -pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto") -pipeline("Hey how are you doing today?") -``` - ## Resources -A ton of cool resources are already available on the documentation page of [~llama2], inviting contributors to add new resources curated for Llama3 here! 🤗 +A ton of cool resources are already available on the documentation page of [Llama2](./llama2), inviting contributors to add new resources curated for Llama3 here! 🤗 From 1c68f2cafb4ca54562f74b66d1085b68dd6682f5 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 27 Jun 2024 18:40:40 +0100 Subject: [PATCH 11/18] [HybridCache] Fix `get_seq_length` method (#31661) * fix gemma2 * handle in generate --- src/transformers/cache_utils.py | 2 +- src/transformers/generation/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b167cd1d117085..d572b8c8c71636 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1083,7 +1083,7 @@ def get_max_length(self) -> Optional[int]: # no matter how long the sentence is return self.max_cache_len - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0): return None def reset(self): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2686f3af7af3a1..9c69bb35d264fe 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1399,7 +1399,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): cache = model_kwargs["past_key_values"] if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length"): + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() if "inputs_embeds" in model_kwargs: From 0142aab7f8546bc5617832740c560ab1caf2186c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 27 Jun 2024 22:59:54 -0700 Subject: [PATCH 12/18] don't zero out the attention_mask when using sliding window with flash attention (#31670) * don't zero out the attention_mask when using sliding window with flash attention * chore: lint --- src/transformers/models/gemma2/modeling_gemma2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2f4768e59f46e9..09ce72c8b1b231 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -602,6 +602,7 @@ def forward( class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) @@ -625,7 +626,9 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + if ( + self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None + ): # efficient SDPA and no padding attention_mask = attention_mask * torch.tril( torch.ones_like(attention_mask), diagonal=-self.sliding_window ) From 5e89b335ab088b52e5dbb50c5368d944e9052ed4 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 28 Jun 2024 14:20:30 +0800 Subject: [PATCH 13/18] Fix Gemma2 4d attention mask (#31674) Update modeling_gemma2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gemma2/modeling_gemma2.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 09ce72c8b1b231..24e60eddba221b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -629,11 +629,13 @@ def forward( if ( self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None ): # efficient SDPA and no padding - attention_mask = attention_mask * torch.tril( - torch.ones_like(attention_mask), diagonal=-self.sliding_window + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) - if attention_mask.shape[1] <= 1: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] residual = hidden_states From 82a1fc7256bf27f83aec3a93543b6d156add09cf Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:18:01 -0700 Subject: [PATCH 14/18] Fix return_dict in encodec (#31646) * fix: use return_dict parameter * fix: type checks * fix: unused imports * update: one-line if else * remove: recursive check --- .../models/encodec/modeling_encodec.py | 4 +- tests/models/encodec/test_modeling_encodec.py | 41 +++++++------------ 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 9627742b9eee6b..f325a6adbe6c1a 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -729,7 +729,7 @@ def decode( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - return_dict = return_dict or self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.return_dict chunk_length = self.config.chunk_length if chunk_length is None: @@ -786,7 +786,7 @@ def forward( >>> audio_codes = outputs.audio_codes >>> audio_values = outputs.audio_values ```""" - return_dict = return_dict or self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.return_dict if padding_mask is None: padding_mask = torch.ones_like(input_values).bool() diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index e4f66d85641bed..0a023894d8a00a 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -19,7 +19,6 @@ import os import tempfile import unittest -from typing import Dict, List, Tuple import numpy as np from datasets import Audio, load_dataset @@ -385,31 +384,21 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip( - tuple_object.values(), dict_object.values() - ): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" - f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." - ), - ) - - recursive_check(tuple_output, dict_output) + self.assertTrue(isinstance(tuple_output, tuple)) + self.assertTrue(isinstance(dict_output, dict)) + + for tuple_value, dict_value in zip(tuple_output, dict_output.values()): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:" + f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has" + f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}." + ), + ) for model_class in self.all_model_classes: model = model_class(config) From cb298978ade3f1edb0ffd02ee079a69f08917a2a Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:50:27 +0900 Subject: [PATCH 15/18] add gather_use_object arguments (#31514) * add gather_use_object arguments * fix name and pass the CI test for Seq2SeqTrainer * make style * make it to functools * fix typo * add accelerate version: * adding warning * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * make style * Update src/transformers/training_args.py * check function move to initial part * add test for eval_use_gather_object --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 5 +++++ src/transformers/training_args.py | 18 +++++++++++++++++- tests/trainer/test_trainer.py | 12 ++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 71c3ee43af2c9c..affc7b725e8a70 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4605,6 +4605,11 @@ def create_accelerator_and_postprocess(self): # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 41a9607e312105..5eff032774e203 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -773,8 +773,11 @@ class TrainingArguments: that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. - eval_on_start(`bool`, *optional*, defaults to `False`): + eval_on_start (`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ framework = "pt" @@ -1465,6 +1468,13 @@ class TrainingArguments: }, ) + eval_use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1992,6 +2002,12 @@ def __post_init__(self): FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2c6793e39ca5a5..26fa4624674ec5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -132,6 +132,7 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -3565,6 +3566,17 @@ class TorchDtypeTrainingArguments(TrainingArguments): self.assertIn("torch_dtype", args_dict) self.assertEqual(args_dict["torch_dtype"], dtype) + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + @require_torch @is_staging_test From bbf1e618642946f81de903d5bf2944a8b1f45df4 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:16:17 +0200 Subject: [PATCH 16/18] Gemma capping is a must for big models (#31698) * softcapping * soft cap before the mask * style * ... * super nit --- src/transformers/models/gemma2/configuration_gemma2.py | 3 +++ src/transformers/models/gemma2/modeling_gemma2.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 47207d7ca12436..7da541207bfe76 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. @@ -116,6 +117,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, **kwargs, @@ -135,6 +137,7 @@ def __init__( self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.attn_logit_softcapping = attn_logit_softcapping super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 24e60eddba221b..6b2b47b5159e28 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -256,6 +256,11 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask From e65502951593a76844e872fee9c56b805598538a Mon Sep 17 00:00:00 2001 From: Jade Choghari <78852495+jadechoghari@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:02:30 -0400 Subject: [PATCH 17/18] Add French version of run scripts tutorial (#31483) * Add French translation of run scripts tutorial * Update docs/source/fr/run_scripts_fr.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/fr/run_scripts_fr.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/fr/run_scripts_fr.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/fr/run_scripts_fr.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/fr/run_scripts_fr.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Jade Choghari Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/fr/_toctree.yml | 2 +- docs/source/fr/run_scripts_fr.md | 355 +++++++++++++++++++++++++++++++ 2 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 docs/source/fr/run_scripts_fr.md diff --git a/docs/source/fr/_toctree.yml b/docs/source/fr/_toctree.yml index 93597ed1850f01..8f1e1046b0260d 100755 --- a/docs/source/fr/_toctree.yml +++ b/docs/source/fr/_toctree.yml @@ -15,7 +15,7 @@ title: Préparation des données - local: in_translation title: Fine-tune un modèle pré-entraîné - - local: in_translation + - local: run_scripts_fr title: Entraînement avec un script - local: in_translation title: Entraînement distribué avec 🤗 Accelerate diff --git a/docs/source/fr/run_scripts_fr.md b/docs/source/fr/run_scripts_fr.md new file mode 100644 index 00000000000000..0344ff2cec3d2d --- /dev/null +++ b/docs/source/fr/run_scripts_fr.md @@ -0,0 +1,355 @@ + + +# Entraîner avec un script + +En plus des [notebooks](./notebooks) de 🤗 Transformers, il existe également des exemples de scripts démontrant comment entraîner un modèle pour une tâche avec [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch), [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow) ou [JAX/Flax](https://github.com/huggingface/transformers/tree/main/examples/flax). + + +Vous trouverez également des scripts que nous avons utilisé dans nos [projets de recherche](https://github.com/huggingface/transformers/tree/main/examples/research_projects) et des [exemples "legacy"](https://github.com/huggingface/transformers/tree/main/examples/legacy) qui sont des contributions de la communauté. Ces scripts ne sont pas activement maintenus et nécessitent une version spécifique de 🤗 Transformers qui sera probablement incompatible avec la dernière version de la librairie. + +Les exemples de scripts ne sont pas censés fonctionner immédiatement pour chaque problème, et il se peut que vous ayez besoin d'adapter le script au problème que vous essayez de résoudre. Pour vous aider dans cette tâche, la plupart des scripts exposent entièrement la manière dont les données sont prétraitées, vous permettant de les modifier selon vos besoins. + +Pour toute fonctionnalité que vous souhaitez implémenter dans un script d'exemple, veuillez en discuter sur le [forum](https://discuss.huggingface.co/) ou dans une [issue](https://github.com/huggingface/transformers/issues) avant de soumettre une Pull Request. Bien que nous acceptions les corrections de bugs, il est peu probable que nous fusionnions une Pull Request (opération "merge" dans Git) ajoutant plus de fonctionnalités au détriment de la lisibilité. + +Ce guide vous montrera comment exécuter un script d'entraînement de résumé en exemple avec [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) et [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization). Tous les exemples sont censés fonctionner avec les deux frameworks, sauf indication contraire. + +## Configuration + +Pour exécuter avec succès la dernière version des scripts d'exemple, vous devez **installer 🤗 Transformers à partir du code source** dans un nouvel environnement virtuel : + +```bash +git clone https://github.com/huggingface/transformers +cd transformers +pip install . +``` + +Pour les versions plus anciennes des exemples de scripts, cliquez sur le bouton ci-dessous : + +
+ Exemples pour les anciennes versions de Transformers 🤗 + +
+ +Ensuite, changez votre clone actuel de 🤗 Transformers pour une version spécifique, comme par exemple v3.5.1 : + +```bash +git checkout tags/v3.5.1 +``` + +Après avoir configuré la bonne version de la librairie, accédez au dossier d'exemple de votre choix et installez les prérequis spécifiques à l'exemple. + +```bash +pip install -r requirements.txt +``` + +## Exécuter un script + + + + +Le script d'exemple télécharge et prétraite un jeu de données à partir de la bibliothèque 🤗 [Datasets](https://huggingface.co/docs/datasets/). Ensuite, le script affine un ensemble de données à l'aide de [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) sur une architecture qui prend en charge la tâche de résumé. L'exemple suivant montre comment ajuster le modèle [T5-small](https://huggingface.co/google-t5/t5-small) sur les données [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail). Le modèle T5 nécessite un argument supplémentaire `source_prefix` en raison de la façon dont il a été entraîné. Cette invite permet à T5 de savoir qu'il s'agit d'une tâche de résumé. + +```bash +python examples/pytorch/summarization/run_summarization.py \ + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --predict_with_generate +``` + + + +Le script d'exemple télécharge et prétraite un jeu de données à partir de la bibliothèque 🤗 [Datasets](https://huggingface.co/docs/datasets/). Ensuite, le script ajuste un modèle à l'aide de Keras sur une architecture qui prend en charge la tâche de résumé. L'exemple suivant montre comment ajuster le modèle [T5-small](https://huggingface.co/google-t5/t5-small) sur le jeu de données [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail). Le modèle T5 nécessite un argument supplémentaire source_prefix en raison de la façon dont il a été entraîné. Cette invite permet à T5 de savoir qu'il s'agit d'une tâche de résumé. + +```bash +python examples/tensorflow/summarization/run_summarization.py \ + --model_name_or_path google-t5/t5-small \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 16 \ + --num_train_epochs 3 \ + --do_train \ + --do_eval +``` + + + +## Entraînement distribué et précision mixte + +[Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) prend en charge l'entraînement distribué et la précision mixte, ce qui signifie que vous pouvez également les utiliser dans un script. Pour activer ces deux fonctionnalités : + +- Ajoutez l'argument fp16 pour activer la précision mixte. +- Définissez le nombre de GPU à utiliser avec l'argument `nproc_per_node`. + +```bash +torchrun \ + --nproc_per_node 8 pytorch/summarization/run_summarization.py \ + --fp16 \ + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --predict_with_generate +``` + +Les scripts TensorFlow utilisent une Strategie en Miroir [`MirroredStrategy`](https://www.tensorflow.org/guide/distributed_training#mirroredstrategy) pour l'entraînement distribué, et vous n'avez pas besoin d'ajouter d'arguments supplémentaires au script d'entraînement. Le script TensorFlow utilisera plusieurs GPU par défaut s'ils sont disponibles. + +## Exécuter un script sur un TPU + + + + +Les unités de traitement de tenseurs (UTT) (TPU) sont spécialement conçues pour accélérer les performances. PyTorch prend en charge les TPU avec le compilateur de deep learning [XLA](https://www.tensorflow.org/xla). Pour utiliser un TPU, lancez le script xla_spawn.py et utilisez l'argument num_cores pour définir le nombre de cœurs TPU que vous souhaitez utilise + +```bash +python xla_spawn.py --num_cores 8 \ + summarization/run_summarization.py \ + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --predict_with_generate +``` + + +Les scripts TensorFlow utilisent une [`TPUStrategy`](https://www.tensorflow.org/guide/distributed_training#tpustrategy) pour l'entraînement sur TPU. Pour utiliser un TPU, passez le nom de la ressource TPU à l'argument tpu. + +```bash +python run_summarization.py \ + --tpu name_of_tpu_resource \ + --model_name_or_path google-t5/t5-small \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 16 \ + --num_train_epochs 3 \ + --do_train \ + --do_eval +``` + + + +## Exécuter un script avec 🤗 Accelerate + +🤗 [Accelerate](https://huggingface.co/docs/accelerate) est une bibliothèque uniquement pour PyTorch qui offre une méthode unifiée pour entraîner un modèle sur plusieurs types de configurations (CPU uniquement, plusieurs GPU, TPU) tout en maintenant une visibilité complète sur la boucle d'entraînement PyTorch. Assurez-vous que vous avez installé 🤗 Accelerate si ce n'est pas déjà le cas. + +> Note : Comme Accelerate est en développement rapide, la version git d'accelerate doit être installée pour exécuter les scripts. +```bash +pip install git+https://github.com/huggingface/accelerate +``` + +Au lieu du script `run_summarization.py`, vous devez utiliser le script `run_summarization_no_trainer.py`. Les scripts compatibles avec 🤗 Accelerate auront un fichier `task_no_trainer.py` dans le dossier. Commencez par exécuter la commande suivante pour créer et enregistrer un fichier de configuration. + +```bash +accelerate config +``` + +Testez votre configuration pour vous assurer qu'elle est correctement configurée : + +```bash +accelerate test +``` + +Maintenant, vous êtes prêt à lancer l'entraînement : + +```bash +accelerate launch run_summarization_no_trainer.py \ + --model_name_or_path google-t5/t5-small \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir ~/tmp/tst-summarization +``` + +## Utiliser un jeu de données personnalisé + +Le script de résumé prend en charge les jeux de données personnalisés tant qu'ils sont au format CSV ou JSON Line. Lorsque vous utilisez votre propre jeu de données, vous devez spécifier plusieurs arguments supplémentaires : + +- `train_file` et `validation_file` spécifient le chemin vers vos fichiers d'entraînement et de validation. +- `text_column` est le texte d'entrée à résumer. +- `summary_column` est le texte cible à produire. + +Un exemple de script de résumé utilisant un ensemble de données personnalisé ressemblerait à ceci : + +```bash +python examples/pytorch/summarization/run_summarization.py \ + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --train_file path_to_csv_or_jsonlines_file \ + --validation_file path_to_csv_or_jsonlines_file \ + --text_column text_column_name \ + --summary_column summary_column_name \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --overwrite_output_dir \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --predict_with_generate +``` + +## Tester un script +Il est souvent judicieux d'exécuter votre script sur un plus petit nombre d'exemples de jeu de données pour s'assurer que tout fonctionne comme prévu avant de s'engager sur un jeu de données complet qui pourrait prendre des heures à traiter. Utilisez les arguments suivants pour tronquer le jeu de données à un nombre maximal d'échantillons : + +- `max_train_samples` +- `max_eval_samples` +- `max_predict_samples` + +```bash +python examples/pytorch/summarization/run_summarization.py \ + --model_name_or_path google-t5/t5-small \ + --max_train_samples 50 \ + --max_eval_samples 50 \ + --max_predict_samples 50 \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --predict_with_generate +``` + +Tous les scripts d'exemple ne prennent pas en charge l'argument `max_predict_samples`. Si vous n'êtes pas sûr que votre script prenne en charge cet argument, ajoutez l'argument `-h` pour vérifier. + +```bash +examples/pytorch/summarization/run_summarization.py -h +``` + +## Reprendre l'entraînement à partir d'un point de contrôle + +Une autre option utile est de reprendre l'entraînement à partir d'un point de contrôle précédent. Cela vous permettra de reprendre là où vous vous étiez arrêté sans recommencer si votre entraînement est interrompu. Il existe deux méthodes pour reprendre l'entraînement à partir d'un point de contrôle. + +La première méthode utilise l'argument `output_dir previous_output_dir` pour reprendre l'entraînement à partir du dernier point de contrôle stocké dans `output_dir`. Dans ce cas, vous devez supprimer l'argument `overwrite_output_dir`. + +```bash +python examples/pytorch/summarization/run_summarization.py + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --output_dir previous_output_dir \ + --predict_with_generate +``` + +La seconde méthode utilise l'argument `resume_from_checkpoint path_to_specific_checkpoint` pour reprendre l'entraînement à partir d'un dossier de point de contrôle spécifique. + +```bash +python examples/pytorch/summarization/run_summarization.py + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --resume_from_checkpoint path_to_specific_checkpoint \ + --predict_with_generate +``` + +## Partage ton modèle + +Tous les scripts peuvent télécharger votre modèle final sur le Model Hub. Assurez-vous que vous êtes connecté à Hugging Face avant de commencer : + +```bash +huggingface-cli login +``` + +Ensuite, ajoutez l'argument `push_to_hub` au script. Cet argument créera un dépôt avec votre nom d'utilisateur Hugging Face et le nom du dossier spécifié dans `output_dir`. + + +Pour donner un nom spécifique à votre dépôt, utilisez l'argument `push_to_hub_model_id` pour l'ajouter. Le dépôt sera automatiquement listé sous votre namespace. + +L'exemple suivant montre comment télécharger un modèle avec un nom de dépôt spécifique : + +```bash +python examples/pytorch/summarization/run_summarization.py + --model_name_or_path google-t5/t5-small \ + --do_train \ + --do_eval \ + --dataset_name cnn_dailymail \ + --dataset_config "3.0.0" \ + --source_prefix "summarize: " \ + --push_to_hub \ + --push_to_hub_model_id finetuned-t5-cnn_dailymail \ + --output_dir /tmp/tst-summarization \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=4 \ + --overwrite_output_dir \ + --predict_with_generate +``` \ No newline at end of file From 3345ae733b6f4aeb7204a0f3e646a3cdbaad0023 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 1 Jul 2024 17:39:33 +0100 Subject: [PATCH 18/18] dependencies: `keras-nlp<0.14` pin (#31684) * keras nlp pin * this should use the new docker images:dev * dev-ci --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8f28f0f9e4adb7..af2cec8aa9d5e3 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,7 @@ "kenlm", # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", - "keras-nlp>=0.3.1", + "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. "librosa", "nltk", "natten>=0.14.6,<0.15.0", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 7f8c3e4433fe3c..fcb2c7a29f63bb 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -34,7 +34,7 @@ "jinja2": "jinja2>=3.1.0", "kenlm": "kenlm", "keras": "keras>2.9,<2.16", - "keras-nlp": "keras-nlp>=0.3.1", + "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", "librosa": "librosa", "nltk": "nltk", "natten": "natten>=0.14.6,<0.15.0",