diff --git a/README.md b/README.md index de15f239..fed99678 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,9 @@ Features: - Lazy loading of tensors for low memory use - Interpolated gradients for parameter values (inspired by Gryphe's [BlockMerge_Gradient](https://github.com/Gryphe/BlockMerge_Gradient) script) - Piecewise assembly of language models from layers ("Frankenmerging") +- [Mixture of Experts merging](#mixture-of-experts-merging) -🔊 Call to Evolve - to solve evolutionary merge methods as a community - please see https://github.com/arcee-ai/mergekit/issues/207. +🔊 Call to Evolve - to solve evolutionary merge methods as a community - please see . 🌐 GUI Launch Alert 🤗 - We are excited to announce the launch of a graphical user interface for mergekit in Hugging Face Spaces! This GUI simplifies the merging process, making it more accessible to a broader audience. Check it out and contribute at [Hugging Face Spaces - mergekit-community](https://huggingface.co/mergekit-community). @@ -179,13 +180,17 @@ Parameters: Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models. -### Usage: +### Usage ```sh mergekit-extract-lora finetuned_model_id_or_path base_model_id_or_path output_path [--no-lazy-unpickle] --rank=desired_rank ``` -# Citation +## Mixture of Experts merging + +The `mergekit-moe` script supports merging multiple dense models into a mixture of experts, either for direct use or for further training. For more details see the [`mergekit-moe` documentation](docs/moe.md). + +## Citation We now have a [paper](https://arxiv.org/abs/2403.13257) you can cite for the MergeKit library: diff --git a/docs/moe.md b/docs/moe.md index 890be84f..e3c9de31 100644 --- a/docs/moe.md +++ b/docs/moe.md @@ -1,6 +1,12 @@ # mergekit-moe -`mergekit-moe` is a script for combining Mistral or Llama models of the same size into Mixtral Mixture of Experts models. The script will combine the self-attention and layer normalization parameters from a "base" model with the MLP parameters from a set of "expert" models. `mergekit-moe` uses its own YML configuration syntax, which looks like so: +`mergekit-moe` is a script for combining Mistral or Llama models of the same size into Mixtral Mixture of Experts models. The script will combine the self-attention and layer normalization parameters from a "base" model with the MLP parameters from a set of "expert" models. + +If using the `hidden` or `cheap_embed` gate mode, the output model will be usable without any further training. If you are initializing a model to do further training on, such as for sparse upcycling, then use the `random` gate mode to get a model ready for training. + +## Configuration + +`mergekit-moe` uses its own YML configuration syntax, which looks like so: ```yml base_model: path/to/self_attn_donor @@ -21,18 +27,89 @@ experts: The script takes two arguments, an input config and an output path: `mergekit-moe ./config.yml ./my-clowncar-moe-12x180B` -## Gate Modes +Currently the script can output models that use the Mixtral, Deepseek MoE, or Qwen MoE architectures. Some output architectures support a shared expert which will be activated for all tokens, which can be configured like this: + +```yml +base_model: path/to/self_attn_donor +gate_mode: hidden # one of "hidden", "cheap_embed", or "random" +dtype: bfloat16 # output dtype (float32, float16, or bfloat16) +experts: + ... +shared_experts: + - source_model: model_name + positive_prompts: # required by Qwen MoE for "hidden" gate mode, otherwise not allowed + - "blah blah" + # (optional, but recommended:) + residual_scale: 0.1 # downweight output from shared expert to prevent overcooking the model +``` + +Currently only up to one shared expert is supported. + +An appropriate architecture will be inferred based on the input models and presence or absence of shared experts in your configuration. Alternatively, you can explicitly specify an output architecture by setting the `architecture:` field in your config. For example: + +```yml +base_model: path/to/self_attn_donor +architecture: qwen +# ... and so on +``` + +### Gate Modes There are three methods for populating the MoE gates implemented. -### "hidden" +#### "hidden" Uses the hidden state representations of the positive/negative prompts for MoE gate parameters. Best quality and most effective option; the default. Requires evaluating each prompt using the base model so you might not be able to use this on constrained hardware (depending on the model). You can use `--load-in-8bit` or `--load-in-4bit` to reduce VRAM usage. -### "cheap_embed" +#### "cheap_embed" Uses only the raw token embedding of the prompts, using the same gate parameters for every layer. Distinctly less effective than "hidden". Can be run on much, much lower end hardware. -### "random" +#### "random" Randomly initializes the MoE gates. Good for if you are going to fine tune the model afterwards, or maybe if you want something a little unhinged? I won't judge. + +## Example Configurations + +Sparse upcycling of smol_llama into a 8x220M MoE: + +```yml +base_model: BEE-spoke-data/smol_llama-220M-GQA +gate_mode: random +dtype: bfloat16 +experts: + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA + - source_model: BEE-spoke-data/smol_llama-220M-GQA +# and then train the sucker! +``` + +Shove some Mistral models in a clown car: + +```yml +base_model: NousResearch/Hermes-2-Pro-Mistral-7B +gate_mode: hidden +dtype: bfloat16 +experts: + - source_model: NousResearch/Hermes-2-Pro-Mistral-7B + positive_prompts: + - "<|im_start|>user\nHello, who are you?<|im_end|>" + - "<|im_start|>user\nI need help with" + - source_model: BioMistral/BioMistral-7B-DARE + positive_prompts: + - "As a doctor of medicine," + - source_model: PocketDoc/Dans-AdventurousWinds-7b + positive_prompts: + - "[Genres: Science Fiction]\n[Tags: humor, old school, sci fi]" + - "> get ye flask" + - "[Mode: Interactive Storyteller]" + - source_model: VAGOsolutions/SauerkrautLM-7b-HerO + positive_prompts: + - "<|im_start|>user\nWie geht es dir?<|im_end|>" + - "Das ist ein Satz auf Deutsch." +``` diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 0fb8c15d..16acbbab 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -350,6 +350,7 @@ def _load_all_architectures() -> ( JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() MISTRAL_INFO = _load_json_arch("mistral.json") +QWEN2_INFO = _load_json_arch("qwen2.json") def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: diff --git a/mergekit/common.py b/mergekit/common.py index 1837e5d8..f334c7cd 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -184,7 +184,10 @@ def __str__(self) -> str: return str(self.model) -def dtype_from_name(name: Optional[str]) -> torch.dtype: +def dtype_from_name(name: Optional[str]) -> Optional[torch.dtype]: + if not name: + return None + if name.startswith("torch."): name = name[len("torch.") :] diff --git a/mergekit/moe/__init__.py b/mergekit/moe/__init__.py new file mode 100644 index 00000000..bc1cf067 --- /dev/null +++ b/mergekit/moe/__init__.py @@ -0,0 +1,19 @@ +from typing import List + +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.deepseek import DeepseekMoE +from mergekit.moe.mixtral import MixtralMoE + +ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()] + +try: + from mergekit.moe.qwen import QwenMoE +except ImportError: + pass +else: + ALL_OUTPUT_ARCHITECTURES.append(QwenMoE()) + +__all__ = [ + "ALL_OUTPUT_ARCHITECTURES", + "MoEOutputArchitecture", +] diff --git a/mergekit/moe/arch.py b/mergekit/moe/arch.py new file mode 100644 index 00000000..66a54d61 --- /dev/null +++ b/mergekit/moe/arch.py @@ -0,0 +1,53 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + + +class MoEOutputArchitecture(ABC): + @abstractmethod + def name(self) -> str: + """Return a human-readable name for the architecture.""" + pass + + @abstractmethod + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + """Return whether this architecture supports the given config. + + If `explain` is True, log an explanation of why the config is not supported.""" + pass + + @abstractmethod + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + """Write the config and tensors for the output MoE to the given path.""" + pass diff --git a/mergekit/moe/common.py b/mergekit/moe/common.py new file mode 100644 index 00000000..ce4525ed --- /dev/null +++ b/mergekit/moe/common.py @@ -0,0 +1,75 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Dict, Optional + +import torch +import tqdm +import transformers + +from mergekit.common import ModelReference, dtype_from_name +from mergekit.io import LazyTensorLoader, TensorWriter +from mergekit.merge import MergeOptions +from mergekit.moe.config import Expert, MoEMergeConfig + + +def initialize_io( + config: MoEMergeConfig, + out_path: str, + merge_options: MergeOptions, +) -> tuple[Dict[ModelReference, LazyTensorLoader], LazyTensorLoader, TensorWriter]: + base_model = config.base_model + loaders: Dict[ModelReference, LazyTensorLoader] = {} + for model in tqdm.tqdm( + [base_model] + [e.source_model for e in config.experts], desc="Warm up loaders" + ): + loaders[model] = model.lazy_loader( + cache_dir=merge_options.transformers_cache, + lazy_unpickle=merge_options.lazy_unpickle, + ) + + base_loader = loaders.get(base_model) + writer = TensorWriter( + out_path=out_path, + max_shard_size=merge_options.out_shard_size, + safe_serialization=merge_options.safe_serialization, + ) + + return loaders, base_loader, writer + + +def select_dtype( + config: MoEMergeConfig, base_cfg: transformers.PretrainedConfig +) -> Optional[torch.dtype]: + out_dtype = None + if config.dtype: + out_dtype = dtype_from_name(config.dtype) + + if out_dtype is None and base_cfg.torch_dtype: + out_dtype = base_cfg.torch_dtype + if isinstance(out_dtype, str): + out_dtype = dtype_from_name(out_dtype) + return out_dtype + + +def noise_and_scale( + tensor: torch.Tensor, expert: Expert, is_residual: bool = False +) -> torch.Tensor: + if expert.noise_scale is not None: + noise = torch.randn_like(tensor) * expert.noise_scale + tensor = tensor + noise + if is_residual and expert.residual_scale is not None: + tensor = tensor * expert.residual_scale + return tensor diff --git a/mergekit/moe/config.py b/mergekit/moe/config.py new file mode 100644 index 00000000..2e3f027a --- /dev/null +++ b/mergekit/moe/config.py @@ -0,0 +1,95 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import List, Optional + +from pydantic import BaseModel + +from mergekit.common import ModelReference + + +class Expert(BaseModel): + """ + Defines a model to be used as a set of layerwise experts in a MoE model. + """ + + source_model: ModelReference + + positive_prompts: Optional[List[str]] = None + negative_prompts: Optional[List[str]] = None + noise_scale: Optional[float] = None + residual_scale: Optional[float] = None + + +class MoEMergeConfig(BaseModel): + """ + Configuration for merging a set of "expert" models into a MoE model. + """ + + base_model: ModelReference + experts: List[Expert] + gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random" + # "hidden" uses hidden state vectors for the given prompts for each layer + # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer + # "random" is random + dtype: Optional[str] = None + experts_per_token: int = 2 + shared_experts: Optional[List[Expert]] = None + architecture: Optional[str] = None + + +def is_bad_config(config: MoEMergeConfig, allow_all_same: bool = False) -> bool: + if config.experts_per_token < 1: + logging.error("Experts per token must be >= 1") + return True + + if len(config.experts) < config.experts_per_token: + logging.error("Must include at least as many experts as experts_per_token.") + return True + + if config.gate_mode == "random": + return False # eh we're good + + for expert_idx, expert in enumerate(config.experts): + if not expert.positive_prompts: + logging.error(f"Expert {expert_idx} has no positive prompts.") + return True + + def prompt_tup(e: Expert): + return (tuple(e.positive_prompts), tuple(e.negative_prompts or [])) + + # let's just nip this trend in the bud + p_first = prompt_tup(config.experts[0]) + if all(prompt_tup(e) == p_first for e in config.experts[1:]): + logging.error( + "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE." + ) + logging.error( + "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert." + ) + return True + + if not allow_all_same: + if all( + e.source_model == config.experts[0].source_model for e in config.experts[1:] + ): + logging.error( + "All of your expert models are the same. This will produce " + "a model that uses more resources but gives the exact same output. " + "If you plan to train the model after merging, proceed with the " + "--i-understand-this-is-not-useful-without-training flag." + ) + return True diff --git a/mergekit/moe/deepseek.py b/mergekit/moe/deepseek.py new file mode 100644 index 00000000..1f7226fb --- /dev/null +++ b/mergekit/moe/deepseek.py @@ -0,0 +1,196 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import json +import logging +import os +from typing import Dict, List, Optional + +import torch +import tqdm +import transformers + +from mergekit.architecture import get_architecture_info +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + + +class DeepseekMoE(MoEOutputArchitecture): + def name(self) -> str: + return "DeepSeek MoE" + + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + if config.shared_experts: + if len(config.shared_experts) > 1: + if explain: + logging.warning( + "DeepSeek MoE merge does not support more than one shared expert" + ) + return False + + if ( + config.shared_experts[0].positive_prompts + or config.shared_experts[0].negative_prompts + ): + if explain: + logging.warning( + "DeepSeek MoE merge does not support gating shared experts" + ) + return False + + model_types = [] + for model_ref in ( + [config.base_model] + + [e.source_model for e in config.experts] + + [e.source_model for e in (config.shared_experts or [])] + ): + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + model_types.append(model_cfg.model_type) + + if len(set(model_types)) != 1: + if explain: + logging.warning( + "Deepseek MoE requires all input models to have the same architecture" + ) + return False + if model_types[0] not in ("llama", "mistral"): + if explain: + logging.warning( + "Deepseek MoE requires all input models to be Llama or Mistral models" + ) + return False + return True + + def _generate_config( + self, + base_config: transformers.PretrainedConfig, + num_experts: int, + shared_experts: Optional[int] = None, + experts_per_token: Optional[int] = None, + ) -> Dict: + if shared_experts and shared_experts > 1: + raise NotImplementedError( + "Shared experts must be 0 or 1 for DeepSeek output" + ) + + res = base_config.to_dict() + res["architectures"] = ["DeepseekForCausalLM"] + res["model_type"] = "deepseek" + res["n_routed_experts"] = num_experts + res["n_shared_experts"] = shared_experts or None + res["num_experts_per_tok"] = experts_per_token or (1 if shared_experts else 2) + res["first_k_dense_replace"] = 0 + res["moe_layer_freq"] = 1 + res["scoring_func"] = "softmax" + res["norm_topk_prob"] = True + res["moe_intermediate_size"] = res["intermediate_size"] + res["auto_map"] = { + "AutoConfig": "deepseek-ai/deepseek-moe-16b-base--configuration_deepseek.DeepseekConfig", + "AutoModel": "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekModel", + "AutoModelForCausalLM": "deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekForCausalLM", + } + return res + + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + base_model = config.base_model + base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) + + out_dtype = select_dtype(config, base_cfg) + out_cfg = self._generate_config( + base_cfg, + len(config.experts), + len(config.shared_experts or []), + config.experts_per_token, + ) + if out_dtype is not None: + out_cfg["torch_dtype"] = str(out_dtype).removeprefix("torch.") + with open(os.path.join(out_path, "config.json"), "w", encoding="utf-8") as f: + json.dump(out_cfg, f, indent=4) + + shared_def = config.shared_experts[0] if config.shared_experts else None + + loaders, base_loader, writer = initialize_io(config, out_path, merge_options) + shared_loader = loaders.get(shared_def.source_model) if shared_def else None + for weight_info in tqdm.tqdm( + get_architecture_info(base_cfg).all_weights(base_cfg), + desc="Weights", + ): + tensor_name = weight_info.name + if ".mlp." in tensor_name: + for expert_idx, expert in enumerate(config.experts): + expert_name = tensor_name.replace( + ".mlp.", f".mlp.experts.{expert_idx}." + ) + expert_loader = loaders.get(expert.source_model) + tensor = expert_loader.get_tensor( + weight_info.name, aliases=weight_info.aliases + ) + tensor = noise_and_scale( + tensor, expert, is_residual="down_proj" in tensor_name + ) + writer.save_tensor( + expert_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + + if shared_def is not None: + shared_tensor = shared_loader.get_tensor( + weight_info.name, aliases=weight_info.aliases + ) + shared_tensor = noise_and_scale( + shared_tensor, + shared_def, + is_residual="down_proj" in tensor_name, + ) + writer.save_tensor( + tensor_name.replace(".mlp.", ".mlp.shared_experts."), + shared_tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + else: + tensor = base_loader.get_tensor( + tensor_name, aliases=weight_info.aliases + ) + writer.save_tensor( + tensor_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + + for layer_idx, weight in enumerate( + tqdm.tqdm(router_weights, desc="Router weights") + ): + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.gate.weight", + weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + writer.finalize() diff --git a/mergekit/moe/mixtral.py b/mergekit/moe/mixtral.py new file mode 100644 index 00000000..538cb701 --- /dev/null +++ b/mergekit/moe/mixtral.py @@ -0,0 +1,178 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import List, Optional + +import torch +import tqdm +import transformers + +from mergekit.architecture import MISTRAL_INFO, WeightInfo +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + + +class MixtralMoE(MoEOutputArchitecture): + def name(self) -> str: + return "Mixtral" + + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + if config.shared_experts: + if explain: + logging.warning("Mixtral does not support shared experts") + return False + + model_types = [] + for model_ref in [config.base_model] + [e.source_model for e in config.experts]: + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + model_types.append(model_cfg.model_type) + + if len(set(model_types)) != 1: + if explain: + logging.warning( + "Mixtral requires all input models to have the same architecture" + ) + return False + if model_types[0] not in ("llama", "mistral"): + if explain: + logging.warning( + "Mixtral requires all input models to be Llama or Mistral models" + ) + return False + return True + + def _generate_config( + self, + base_config: transformers.PretrainedConfig, + num_experts: int, + shared_experts: Optional[int] = None, + experts_per_token: Optional[int] = None, + ) -> transformers.PretrainedConfig: + if shared_experts: + raise NotImplementedError("Shared experts not supported for Mixtral output") + + if not isinstance(base_config, transformers.MistralConfig): + base_cfg_mistral = transformers.MistralConfig(**base_config.to_dict()) + base_cfg_mistral.sliding_window = None + base_cfg_mistral.max_position_embeddings = ( + base_config.max_position_embeddings + ) + base_config = base_cfg_mistral + + out_cfg = transformers.MixtralConfig(**base_config.to_dict()) + out_cfg.architectures = ["MixtralForCausalLM"] + out_cfg.num_local_experts = num_experts + out_cfg.num_experts_per_tok = experts_per_token or 2 + out_cfg.sliding_window = None + + if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0: + logging.warning( + f"Your model has {out_cfg.num_local_experts} experts, which is " + "not a power of two. The model will not be usable in llama.cpp." + ) + return out_cfg + + def _remap_weight_name(self, weight: WeightInfo) -> str: + if ".mlp." not in weight.name: + # Everything but MLP is identical to base Mistral + return weight.name + + res = weight.name + for needle, replacement in [ + (".mlp.gate_proj", ".block_sparse_moe.experts.{expert_idx}.w1"), + (".mlp.down_proj", ".block_sparse_moe.experts.{expert_idx}.w2"), + (".mlp.up_proj", ".block_sparse_moe.experts.{expert_idx}.w3"), + ]: + res = res.replace(needle, replacement) + return res + + def _router_weight_name(self, layer_idx: int) -> str: + return f"model.layers.{layer_idx}.block_sparse_moe.gate.weight" + + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + base_model = config.base_model + base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) + + assert len(router_weights) == base_cfg.num_hidden_layers, ( + f"Expected {base_cfg.num_hidden_layers} router weights, " + f"got {len(router_weights)}" + ) + + out_dtype = select_dtype(config, base_cfg) + out_cfg = self._generate_config( + base_cfg, + len(config.experts), + len(config.shared_experts or []), + config.experts_per_token, + ) + out_cfg.torch_dtype = out_dtype + out_cfg.save_pretrained(out_path) + + loaders, base_loader, writer = initialize_io(config, out_path, merge_options) + for weight_info in tqdm.tqdm( + MISTRAL_INFO.all_weights(base_cfg), + desc="Weights", + ): + tensor_name = self._remap_weight_name(weight_info) + if "{expert_idx}" in tensor_name: + for expert_index, expert in enumerate(config.experts): + expert_name = tensor_name.replace("{expert_idx}", str(expert_index)) + expert_loader = loaders.get(expert.source_model) + tensor = expert_loader.get_tensor( + weight_info.name, aliases=weight_info.aliases + ) + tensor = noise_and_scale( + tensor, expert, is_residual="down_proj" in tensor_name + ) + writer.save_tensor( + expert_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + else: + tensor = base_loader.get_tensor( + tensor_name, aliases=weight_info.aliases + ) + writer.save_tensor( + tensor_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + + for layer_idx, weight in enumerate( + tqdm.tqdm(router_weights, desc="Router weights") + ): + writer.save_tensor( + self._router_weight_name(layer_idx), + weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + writer.finalize() diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py new file mode 100644 index 00000000..76935a46 --- /dev/null +++ b/mergekit/moe/qwen.py @@ -0,0 +1,206 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import List, Optional + +import torch +import tqdm +import transformers + +# explicitly import the config class so that we can catch errors upstream +# if the transformers version installed is too old +from transformers.models.qwen2_moe import Qwen2MoeConfig + +from mergekit.architecture import QWEN2_INFO +from mergekit.moe.arch import MoEOutputArchitecture +from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.config import MoEMergeConfig +from mergekit.options import MergeOptions + + +class QwenMoE(MoEOutputArchitecture): + def name(self) -> str: + return "Qwen MoE" + + def supports_config( + self, + config: MoEMergeConfig, + explain: bool = False, + trust_remote_code: bool = False, + ) -> bool: + if len(config.shared_experts or []) != 1: + if explain: + logging.warning("Qwen MoE merge requires exactly one shared expert") + return False + + if ( + config.gate_mode != "random" + and not config.shared_experts[0].positive_prompts + ): + if explain: + logging.warning("Qwen MoE requires the shared expert to have prompts") + return False + + model_types = [] + for model_ref in ( + [config.base_model] + + [e.source_model for e in config.experts] + + [e.source_model for e in (config.shared_experts or [])] + ): + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + model_types.append(model_cfg.model_type) + + if len(set(model_types)) != 1: + if explain: + logging.warning( + "Qwen MoE requires all input models to have the same architecture" + ) + return False + if model_types[0] not in ("llama", "mistral", "qwen"): + if explain: + logging.warning( + "Qwen MoE requires all input models to be Llama or Mistral models" + ) + return False + return True + + def _generate_config( + self, + base_config: transformers.PretrainedConfig, + num_experts: int, + experts_per_token: Optional[int] = None, + ) -> Qwen2MoeConfig: + out_cfg = Qwen2MoeConfig(**base_config.to_dict()) + out_cfg.architectures = ["Qwen2MoeForCausalLM"] + out_cfg.num_experts = num_experts + out_cfg.num_experts_per_tok = experts_per_token or 2 + out_cfg.decoder_sparse_step = 1 + out_cfg.norm_topk_prob = True + out_cfg.sliding_window = None + out_cfg.use_sliding_window = False + out_cfg.shared_expert_intermediate_size = out_cfg.intermediate_size + out_cfg.moe_intermediate_size = out_cfg.intermediate_size + + if (out_cfg.num_experts & (out_cfg.num_experts - 1)) != 0: + logging.warning( + f"Your model has {out_cfg.num_experts} experts, which is " + "not a power of two. The model will not be usable in llama.cpp." + ) + return out_cfg + + def write_model( + self, + out_path: str, + config: MoEMergeConfig, + merge_options: MergeOptions, + router_weights: List[torch.Tensor], + shared_router_weights: Optional[List[torch.Tensor]] = None, + ): + base_model = config.base_model + base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) + + out_dtype = select_dtype(config, base_cfg) + out_cfg = self._generate_config( + base_cfg, + len(config.experts), + config.experts_per_token, + ) + if out_dtype is not None: + out_cfg.torch_dtype = out_dtype + out_cfg.save_pretrained(out_path) + + shared_def = config.shared_experts[0] + + loaders, base_loader, writer = initialize_io(config, out_path, merge_options) + shared_loader = loaders.get(shared_def.source_model) if shared_def else None + for weight_info in tqdm.tqdm( + QWEN2_INFO.all_weights(base_cfg), + desc="Weights", + ): + tensor_name = weight_info.name + if ".mlp." in tensor_name: + for expert_idx, expert in enumerate(config.experts): + expert_name = tensor_name.replace( + ".mlp.", f".mlp.experts.{expert_idx}." + ) + expert_loader = loaders.get(expert.source_model) + tensor = expert_loader.get_tensor( + weight_info.name, aliases=weight_info.aliases + ) + tensor = noise_and_scale( + tensor, expert, is_residual="down_proj" in tensor_name + ) + writer.save_tensor( + expert_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + + shared_tensor = shared_loader.get_tensor( + weight_info.name, aliases=weight_info.aliases + ) + shared_tensor = noise_and_scale( + shared_tensor, + shared_def, + is_residual="down_proj" in tensor_name, + ) + writer.save_tensor( + tensor_name.replace(".mlp.", ".mlp.shared_expert."), + shared_tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + else: + try: + tensor = base_loader.get_tensor( + tensor_name, aliases=weight_info.aliases + ) + except KeyError: + if tensor_name.endswith("_proj.bias"): + # qwen 2 moe wants attention bias, give it zeros + head_dim = out_cfg.hidden_size // out_cfg.num_attention_heads + num_heads = ( + out_cfg.num_key_value_heads + if ( + tensor_name.endswith("k_proj.bias") + or tensor_name.endswith("v_proj.bias") + ) + else out_cfg.num_attention_heads + ) + tensor = torch.zeros(num_heads * head_dim, dtype=out_dtype) + else: + raise + + writer.save_tensor( + tensor_name, + tensor.to(dtype=out_dtype), + clone=merge_options.clone_tensors, + ) + + for layer_idx, weight in enumerate( + tqdm.tqdm(router_weights, desc="Router weights") + ): + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.gate.weight", + weight.to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + writer.save_tensor( + f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight", + shared_router_weights[layer_idx].to(dtype=out_dtype).contiguous(), + clone=merge_options.clone_tensors, + ) + + writer.finalize() diff --git a/mergekit/moe/router.py b/mergekit/moe/router.py new file mode 100644 index 00000000..3d0d0ec5 --- /dev/null +++ b/mergekit/moe/router.py @@ -0,0 +1,177 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import Dict, List, Union + +import torch +import tqdm +import transformers +from transformers import AutoModelForCausalLM, LlamaForCausalLM, MistralForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithPast + +from mergekit.common import ModelReference +from mergekit.moe.config import Expert + + +def get_hidden_states( + model: Union[MistralForCausalLM, LlamaForCausalLM], + tokenized: transformers.BatchEncoding, + average: bool = True, +) -> List[torch.Tensor]: + with torch.no_grad(): + output: CausalLMOutputWithPast = model( + **tokenized.to(model.device), output_hidden_states=True, return_dict=True + ) + hidden_states = torch.stack( + output.hidden_states[:-1] + ) # (num_layers, batch_size, seq_len, hidden_size) + if average: + # use average over sequence + hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2] + else: + # take last value + hidden_states = hidden_states[:, :, -1, :] + return hidden_states.sum(dim=1) / hidden_states.shape[1] + + +def get_cheap_embedding( + embed: torch.Tensor, + tokenized: Dict[str, torch.Tensor], + num_layers: int, + vocab_size: int, +) -> torch.Tensor: + onehot = torch.nn.functional.one_hot( + tokenized["input_ids"], num_classes=vocab_size + ) # (batch_size, seq_len, 32000) + h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size) + embedded = ( + (h * tokenized["attention_mask"].unsqueeze(-1)) + .sum(dim=1) + .sum(dim=0, keepdim=True) + ) # (1, hidden_size) + res = embedded / embedded.norm(dim=-1, keepdim=True).clamp( + min=1e-8 + ) # (1, hidden_size) + return res.repeat(num_layers, 1) + + +def tokenize_prompts( + prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase +): + return tokenizer( + [(tokenizer.bos_token or "") + p for p in prompts], + return_tensors="pt", + padding=True, + add_special_tokens=False, + ) + + +def get_gate_params( + model_ref: ModelReference, + tokenizer: transformers.PreTrainedTokenizerBase, + experts: List[Expert], + mode: str = "hidden", + load_in_4bit: bool = False, + load_in_8bit: bool = False, + lazy_unpickle: bool = False, + trust_remote_code: bool = False, + device: str = "auto", +): + gate_vecs = [] + _do_it = None + + model_cfg = model_ref.config(trust_remote_code=trust_remote_code) + + if mode == "random": + return torch.randn( + (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size) + ) + elif mode == "cheap_embed": + embed = model_ref.lazy_loader(lazy_unpickle=lazy_unpickle).get_tensor( + "model.embed_tokens.weight" + ) + + def _do_it(tokenized): + return get_cheap_embedding( + embed, + tokenized, + num_layers=model_cfg.num_hidden_layers, + vocab_size=model_cfg.vocab_size, + ) + + elif mode in ("hidden", "hidden_avg", "hidden_last"): + model = AutoModelForCausalLM.from_pretrained( + model_ref.model.path, + revision=model_ref.model.revision, + torch_dtype=torch.bfloat16, + device_map=device, + low_cpu_mem_usage=True, + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + trust_remote_code=trust_remote_code, + ) + + def _do_it(tokenized): + return get_hidden_states( + model, tokenized=tokenized, average=mode == "hidden_avg" + ) + + gate_vecs = [] + for expert in tqdm.tqdm(experts, desc="expert prompts"): + hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer)) + if expert.negative_prompts: + hidden_states -= _do_it( + tokenize_prompts(expert.negative_prompts, tokenizer) + ) + + hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8) + gate_vecs.append(hidden_states) + gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size) + return gate_vecs.permute(1, 0, 2) + + +def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0): + degen_indices = [] + num_layers, _num_experts, _hidden_size = gate_vecs.shape + for idx in range(num_layers): + c = torch.linalg.cond(gate_vecs[idx, :, :].float()) + if c > threshold: + degen_indices.append(idx) + + if degen_indices: + if len(degen_indices) == 1: + layer_str = f"layer {degen_indices[0]}" + verb = "has" + elif len(degen_indices) == 2: + layer_str = f"layers {' and '.join(map(str, degen_indices))}" + verb = "have" + elif len(degen_indices) >= num_layers: + layer_str = "ALL layers" + verb = "have" + else: + layer_str = ( + "layers " + + ", ".join(map(str, degen_indices[:-1])) + + ", and " + + str(degen_indices[-1]) + ) + verb = "have" + + logging.warning( + f"{layer_str} {verb} degenerate routing parameters " + "- your prompts may be too similar." + ) + logging.warning("One or more experts will be underutilized in your model.") diff --git a/mergekit/scripts/mixtral_moe.py b/mergekit/scripts/mixtral_moe.py deleted file mode 100644 index 5cdb1d63..00000000 --- a/mergekit/scripts/mixtral_moe.py +++ /dev/null @@ -1,482 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -import logging -import os -import sys -from typing import Dict, List, Optional, Union - -import click -import torch -import tqdm -import transformers -import yaml -from pydantic import BaseModel -from transformers import ( - AutoModelForCausalLM, - LlamaForCausalLM, - MistralConfig, - MistralForCausalLM, - MixtralConfig, -) -from transformers.modeling_outputs import CausalLMOutputWithPast - -import mergekit.architecture -from mergekit.common import ModelReference, dtype_from_name -from mergekit.io import LazyTensorLoader, TensorWriter -from mergekit.merge import MergeOptions -from mergekit.options import add_merge_options - -# Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models. -# Takes the path to a yml config and an output path. -# Config schema is the two classes below. - - -class Expert(BaseModel): - source_model: str - - positive_prompts: List[str] - negative_prompts: Optional[List[str]] = None - noise_scale: Optional[float] = None - - @property - def model_ref(self): - return ModelReference.parse(self.source_model) - - -class MistralMOEConfig(BaseModel): - base_model: str - experts: List[Expert] - gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random" - # "hidden" uses hidden state vectors for the given prompts for each layer - # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer - # "random" is random - dtype: Optional[str] = None - experts_per_token: int = 2 - - -def get_hidden_states( - model: Union[MistralForCausalLM, LlamaForCausalLM], - tokenized: transformers.BatchEncoding, - average: bool = True, -) -> List[torch.Tensor]: - with torch.no_grad(): - output: CausalLMOutputWithPast = model( - **tokenized.to(model.device), output_hidden_states=True, return_dict=True - ) - hidden_states = torch.stack( - output.hidden_states[:-1] - ) # (num_layers, batch_size, seq_len, hidden_size) - if average: - # use average over sequence - hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2] - else: - # take last value - hidden_states = hidden_states[:, :, -1, :] - return hidden_states.sum(dim=1) / hidden_states.shape[1] - - -def get_cheap_embedding( - embed: torch.Tensor, - tokenized: Dict[str, torch.Tensor], - num_layers: int, - vocab_size: int, -) -> torch.Tensor: - onehot = torch.nn.functional.one_hot( - tokenized["input_ids"], num_classes=vocab_size - ) # (batch_size, seq_len, 32000) - h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size) - embedded = ( - (h * tokenized["attention_mask"].unsqueeze(-1)) - .sum(dim=1) - .sum(dim=0, keepdim=True) - ) # (1, hidden_size) - res = embedded / embedded.norm(dim=-1, keepdim=True).clamp( - min=1e-8 - ) # (1, hidden_size) - return res.repeat(num_layers, 1) - - -def tokenize_prompts( - prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase -): - return tokenizer( - [(tokenizer.bos_token or "") + p for p in prompts], - return_tensors="pt", - padding=True, - add_special_tokens=False, - ) - - -def get_gate_params( - model_ref: ModelReference, - tokenizer: transformers.PreTrainedTokenizerBase, - experts: List[Expert], - mode: str = "hidden", - load_in_4bit: bool = False, - load_in_8bit: bool = False, - lazy_unpickle: bool = False, - trust_remote_code: bool = False, - device: str = "auto", -): - gate_vecs = [] - _do_it = None - - model_cfg = model_ref.config(trust_remote_code=trust_remote_code) - - if mode == "random": - return torch.randn( - (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size) - ) - elif mode == "cheap_embed": - embed = model_ref.lazy_loader(lazy_unpickle=lazy_unpickle).get_tensor( - "model.embed_tokens.weight" - ) - - def _do_it(tokenized): - return get_cheap_embedding( - embed, - tokenized, - num_layers=model_cfg.num_hidden_layers, - vocab_size=model_cfg.vocab_size, - ) - - elif mode in ("hidden", "hidden_avg", "hidden_last"): - model = AutoModelForCausalLM.from_pretrained( - model_ref.model.path, - revision=model_ref.model.revision, - torch_dtype=torch.bfloat16, - device_map=device, - low_cpu_mem_usage=True, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - trust_remote_code=trust_remote_code, - ) - - def _do_it(tokenized): - return get_hidden_states( - model, tokenized=tokenized, average=mode == "hidden_avg" - ) - - gate_vecs = [] - for expert in tqdm.tqdm(experts, desc="expert prompts"): - hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer)) - if expert.negative_prompts: - hidden_states -= _do_it( - tokenize_prompts(expert.negative_prompts, tokenizer) - ) - - hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8) - gate_vecs.append(hidden_states) - gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size) - return gate_vecs.permute(1, 0, 2) - - -def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0): - degen_indices = [] - num_layers, _num_experts, _hidden_size = gate_vecs.shape - for idx in range(num_layers): - c = torch.linalg.cond(gate_vecs[idx, :, :].float()) - if c > threshold: - degen_indices.append(idx) - - if degen_indices: - if len(degen_indices) == 1: - layer_str = f"layer {degen_indices[0]}" - verb = "has" - elif len(degen_indices) == 2: - layer_str = f"layers {' and '.join(map(str, degen_indices))}" - verb = "have" - elif len(degen_indices) >= num_layers: - layer_str = "ALL layers" - verb = "have" - else: - layer_str = ( - "layers " - + ", ".join(map(str, degen_indices[:-1])) - + ", and " - + str(degen_indices[-1]) - ) - verb = "have" - - logging.warning( - f"{layer_str} {verb} degenerate routing parameters " - "- your prompts may be too similar." - ) - logging.warning("One or more experts will be underutilized in your model.") - - -def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool: - if len(config.experts) < 2: - logging.error("Must include at least two experts.") - return True - - if config.gate_mode == "random": - return False # eh we're good - - def prompt_tup(e: Expert): - return (tuple(e.positive_prompts), tuple(e.negative_prompts or [])) - - # let's just nip this trend in the bud - p_first = prompt_tup(config.experts[0]) - if all(prompt_tup(e) == p_first for e in config.experts[1:]): - logging.error( - "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE." - ) - logging.error( - "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert." - ) - return True - - if not allow_all_same: - if all( - e.source_model == config.experts[0].source_model for e in config.experts[1:] - ): - logging.error( - "All of your expert models are the same. This will produce " - "a model that uses more resources but gives the exact same output. " - "If you plan to train the model after merging, proceed with the " - "--i-understand-this-is-not-useful-without-training flag." - ) - return True - - -def build( - config: MistralMOEConfig, - out_path: str, - merge_options: MergeOptions, - load_in_4bit: bool = False, - load_in_8bit: bool = False, - device: str = "auto", - allow_all_same: bool = False, -): - if is_bad_config(config, allow_all_same=allow_all_same): - sys.exit(1) - - if config.experts_per_token < 1: - logging.error("Experts per token must be >= 1") - sys.exit(1) - if config.experts_per_token > len(config.experts): - logging.error("Experts per token must be <= number of experts") - sys.exit(1) - - base_model = ModelReference.parse(config.base_model) - base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) - if not isinstance(base_cfg, MistralConfig): - base_cfg_mistral = MistralConfig(**base_cfg.to_dict()) - base_cfg_mistral.sliding_window = None - base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings - base_cfg = base_cfg_mistral - - out_cfg = MixtralConfig(**base_cfg.to_dict()) - out_cfg.architectures = ["MixtralForCausalLM"] - out_cfg.num_local_experts = len(config.experts) - out_cfg.num_experts_per_tok = config.experts_per_token - out_cfg.sliding_window = None - if config.dtype: - out_cfg.torch_dtype = config.dtype - out_cfg.save_pretrained(out_path) - - if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0: - logging.warning( - f"Your model has {out_cfg.num_local_experts} experts, which is " - "not a power of two. The model will not be usable in llama.cpp." - ) - - loaders: Dict[ModelReference, LazyTensorLoader] = {} - for model in tqdm.tqdm( - [base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders" - ): - loaders[model] = model.lazy_loader( - cache_dir=merge_options.transformers_cache, - lazy_unpickle=merge_options.lazy_unpickle, - ) - - base_loader = loaders.get(base_model) - writer = TensorWriter( - out_path=out_path, - max_shard_size=merge_options.out_shard_size, - safe_serialization=merge_options.safe_serialization, - ) - - if config.dtype: - out_dtype = dtype_from_name(config.dtype) - elif base_cfg.torch_dtype: - out_dtype = base_cfg.torch_dtype - if isinstance(out_dtype, str): - out_dtype = dtype_from_name(out_dtype) - else: - out_dtype = None - - logging.info("Copying parameters...") - MISTRAL_INFO = mergekit.architecture.MISTRAL_INFO - for weight_info in MISTRAL_INFO.pre_weights(base_cfg) + MISTRAL_INFO.post_weights( - base_cfg - ): - tensor_name = weight_info.name - tensor = base_loader.get_tensor(tensor_name, aliases=weight_info.aliases) - if not out_dtype: - # All else has failed, take the first dtype we see - out_dtype = tensor.dtype - writer.save_tensor( - tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors - ) - - for layer_idx in range(base_cfg.num_hidden_layers): - for weight_info in MISTRAL_INFO.layer_weights(index=layer_idx, config=base_cfg): - tensor_name = weight_info.name - - if ".mlp." in tensor_name: - for moe_index, expert in enumerate(config.experts): - expert_name = tensor_name.replace( - ".mlp.gate_proj", f".block_sparse_moe.experts.{moe_index}.w1" - ) - expert_name = expert_name.replace( - ".mlp.down_proj", f".block_sparse_moe.experts.{moe_index}.w2" - ) - expert_name = expert_name.replace( - ".mlp.up_proj", f".block_sparse_moe.experts.{moe_index}.w3" - ) - expert_loader = loaders.get(expert.model_ref) - tensor = expert_loader.get_tensor( - tensor_name, aliases=weight_info.aliases - ) - if expert.noise_scale: - tensor += torch.randn_like(tensor) * expert.noise_scale - writer.save_tensor( - expert_name, tensor.to(dtype=out_dtype), clone=True - ) - continue - writer.save_tensor( - tensor_name, - base_loader.get_tensor(tensor_name, aliases=weight_info.aliases).to( - dtype=out_dtype - ), - ) - - tokenizer = transformers.AutoTokenizer.from_pretrained( - base_model.model.path, revision=base_model.model.revision - ) - tokenizer.padding_side = "left" - tokenizer.pad_token_id = tokenizer.bos_token_id - if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token - - logging.info("Getting gate parameters...") - gate_vecs = get_gate_params( - base_model, - tokenizer, - config.experts, - mode=config.gate_mode, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - lazy_unpickle=merge_options.lazy_unpickle, - trust_remote_code=merge_options.trust_remote_code, - device=device, - ) - # gate_vecs: (num_layers, num_experts, hidden_size) - - warn_degenerate_gates(gate_vecs) - - for layer_idx in range(base_cfg.num_hidden_layers): - writer.save_tensor( - f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", - gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype), - ) - writer.finalize() - - if merge_options.copy_tokenizer: - logging.info("Saving tokenizer...") - tokenizer.save_pretrained(out_path, safe_serialization=True) - - logging.info("Done.") - - -@click.command("mergekit-moe") -@click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) -@click.argument("out_path", type=click.Path()) -@click.option( - "--load-in-4bit", - is_flag=True, - type=bool, - default=False, - help="Load model in 4bit for computing hidden states", -) -@click.option( - "--load-in-8bit", - is_flag=True, - type=bool, - default=False, - help="Load model in 8bit for computing hidden states", -) -@click.option( - "--device", - type=str, - default="auto", - help="Device to use to compute embeddings", - show_default=True, -) -@click.option( - "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging" -) -@click.option( - "--i-understand-this-is-not-useful-without-training", - type=bool, - default=False, - is_flag=True, - help="Really make the questionable model you want.", -) -@add_merge_options -def main( - config_path: str, - out_path: str, - load_in_4bit: bool, - load_in_8bit: bool, - device: str, - merge_options: MergeOptions, - verbose: bool, - i_understand_this_is_not_useful_without_training: bool, -): - logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) - - if merge_options.cuda: - logging.warning( - '--cuda is a no-op for mergekit-moe, use "--device cuda" instead' - ) - - with open(config_path, "r", encoding="utf-8") as file: - config_source = file.read() - - config = MistralMOEConfig.model_validate(yaml.safe_load(config_source)) - build( - config, - out_path=out_path, - merge_options=merge_options, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - device=device, - allow_all_same=i_understand_this_is_not_useful_without_training, - ) - - if merge_options.write_model_card: - # TODO: generate a README.md as well - with open( - os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8" - ) as fp: - fp.write(config_source) - - -if __name__ == "__main__": - main() diff --git a/mergekit/scripts/moe.py b/mergekit/scripts/moe.py new file mode 100644 index 00000000..fa0c11f7 --- /dev/null +++ b/mergekit/scripts/moe.py @@ -0,0 +1,231 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +import os +import sys +from typing import List + +import click +import transformers +import yaml + +from mergekit.merge import MergeOptions +from mergekit.moe import ALL_OUTPUT_ARCHITECTURES, MoEOutputArchitecture +from mergekit.moe.config import MoEMergeConfig, is_bad_config +from mergekit.moe.router import get_gate_params, warn_degenerate_gates +from mergekit.options import add_merge_options + + +def build( + config: MoEMergeConfig, + out_path: str, + merge_options: MergeOptions, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + device: str = "auto", + allow_all_same: bool = False, + verbose: bool = False, +): + if is_bad_config(config, allow_all_same=allow_all_same): + sys.exit(1) + + base_model = config.base_model + out_arch = select_output_arch(config, merge_options, verbose=verbose) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + base_model.model.path, revision=base_model.model.revision + ) + tokenizer.padding_side = "left" + tokenizer.pad_token_id = tokenizer.bos_token_id + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + logging.info("Getting gate parameters...") + need_gates = list(config.experts) + if config.shared_experts: + has_prompts = any(e.positive_prompts for e in config.shared_experts) + assert all( + bool(e.positive_prompts) == has_prompts for e in config.shared_experts + ), "Must specify prompts for all shared experts or none, not a mix" + if has_prompts: + need_gates.extend(config.shared_experts) + + gate_vecs = get_gate_params( + base_model, + tokenizer, + need_gates, + mode=config.gate_mode, + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + lazy_unpickle=merge_options.lazy_unpickle, + trust_remote_code=merge_options.trust_remote_code, + device=device, + ) + # gate_vecs: (num_layers, num_experts, hidden_size) + router_weights = gate_vecs[:, : len(config.experts), :] + shared_router_weights = gate_vecs[:, len(config.experts) :, :] + warn_degenerate_gates(gate_vecs) + + out_arch.write_model( + out_path, + config, + merge_options, + router_weights=[router_weights[i, ...] for i in range(router_weights.shape[0])], + shared_router_weights=[ + shared_router_weights[i, ...] for i in range(router_weights.shape[0]) + ], + ) + + if merge_options.copy_tokenizer: + logging.info("Saving tokenizer...") + tokenizer.save_pretrained(out_path, safe_serialization=True) + + logging.info("Done.") + + +def select_output_arch( + config: MoEMergeConfig, + merge_options: MergeOptions, + verbose: bool = False, +) -> MoEOutputArchitecture: + candidates_in = ALL_OUTPUT_ARCHITECTURES + if config.architecture: + candidates_in = [ + a + for a in candidates_in + if a.name().lower().startswith(config.architecture.lower()) + ] + if not candidates_in: + logging.error( + f"No output architecture found that matches the given architecture: {config.architecture}" + ) + logging.error("All supported output architectures:") + for arch in ALL_OUTPUT_ARCHITECTURES: + logging.error(f" * {arch.name()}") + sys.exit(1) + + candidates: List[MoEOutputArchitecture] = [] + for arch in candidates_in: + if arch.supports_config( + config, explain=verbose, trust_remote_code=merge_options.trust_remote_code + ): + candidates.append(arch) + else: + logging.info(f"Output architecture {arch.name()} does not support config.") + + if not candidates: + logging.error( + "No output architecture found that is compatible with the given models." + ) + + logging.error("All supported output architectures:") + for arch in ALL_OUTPUT_ARCHITECTURES: + logging.error(f" * {arch.name()}") + sys.exit(1) + + # for compatibility with older configs, default to Mixtral if available + for arch in candidates: + if arch.name() == "Mixtral": + return arch + + if len(candidates) > 1: + logging.warning( + "Multiple output architectures found that are compatible with the given models." + ) + logging.warning(f"Defaulting to {candidates[0].name()}") + else: + logging.info(f"Selected output architecture: {candidates[0].name()}") + return candidates[0] + + +@click.command("mergekit-moe") +@click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) +@click.argument("out_path", type=click.Path()) +@click.option( + "--load-in-4bit", + is_flag=True, + type=bool, + default=False, + help="Load model in 4bit for computing hidden states", +) +@click.option( + "--load-in-8bit", + is_flag=True, + type=bool, + default=False, + help="Load model in 8bit for computing hidden states", +) +@click.option( + "--device", + type=str, + default="auto", + help="Device to use to compute embeddings", + show_default=True, +) +@click.option( + "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging" +) +@click.option( + "--i-understand-this-is-not-useful-without-training", + type=bool, + default=False, + is_flag=True, + help="Really make the questionable model you want.", +) +@add_merge_options +def main( + config_path: str, + out_path: str, + load_in_4bit: bool, + load_in_8bit: bool, + device: str, + merge_options: MergeOptions, + verbose: bool, + i_understand_this_is_not_useful_without_training: bool, +): + """Create a Mixture of Experts model by combining the pretrained weights of multiple models.""" + logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) + + if merge_options.cuda: + logging.warning( + '--cuda is a no-op for mergekit-moe, use "--device cuda" instead' + ) + + with open(config_path, "r", encoding="utf-8") as file: + config_source = file.read() + + config = MoEMergeConfig.model_validate(yaml.safe_load(config_source)) + build( + config, + out_path=out_path, + merge_options=merge_options, + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + device=device, + allow_all_same=i_understand_this_is_not_useful_without_training, + verbose=verbose, + ) + + if merge_options.write_model_card: + # TODO: generate a README.md as well + with open( + os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8" + ) as fp: + fp.write(config_source) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index fc79d339..9cba5dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ mergekit-mega = "mergekit.scripts.megamerge:main" mergekit-legacy = "mergekit.scripts.legacy:main" mergekit-layershuffle = "mergekit.scripts.layershuffle:main" bakllama = "mergekit.scripts.bakllama:main" -mergekit-moe = "mergekit.scripts.mixtral_moe:main" +mergekit-moe = "mergekit.scripts.moe:main" mergekit-tokensurgeon = "mergekit.scripts.tokensurgeon:main" mergekit-extract-lora = "mergekit.scripts.extract_lora:main" @@ -48,6 +48,7 @@ packages = [ "mergekit", "mergekit.io", "mergekit.merge_methods", + "mergekit.moe", "mergekit.scripts", "mergekit._data", "mergekit._data.architectures",