Skip to content

Commit

Permalink
runnable tgi version
Browse files Browse the repository at this point in the history
  • Loading branch information
System administrator committed Dec 12, 2024
1 parent ba31db0 commit 2b02689
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

__version__ = "4.48.0.dev0"

print("CORRECT DEV VERSION")

from typing import TYPE_CHECKING

# Check the dependencies satisfy the minimal versions required.
Expand Down
88 changes: 88 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,94 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
return False



def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)

sharded_state_dicts = [{}]
last_block_size = 0
total_size = 0
storage_id_to_block = {}

for key, weight in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(weight, str):
continue
else:
storage_id = id_tensor_storage(weight)

# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block and weight.device != torch.device("meta"):
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue

weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
sharded_state_dicts.append({})
last_block_size = 0

sharded_state_dicts[-1][key] = weight
last_block_size += weight_size
total_size += weight_size
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1

# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None

# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file

# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index



def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"""
This is the same as
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,14 @@ def forward(

output_attentions = False

bsz, q_len, _ = hidden_states.size()
if position_ids.dim() > 1:
bsz, q_len, _ = hidden_states.size()
else:
q_len = None

input_shape = hidden_states.shape[:-1]
print(f"INPUT SHAPE: {input_shape}")


query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand Down Expand Up @@ -446,6 +453,7 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)


dropout_rate = self.attention_dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
Expand Down Expand Up @@ -1195,7 +1203,12 @@ def forward(

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if hidden_states.dim() == 2:
if kwargs.get("lm_head_indices", None) is not None:
hidden_states = hidden_states[kwargs["lm_head_indices"]]
else:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
Expand Down

0 comments on commit 2b02689

Please sign in to comment.