-
Notifications
You must be signed in to change notification settings - Fork 458
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expands the script `mergekit-moe` to support two new output architectures, Deepseek MoE and Qwen 2 MoE. Both architectures include support for "shared" experts. Currently the script supports adding a single shared expert. The Deepseek architecture uses the shared experts ungated and unweighted, so you probably want to set the new `residual_scale` option on the shared expert to a relatively low value (think 0.1ish) to keep the model from being completely overcooked. Qwen 2 MoE has a gate parameter associated with the shared expert so this is less necessary, but still advisable. Deepseek MoE supports either Llama or Mistral based models as inputs. Qwen 2 MoE supports Llama, Mistral, or Qwen2 based models. Addresses #117, #244, and #134.
- Loading branch information
Showing
15 changed files
with
1,327 additions
and
492 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List | ||
|
||
from mergekit.moe.arch import MoEOutputArchitecture | ||
from mergekit.moe.deepseek import DeepseekMoE | ||
from mergekit.moe.mixtral import MixtralMoE | ||
|
||
ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()] | ||
|
||
try: | ||
from mergekit.moe.qwen import QwenMoE | ||
except ImportError: | ||
pass | ||
else: | ||
ALL_OUTPUT_ARCHITECTURES.append(QwenMoE()) | ||
|
||
__all__ = [ | ||
"ALL_OUTPUT_ARCHITECTURES", | ||
"MoEOutputArchitecture", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from mergekit.moe.config import MoEMergeConfig | ||
from mergekit.options import MergeOptions | ||
|
||
|
||
class MoEOutputArchitecture(ABC): | ||
@abstractmethod | ||
def name(self) -> str: | ||
"""Return a human-readable name for the architecture.""" | ||
pass | ||
|
||
@abstractmethod | ||
def supports_config( | ||
self, | ||
config: MoEMergeConfig, | ||
explain: bool = False, | ||
trust_remote_code: bool = False, | ||
) -> bool: | ||
"""Return whether this architecture supports the given config. | ||
If `explain` is True, log an explanation of why the config is not supported.""" | ||
pass | ||
|
||
@abstractmethod | ||
def write_model( | ||
self, | ||
out_path: str, | ||
config: MoEMergeConfig, | ||
merge_options: MergeOptions, | ||
router_weights: List[torch.Tensor], | ||
shared_router_weights: Optional[List[torch.Tensor]] = None, | ||
): | ||
"""Write the config and tensors for the output MoE to the given path.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
from typing import Dict, Optional | ||
|
||
import torch | ||
import tqdm | ||
import transformers | ||
|
||
from mergekit.common import ModelReference, dtype_from_name | ||
from mergekit.io import LazyTensorLoader, TensorWriter | ||
from mergekit.merge import MergeOptions | ||
from mergekit.moe.config import Expert, MoEMergeConfig | ||
|
||
|
||
def initialize_io( | ||
config: MoEMergeConfig, | ||
out_path: str, | ||
merge_options: MergeOptions, | ||
) -> tuple[Dict[ModelReference, LazyTensorLoader], LazyTensorLoader, TensorWriter]: | ||
base_model = config.base_model | ||
loaders: Dict[ModelReference, LazyTensorLoader] = {} | ||
for model in tqdm.tqdm( | ||
[base_model] + [e.source_model for e in config.experts], desc="Warm up loaders" | ||
): | ||
loaders[model] = model.lazy_loader( | ||
cache_dir=merge_options.transformers_cache, | ||
lazy_unpickle=merge_options.lazy_unpickle, | ||
) | ||
|
||
base_loader = loaders.get(base_model) | ||
writer = TensorWriter( | ||
out_path=out_path, | ||
max_shard_size=merge_options.out_shard_size, | ||
safe_serialization=merge_options.safe_serialization, | ||
) | ||
|
||
return loaders, base_loader, writer | ||
|
||
|
||
def select_dtype( | ||
config: MoEMergeConfig, base_cfg: transformers.PretrainedConfig | ||
) -> Optional[torch.dtype]: | ||
out_dtype = None | ||
if config.dtype: | ||
out_dtype = dtype_from_name(config.dtype) | ||
|
||
if out_dtype is None and base_cfg.torch_dtype: | ||
out_dtype = base_cfg.torch_dtype | ||
if isinstance(out_dtype, str): | ||
out_dtype = dtype_from_name(out_dtype) | ||
return out_dtype | ||
|
||
|
||
def noise_and_scale( | ||
tensor: torch.Tensor, expert: Expert, is_residual: bool = False | ||
) -> torch.Tensor: | ||
if expert.noise_scale is not None: | ||
noise = torch.randn_like(tensor) * expert.noise_scale | ||
tensor = tensor + noise | ||
if is_residual and expert.residual_scale is not None: | ||
tensor = tensor * expert.residual_scale | ||
return tensor |
Oops, something went wrong.