Skip to content

Commit

Permalink
Download fix (pytorch#1366)
Browse files Browse the repository at this point in the history
* fix: allow multiple weight mapping files for mistral

Downloading a Mistral model fails because it includes multiple weight
mapping files. The regression was introduced in commit
`766bee9f4a1fcb187fae543a525495d3ff482097`. I'm unclear on the original
intent, but perhaps the exception was meant to apply only to Granite
models. This isn’t an ideal fix, but it does enable Mistral to be
downloaded and used for chat.

Signed-off-by: Sébastien Han <[email protected]>

* fix(download): Fix safetensors/bin/pth download logic

The previous logic didn't handle .bin files, so if a model (like mistral)
has both .bin and .safetensors, it would download both.

Branch: download-fix

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert hf): Better logic to handle multiple weight mapping files

This will not actually be needed for mistral with the fix in download to
handle .bin files, but it may be needed for other models, so it's worth
having.

Branch: download-fix

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Sébastien Han <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Sébastien Han <[email protected]>
  • Loading branch information
gabe-l-hart and leseb authored Nov 19, 2024
1 parent a6a6e61 commit 57dee04
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
35 changes: 24 additions & 11 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,14 @@ def convert_hf_checkpoint(
config = TransformerArgs.from_params(config_args)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
# Find all candidate weight mapping index files
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
if len(model_map_json_matches):
model_map_json = model_map_json_matches[0]
else:
model_map_json = model_dir / "pytorch_model.bin.index.json"

# If there is no weight mapping, check for a consolidated model and
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
# Llama 3 has a consolidated model and tokenizer.
# Otherwise raise an error.
if not model_map_json.is_file():
if not model_map_json_matches:
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
Expand All @@ -68,11 +63,30 @@ def convert_hf_checkpoint(
return
else:
raise RuntimeError(
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
)

with open(model_map_json) as json_map:
bin_index = json.load(json_map)
# Load the json file(s) containing weight mapping
#
# NOTE: If there are multiple index files, there are two possibilities:
# 1. The files could be mapped to different weight format files (e.g. .bin
# vs .safetensors)
# 2. The files could be split subsets of the mappings that need to be
# merged
#
# In either case, we can simply keep the mappings where the target file is
# valid in the model dir.
bin_index = {}
for weight_map_file in model_map_json_matches:
with open(weight_map_file, "r") as handle:
weight_map = json.load(handle)
valid_mappings = {
k: model_dir / v
for (k, v) in weight_map.get("weight_map", {}).items()
if (model_dir / v).is_file()
}
bin_index.update(valid_mappings)
bin_files = set(bin_index.values())

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
Expand All @@ -96,7 +110,6 @@ def convert_hf_checkpoint(
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
return (
Expand Down
9 changes: 5 additions & 4 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ def _download_hf_snapshot(
model_info = model_info(model_config.distribution_path, token=hf_token)
model_fnames = [f.rfilename for f in model_info.siblings]

# Check the model config for preference between safetensors and pth
# Check the model config for preference between safetensors and pth/bin
has_pth = any(f.endswith(".pth") for f in model_fnames)
has_bin = any(f.endswith(".bin") for f in model_fnames)
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)

# If told to prefer safetensors, ignore pth files
# If told to prefer safetensors, ignore pth/bin files
if model_config.prefer_safetensors:
if not has_safetensors:
print(
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
file=sys.stderr,
)
exit(1)
ignore_patterns = "*.pth"
ignore_patterns = ["*.pth", "*.bin"]

# If the model has both, prefer pth files over safetensors
elif has_pth and has_safetensors:
elif (has_pth or has_bin) and has_safetensors:
ignore_patterns = "*safetensors*"

# Otherwise, download everything
Expand Down

0 comments on commit 57dee04

Please sign in to comment.