-
Notifications
You must be signed in to change notification settings - Fork 27.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a converter from mamba_ssm -> huggingface mamba (#29705)
* implement convert_mamba_ssm_checkpoint_to_pytorch * Add test test_model_from_mamba_ssm_conversion * moved convert_ssm_config_to_hf_config to inside mamba_ssm_available check * fix skipif clause * moved skips to inside test since skipif decorator isn't working for some reason * Added validation * removed test * fixup * only compare logits * remove weight rename * Update src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py Co-authored-by: amyeroberts <[email protected]> * nits --------- Co-authored-by: amyeroberts <[email protected]>
- Loading branch information
1 parent
2fcb56a
commit 3e83213
Showing
1 changed file
with
153 additions
and
0 deletions.
There are no files selected for viewing
153 changes: 153 additions & 0 deletions
153
src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba_ssm` package to be installed.""" | ||
|
||
import argparse | ||
import json | ||
import math | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from transformers import AutoTokenizer, MambaConfig, MambaForCausalLM | ||
from transformers.utils import logging | ||
from transformers.utils.import_utils import is_mamba_ssm_available | ||
|
||
|
||
if is_mamba_ssm_available(): | ||
from mamba_ssm.models.config_mamba import MambaConfig as MambaConfigSSM | ||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | ||
|
||
def convert_ssm_config_to_hf_config(config_ssm: MambaConfigSSM) -> MambaConfig: | ||
"""Convert a MambaConfig from mamba_ssm to a MambaConfig from transformers.""" | ||
hf_config = MambaConfig() | ||
# Set config hidden size, num hidden layers, and vocab size directly from the original config | ||
hf_config.hidden_size = config_ssm.d_model | ||
hf_config.intermediate_size = config_ssm.d_model * 2 | ||
hf_config.time_step_rank = math.ceil(config_ssm.d_model / 16) | ||
|
||
hf_config.num_hidden_layers = config_ssm.n_layer | ||
vocab_size = config_ssm.vocab_size | ||
pad_vocab_size_multiple = config_ssm.pad_vocab_size_multiple | ||
if (vocab_size % pad_vocab_size_multiple) != 0: | ||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) | ||
hf_config.vocab_size = vocab_size | ||
return hf_config | ||
|
||
|
||
logging.set_verbosity_info() | ||
logger = logging.get_logger(__name__) | ||
|
||
|
||
def convert_mamba_ssm_checkpoint_to_huggingface_model( | ||
original_state_dict: dict, original_ssm_config_dict: dict | ||
) -> Tuple[MambaForCausalLM, AutoTokenizer]: | ||
if not is_mamba_ssm_available(): | ||
raise ImportError( | ||
"Calling convert_mamba_ssm_checkpoint_to_huggingface_model requires the mamba_ssm library to be installed. Please install it with `pip install mamba_ssm`." | ||
) | ||
original_ssm_config = MambaConfigSSM(**original_ssm_config_dict) | ||
|
||
# Convert mamba_ssm config to huggingface MambaConfig | ||
hf_config = convert_ssm_config_to_hf_config(original_ssm_config) | ||
|
||
# No weights need to be renamed between the two models. | ||
converted_state_dict = original_state_dict | ||
|
||
# Load reshaped state dict into a huggingface model. | ||
hf_model = MambaForCausalLM(hf_config) | ||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | ||
hf_model.load_state_dict(converted_state_dict) | ||
return (hf_model, tokenizer) | ||
|
||
|
||
def validate_converted_model( | ||
original_state_dict: dict, original_ssm_config_dict: dict, hf_model: MambaForCausalLM, tokenizer: AutoTokenizer | ||
) -> None: | ||
"""Validate the converted model returns the same output as the original model.""" | ||
torch_device = "cuda" | ||
|
||
original_config = MambaConfigSSM(**original_ssm_config_dict) | ||
original_model = MambaLMHeadModel(original_config).to(torch_device) | ||
original_model.load_state_dict(original_state_dict) | ||
|
||
hf_model = hf_model.to(torch_device) | ||
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) | ||
# Assert model logits are close | ||
with torch.no_grad(): | ||
original_model_logits = original_model(input_ids).logits | ||
hf_model_logits = hf_model(input_ids).logits | ||
if not torch.allclose(original_model_logits, hf_model_logits, atol=1e-3): | ||
raise ValueError("The converted model did not return the same logits as the original model.") | ||
|
||
logger.info("Model conversion validated successfully.") | ||
|
||
|
||
def convert_mamba_checkpoint_file_to_huggingface_model_file( | ||
mamba_checkpoint_path: str, config_json_file: str, output_dir: str | ||
) -> None: | ||
if not is_mamba_ssm_available(): | ||
raise ImportError( | ||
"Calling convert_mamba_checkpoint_file_to_huggingface_model_file requires the mamba_ssm library to be installed. Please install it with `pip install mamba_ssm`." | ||
) | ||
if not torch.cuda.is_available(): | ||
raise ValueError( | ||
"This script is to be run with a CUDA device, as the original mamba_ssm model does not support cpu." | ||
) | ||
logger.info(f"Loading model from {mamba_checkpoint_path} based on config from {config_json_file}") | ||
# Load weights and config from paths | ||
original_state_dict = torch.load(mamba_checkpoint_path, map_location="cpu") | ||
with open(config_json_file, "r", encoding="utf-8") as json_file: | ||
original_ssm_config_dict = json.load(json_file) | ||
|
||
# Convert the model | ||
hf_model, tokenizer = convert_mamba_ssm_checkpoint_to_huggingface_model( | ||
original_state_dict, original_ssm_config_dict | ||
) | ||
|
||
# Validate the conversion | ||
validate_converted_model(original_state_dict, original_ssm_config_dict, hf_model, tokenizer) | ||
|
||
logger.info(f"Model converted successfully. Saving model to {output_dir}") | ||
|
||
# Save new model to pytorch_dump_path | ||
hf_model.save_pretrained(output_dir) | ||
tokenizer.save_pretrained(output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-i", | ||
"--mamba_checkpoint_file", | ||
type=str, | ||
required=True, | ||
help="Path to a `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", | ||
) | ||
parser.add_argument( | ||
"-c", | ||
"--config_json_file", | ||
type=str, | ||
required=True, | ||
help="Path to a `config.json` file corresponding to a MambaConfig of the original mamba_ssm model.", | ||
) | ||
parser.add_argument( | ||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." | ||
) | ||
args = parser.parse_args() | ||
|
||
convert_mamba_checkpoint_file_to_huggingface_model_file( | ||
args.mamba_checkpoint_file, args.config_json_file, args.output_dir | ||
) |