Skip to content

Commit

Permalink
Handle unsharded Llama2 model types in conversion script (huggingface…
Browse files Browse the repository at this point in the history
…#27069)

Handle all unshared models types
  • Loading branch information
coreyhu authored and EduardoPach committed Nov 19, 2023
1 parent 05a5c22 commit daedea6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):

print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
Expand All @@ -138,7 +138,7 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if model_size == "7B":
if num_shards == 1:
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
Expand Down Expand Up @@ -222,7 +222,7 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
torch.save(state_dict, os.path.join(tmp_model_path, filename))

filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if model_size == "7B":
if num_shards == 1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
Expand Down

0 comments on commit daedea6

Please sign in to comment.