From 2b02689e6154864c7329de66845444c7ed97f8b5 Mon Sep 17 00:00:00 2001 From: System administrator Date: Thu, 12 Dec 2024 14:28:32 +0000 Subject: [PATCH] runnable tgi version --- src/transformers/__init__.py | 2 + src/transformers/modeling_utils.py | 88 +++++++++++++++++++ .../models/llama/modeling_llama.py | 17 +++- 3 files changed, 105 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2eaec8f1def96e..30c4a38dcbdb12 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -20,6 +20,8 @@ __version__ = "4.48.0.dev0" +print("CORRECT DEV VERSION") + from typing import TYPE_CHECKING # Check the dependencies satisfy the minimal versions required. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dae29111c8dcc0..794a457b6fc3fd 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -405,6 +405,94 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi return False + +def shard_checkpoint( + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. + """ + logger.warning( + "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " + "split_torch_state_dict_into_shards from huggingface_hub library" + ) + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [{}] + last_block_size = 0 + total_size = 0 + storage_id_to_block = {} + + for key, weight in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(weight, str): + continue + else: + storage_id = id_tensor_storage(weight) + + # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` + if storage_id in storage_id_to_block and weight.device != torch.device("meta"): + block_id = storage_id_to_block[storage_id] + sharded_state_dicts[block_id][key] = weight + continue + + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one + # weight in the current shard. + if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: + sharded_state_dicts.append({}) + last_block_size = 0 + + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size + total_size += weight_size + storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + + def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): """ This is the same as diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6f614e61da869d..87c9db39bdbd04 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -404,7 +404,14 @@ def forward( output_attentions = False - bsz, q_len, _ = hidden_states.size() + if position_ids.dim() > 1: + bsz, q_len, _ = hidden_states.size() + else: + q_len = None + + input_shape = hidden_states.shape[:-1] + print(f"INPUT SHAPE: {input_shape}") + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -446,6 +453,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) + dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons @@ -1195,7 +1203,12 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if hidden_states.dim() == 2: + if kwargs.get("lm_head_indices", None) is not None: + hidden_states = hidden_states[kwargs["lm_head_indices"]] + else: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + logits = self.lm_head(hidden_states) loss = None if labels is not None: