Skip to content

Commit

Permalink
Merge branch 'main' into phi-2
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 30, 2023
2 parents 0e9efda + 06a2b11 commit 6ceb102
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
`mergekit` is a toolkit for merging pre-trained language models, using a variety of merge methods including TIES, linear, and slerp merging. The toolkit also enables piecewise assembly of a language model from layers.

Run `pip install -e .` to install the package and make the scripts available.
If the above fails with the error of:
```
ERROR: File "setup.py" or "setup.cfg" not found. Directory cannot be installed in editable mode:
(A "pyproject.toml" file was found, but editable mode currently requires a setuptools-based build.)
```
You may need to upgrade pip to > 21.3 with the command `python3 -m pip install --upgrade pip`

The script `mergekit-yaml` takes a YAML configuration file defining the operations to perform.

Expand Down
6 changes: 4 additions & 2 deletions mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def merged(

return ModelReference(path=out_path)

def config(self) -> PretrainedConfig:
return AutoConfig.from_pretrained(self.path)
def config(self, trust_remote_code: bool = False) -> PretrainedConfig:
return AutoConfig.from_pretrained(
self.path, trust_remote_code=trust_remote_code
)

def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex:
assert self.lora_path is None
Expand Down
7 changes: 5 additions & 2 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt

method = merge_methods.get(merge_config.merge_method)
model_arch_info = [
get_architecture_info(m.config()) for m in merge_config.referenced_models()
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]
if not options.allow_crimes:
if not all(a == model_arch_info[0] for a in model_arch_info[1:]):
Expand Down Expand Up @@ -102,7 +103,9 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt
clone_tensors=options.clone_tensors,
)

cfg_out = method.model_out_config(merge_config)
cfg_out = method.model_out_config(
merge_config, trust_remote_code=options.trust_remote_code
)
if tokenizer:
try:
cfg_out.vocab_size = len(tokenizer.get_vocab())
Expand Down
12 changes: 9 additions & 3 deletions mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,18 @@ def input_layer_dependencies(
"""List any tensors necessary when input includes a specific layer"""
return []

def model_out_config(self, config: MergeConfiguration) -> PretrainedConfig:
def model_out_config(
self, config: MergeConfiguration, trust_remote_code: bool = False
) -> PretrainedConfig:
"""Return a configuration for the resulting model."""
if config.base_model:
res = ModelReference.parse(config.base_model).config()
res = ModelReference.parse(config.base_model).config(
trust_remote_code=trust_remote_code
)
else:
res = config.referenced_models()[0].config()
res = config.referenced_models()[0].config(
trust_remote_code=trust_remote_code
)

if config.dtype:
res.torch_dtype = config.dtype
Expand Down
5 changes: 3 additions & 2 deletions mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def plan(
merge_config: MergeConfiguration,
arch_info: ArchitectureInfo,
embed_permutations: Optional[Dict[ModelReference, torch.Tensor]] = None,
trust_remote_code: bool = False,
) -> Tuple[List[TensorReference], Dict[TensorReference, Operation]]:
layer_idx = 0

Expand Down Expand Up @@ -62,7 +63,7 @@ def plan(
if base_model and mref == base_model:
base_included = True

model_cfg = mref.config()
model_cfg = mref.config(trust_remote_code=trust_remote_code)
num_layers = arch_info.num_layers(model_cfg)
slices_in.append(
InputSliceDefinition(
Expand All @@ -74,7 +75,7 @@ def plan(

if base_model and not base_included:
logging.info("Base model specified but not in input models - adding")
base_cfg = base_model.config()
base_cfg = base_model.config(trust_remote_code=trust_remote_code)
num_layers = arch_info.num_layers(base_cfg)
slices_in.append(
InputSliceDefinition(
Expand Down

0 comments on commit 6ceb102

Please sign in to comment.