Skip to content

Commit

Permalink
Fixes to auto-architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliotStein committed Oct 22, 2024
1 parent 14a72e3 commit 1571e57
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 50 deletions.
5 changes: 2 additions & 3 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def _hierarchy(self, names):
param_name = name[match.end() :] # e.g., 'input_layernorm.weight'
# Add the parameter name to the corresponding layer in the hierarchy
hierarchy[layer_prefix].append(param_name)
else:
hierarchy[name].append("")

return hierarchy

Expand All @@ -252,9 +254,6 @@ def layer_weights(
WeightInfo(name=(layer_name + ("." + param if param else "")))
for param in self.layered_parameter_names[layer_name]
]

def all_weights(self, config):
return self.parameter_names

def sliceable(self) -> bool:
return True
Expand Down
132 changes: 85 additions & 47 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,51 +33,21 @@
from mergekit.config import MergeConfiguration
from mergekit.graph import Executor
from mergekit.io.tasks import LoaderCache
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
from mergekit.options import MergeOptions
from mergekit.plan import MergePlanner
from mergekit.tokenizer import TokenizerInfo

import os
from safetensors import safe_open
from transformers.configuration_utils import is_remote_url, download_url
from huggingface_hub import snapshot_download
from pathlib import Path
from huggingface_hub import model_info
from huggingface_hub.utils import HfHubHTTPError

def get_model_parameter_names(repo_id: str):
# Get the directory where the model is stored locally
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/hub")
# Overwritten by the environment variable HF_HOME if set
HF_HOME_DEFAULT = "~/.cache/huggingface"

# Expand the user directory if the path contains ~
hf_home = os.path.expanduser(hf_home)

# Construct the model directory path
model_dir = os.path.join(hf_home, "models--" + repo_id.replace("/", "--"))

# Check if model exists locally
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model repository {repo_id} not found locally.")

# Find all safetensor files in the directory (e.g., model-00001-of-00003.safetensors)
safetensors_files = [
os.path.join(root, f)
for root, dirs, files in os.walk(model_dir)
for f in files
if f.endswith(".safetensors")
]

if not safetensors_files:
raise FileNotFoundError(f"No safetensors files found for {repo_id}.")

# Initialize a set to store unique parameter names across all safetensors files
param_names = set()

# Loop through all safetensors files and extract keys
for safetensors_file in safetensors_files:
safetensors_path = os.path.join(model_dir, safetensors_file)

with safe_open(safetensors_path, framework="pt", device="cpu") as f:
param_names.update(
f.keys()
) # Add all parameter names (keys) from this file

return sorted(param_names)

def run_merge(
merge_config: MergeConfiguration,
Expand All @@ -94,19 +64,11 @@ def run_merge(
model_arch_info = [
AutomaticArchitectureInfo(
arch_name=source_model.model.path,
parameter_names=get_model_parameter_names(source_model.model.path),
# Could put get_model_parameter_names inside AutomaticArchitectureInfo,
# but this way we can still use AutomaticArchitectureInfo for other models for arbitrary pytorch models
parameter_names=_get_model_parameter_names(source_model.model.path),
)
for source_model in merge_config.referenced_models()
]

if not options.allow_crimes:
if not all(a.all_weights(None) == model_arch_info[0].all_weights(None) for a in model_arch_info[1:]):
# Current implementation has name = repo_id so will be different for each model. Can change if necessary.
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
arch_info = model_arch_info[0]

# initialize loader cache and set options
Expand Down Expand Up @@ -324,4 +286,80 @@ def _update_config_vocab(
)


def _get_model_parameter_names(repo_id: str):
"""
Get the names of the parameters from a Hugging Face model or local model.
This function supports local paths, remote URLs, or Hugging Face repository IDs.
:param repo_id: The model's repo ID, URL, or local directory path.
:return: A list of parameter names.
"""
# Determine if repo_id is a local path, remote URL, or Hugging Face repo
if Path(repo_id).is_dir():
model_dir = Path(repo_id)
elif is_remote_url(repo_id):
model_dir = Path(download_url(repo_id))
elif _is_hf_repo(repo_id):
hf_home = Path(os.getenv("HF_HOME", HF_HOME_DEFAULT)).expanduser()
snapshot_download(repo_id)
model_dir = hf_home / "hub" / f"models--{repo_id.replace('/', '--')}"
else:
raise ValueError(f"Invalid repo_id: {repo_id}")

# Try to get the model parameter names
try:
return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())
except Exception as e:
print(f"Error loading tensor paths: {e}")
snapshot_path = _most_recent_snapshot_path(model_dir)
try:
return list(ShardedTensorIndex.from_disk(str(snapshot_path)).tensor_paths.keys())
except Exception as e:
print(f"Error loading tensor paths from snapshot: {e}")
raise


def _most_recent_snapshot_path(model_dir: Path) -> Path:
"""
Get the most recently created snapshot directory within a model directory.
:param model_dir: The directory where model snapshots are stored.
:return: The path of the most recent snapshot directory.
"""
snapshots_dir = model_dir / "snapshots"

if not snapshots_dir.exists():
raise FileNotFoundError(f"Snapshot directory does not exist: {snapshots_dir}")

# List all directories in the snapshots directory
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]

# Sort directories by creation time (most recent first)
snapshot_dirs.sort(key=lambda d: d.stat().st_ctime, reverse=True)

if not snapshot_dirs:
raise FileNotFoundError(f"No snapshot directories found in {snapshots_dir}")

most_recent_snapshot = snapshot_dirs[0]

if len(snapshot_dirs) > 1:
print(f"Most recent snapshot directory: {most_recent_snapshot} of {len(snapshot_dirs)}")

return most_recent_snapshot


def _is_hf_repo(repo_id: str) -> bool:
"""
Check if a given repo_id is a valid Hugging Face repository.
:param repo_id: The Hugging Face repository ID.
:return: True if the repo exists, False otherwise.
"""
try:
model_info(repo_id)
return True
except HfHubHTTPError:
return False
except Exception as e:
print(f"Unexpected error while checking repo: {e}")
return False


__all__ = ["MergeOptions", "run_merge"]

0 comments on commit 1571e57

Please sign in to comment.