-
Notifications
You must be signed in to change notification settings - Fork 474
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch argument parsing from typer to click
- Loading branch information
Showing
10 changed files
with
277 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
import typing | ||
from typing import Any, Callable, Optional, Union | ||
|
||
import click | ||
from click.core import Context, Parameter | ||
|
||
from mergekit.common import parse_kmb | ||
from mergekit.merge import MergeOptions | ||
|
||
OPTION_HELP = { | ||
"allow_crimes": "Allow mixing architectures", | ||
"transformers_cache": "Override storage path for downloaded models", | ||
"lora_merge_cache": "Path to store merged LORA models", | ||
"cuda": "Perform matrix arithmetic on GPU", | ||
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM", | ||
"out_shard_size": "Number of parameters per output shard [default: 5B]", | ||
"copy_tokenizer": "Copy a tokenizer to the output", | ||
"clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer", | ||
"trust_remote_code": "Trust remote code from huggingface repos (danger)", | ||
"random_seed": "Seed for reproducible use of randomized merge methods", | ||
"lazy_unpickle": "Experimental lazy unpickler for lower memory usage", | ||
} | ||
|
||
|
||
class ShardSizeParamType(click.ParamType): | ||
name = "size" | ||
|
||
def convert( | ||
self, value: Any, param: Optional[Parameter], ctx: Optional[Context] | ||
) -> int: | ||
return parse_kmb(value) | ||
|
||
|
||
def add_merge_options(f: Callable) -> Callable: | ||
def wrapper(*args, **kwargs): | ||
arg_dict = {} | ||
for field_name in MergeOptions.model_fields: | ||
if field_name in kwargs: | ||
arg_dict[field_name] = kwargs.pop(field_name) | ||
|
||
kwargs["merge_options"] = MergeOptions(**arg_dict) | ||
f(*args, **kwargs) | ||
|
||
for field_name, info in reversed(MergeOptions.model_fields.items()): | ||
origin = typing.get_origin(info.annotation) | ||
if origin is Union: | ||
ty, prob_none = typing.get_args(info.annotation) | ||
assert prob_none is type(None) | ||
field_type = ty | ||
else: | ||
field_type = info.annotation | ||
|
||
if field_name == "out_shard_size": | ||
field_type = ShardSizeParamType() | ||
|
||
arg_name = field_name.replace("_", "-") | ||
if field_type == bool: | ||
arg_str = f"--{arg_name}/--no-{arg_name}" | ||
else: | ||
arg_str = f"--{arg_name}" | ||
|
||
help_str = OPTION_HELP.get(field_name, None) | ||
wrapper = click.option( | ||
arg_str, | ||
type=field_type, | ||
default=info.default, | ||
help=help_str, | ||
show_default=field_name != "out_shard_size", | ||
)(wrapper) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.