-
Notifications
You must be signed in to change notification settings - Fork 458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights Metrics #340
base: main
Are you sure you want to change the base?
Weights Metrics #340
Changes from 13 commits
c2b0f06
64e54bf
7e37d69
2effd41
ed59d9b
564d45c
e530568
207e874
b30175e
dfc0603
f88f904
f89df57
74c5d33
7c209d2
db65f83
3fe66a8
62692d1
2a0c520
f71e360
7e14266
0f3430f
8d68e39
59e23fe
89ecbf5
404e395
8e3c861
a0e8c27
7e2b552
69c3b15
f36fd70
50a5716
73dd3ae
ddbb475
36d28b0
fa6d098
4949458
bd62ed9
9c87514
4ce67b0
88ed18e
1041298
01d5a2b
f64a71e
8b05c2a
b8f7e32
c4df572
5e051a8
bb8fce2
9c92efd
e1c1ecd
9db9d3b
af2bf1a
6045980
cb5e31a
dfdcb00
7bc8001
a5c5811
3d3fd33
2dfa4f9
60726d4
891bbb4
04b64e8
1b5a3c0
6bd12e7
5ac4c78
68a6a30
20dd0f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
models: | ||
- model: BEE-spoke-data/smol_llama-220M-GQA | ||
- model: BEE-spoke-data/smol_llama-220M-openhermes | ||
|
||
metric_method: all | ||
dtype: float32 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True): | |
Abstract base class representing a task in a computational graph. | ||
|
||
This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. | ||
Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super nitpick here: I think the official capitalization is Pydantic, not PyDantic. |
||
|
||
Attributes: | ||
Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution. | ||
|
@@ -106,7 +107,6 @@ def uses_accelerator(self) -> bool: | |
""" | ||
return False | ||
|
||
|
||
class Executor: | ||
""" | ||
Schedules and executes a set of tasks and their dependencies. | ||
|
@@ -241,13 +241,20 @@ def _make_schedule(self, targets: List[Task]) -> List[Task]: | |
# they will be included in the final schedule | ||
edge_tups.append((Executor.DUMMY_TASK_VALUE, task)) | ||
|
||
def _pad_numbers(s): | ||
parts = s.split('.') | ||
for i, part in enumerate(parts): | ||
if part.isdigit(): | ||
parts[i] = part.zfill(3) | ||
return '.'.join(parts) | ||
|
||
def _compare_key(task: Union[Task, str]): | ||
if task == Executor.DUMMY_TASK_VALUE: | ||
return ("", 0) | ||
return ( | ||
task.group_label() or "", | ||
-task.priority(), | ||
) | ||
group_label = task.group_label() or "" | ||
padded_label = _pad_numbers(group_label) | ||
priority = -task.priority() | ||
return (padded_label, priority) | ||
|
||
graph = networkx.DiGraph(edge_tups) | ||
res = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
import logging | ||
|
||
import tqdm | ||
import transformers | ||
|
||
from mergekit.architecture import get_architecture_info | ||
from mergekit.config import MergeConfiguration | ||
from mergekit.graph import Executor | ||
from mergekit.io.tasks import LoaderCache | ||
from mergekit.options import MergeOptions | ||
from mergekit.plan import MergePlanner | ||
from mergekit.merge import _model_out_config | ||
|
||
|
||
def run_measure( | ||
merge_config: MergeConfiguration, | ||
out_path: str, | ||
options: MergeOptions, | ||
): | ||
if options.random_seed is not None: | ||
transformers.trainer_utils.set_seed(options.random_seed) | ||
|
||
if not merge_config.models and not merge_config.slices: | ||
raise RuntimeError("No output requested") | ||
|
||
model_arch_info = [ | ||
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) | ||
for m in merge_config.referenced_models() | ||
] | ||
if not options.allow_crimes: | ||
if not all(a == model_arch_info[0] for a in model_arch_info[1:]): | ||
raise RuntimeError( | ||
"Must specify --allow-crimes to attempt to mix different architectures" | ||
) | ||
arch_info = model_arch_info[0] | ||
|
||
# initialize loader cache and set options | ||
loader_cache = LoaderCache() | ||
loader_cache.setup(options=options) | ||
|
||
# create config for output model | ||
cfg_out = _model_out_config( | ||
merge_config, arch_info, trust_remote_code=options.trust_remote_code | ||
) | ||
|
||
# warm up loader cache | ||
for model in ( | ||
pbar := tqdm.tqdm( | ||
merge_config.referenced_models(), | ||
desc="Warmup loader cache", | ||
disable=options.quiet, | ||
) | ||
): | ||
loader_cache.get(model) | ||
del pbar | ||
|
||
logging.info("Planning operations") | ||
targets = MergePlanner( | ||
merge_config, | ||
arch_info, | ||
options=options, | ||
out_model_config=cfg_out, | ||
).plan_to_disk(out_path=out_path) | ||
|
||
exec = Executor( | ||
tasks=targets, | ||
math_device="cuda" if options.cuda else "cpu", | ||
storage_device="cuda" if options.low_cpu_memory else "cpu", | ||
) | ||
|
||
res = [] | ||
for _task, value in exec.run(quiet=options.quiet): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking this over, I kinda think we might not need a separate file here - maybe it should just early out in |
||
res.append((_task, value)) | ||
|
||
return res | ||
|
||
__all__ = ["MergeOptions", "run_merge"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
from mergekit.metric_methods.base import MetricMethod | ||
from mergekit.metric_methods.all_metrics import AllMetric | ||
|
||
|
||
def get(method: str) -> MetricMethod: | ||
if method == "all": | ||
return AllMetric() | ||
raise RuntimeError(f"Unimplemented metric method {method}") | ||
|
||
|
||
__all__ = [ | ||
"MetricMethod", | ||
"get", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be
gqa_groups