Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a converter from mamba_ssm -> huggingface mamba #29705

Merged
merged 18 commits into from
Apr 4, 2024
Merged
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# coding=utf-8
# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba_ssm` package to be installed."""

import argparse
import json
import math
from typing import Tuple

import torch

from transformers import AutoTokenizer, MambaConfig, MambaForCausalLM
from transformers.utils import logging
from transformers.utils.import_utils import is_mamba_ssm_available


if is_mamba_ssm_available():
from mamba_ssm.models.config_mamba import MambaConfig as MambaConfig_ssm
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

def convert_ssm_config_to_hf_config(config_ssm: MambaConfig_ssm) -> MambaConfig:
"""Convert a MambaConfig from mamba_ssm to a MambaConfig from transformers."""
hf_config = MambaConfig()
# Set config hidden size, num hidden layers, and vocab size directly from the original config
hf_config.hidden_size = config_ssm.d_model
hf_config.intermediate_size = config_ssm.d_model * 2
hf_config.time_step_rank = math.ceil(config_ssm.d_model / 16)

hf_config.num_hidden_layers = config_ssm.n_layer
vocab_size = config_ssm.vocab_size
pad_vocab_size_multiple = config_ssm.pad_vocab_size_multiple
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size
return hf_config


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


def convert_mamba_ssm_checkpoint_to_huggingface_model(
original_state_dict: dict, original_ssm_config_dict: dict
) -> Tuple[MambaForCausalLM, AutoTokenizer]:
if not is_mamba_ssm_available():
raise ImportError(
"Calling convert_mamba_ssm_checkpoint_to_huggingface_model requires the mamba_ssm library to be installed. Please install it with `pip install mamba_ssm`."
)
original_ssm_config = MambaConfig_ssm(**original_ssm_config_dict)

# Convert mamba_ssm config to huggingface MambaConfig
hf_config = convert_ssm_config_to_hf_config(original_ssm_config)

# No weights need to be renamed between the two models.
converted_state_dict = original_state_dict

# Load reshaped state dict into a huggingface model.
hf_model = MambaForCausalLM(hf_config)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
hf_model.load_state_dict(converted_state_dict)
return (hf_model, tokenizer)


def validate_converted_model(
original_state_dict: dict, original_ssm_config_dict: dict, hf_model: MambaForCausalLM, tokenizer: AutoTokenizer
) -> None:
"""Validate the converted model returns the same output as the original model."""
if not is_mamba_ssm_available():
raise ImportError(
"Calling validate_converted_model requires the mamba_ssm library to be installed. Please install it with `pip install mamba_ssm`."
)
if not torch.cuda.is_available():
raise ValueError(
"This script is to be run with a CUDA device, as the original mamba_ssm model does not support cpu."
)
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
torch_device = "cuda"

original_config = MambaConfig_ssm(**original_ssm_config_dict)
original_model = MambaLMHeadModel(original_config).to(torch_device)
original_model.load_state_dict(original_state_dict)

hf_model = hf_model.to(torch_device)
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device)
# Assert model logits are close
with torch.no_grad():
original_model_logits = original_model(input_ids).logits
hf_model_logits = hf_model(input_ids).logits
if not torch.allclose(original_model_logits, hf_model_logits, atol=1e-3):
raise ValueError("The converted model did not return the same logits as the original model.")

logger.info("Model conversion validated successfully.")


def convert_mamba_checkpoint_file_to_huggingface_model_file(
mamba_checkpoint_path: str, config_json_file: str, output_dir: str
) -> None:
logger.info(f"Loading model from {mamba_checkpoint_path} based on config from {config_json_file}")
# Load weights and config from paths
original_state_dict = torch.load(mamba_checkpoint_path, map_location="cpu")
with open(config_json_file, "r", encoding="utf-8") as json_file:
original_ssm_config_dict = json.load(json_file)

# Convert the model
hf_model, tokenizer = convert_mamba_ssm_checkpoint_to_huggingface_model(
original_state_dict, original_ssm_config_dict
)

# Validate the conversion
validate_converted_model(original_state_dict, original_ssm_config_dict, hf_model, tokenizer)

logger.info(f"Model converted successfully. Saving model to {output_dir}")

# Save new model to pytorch_dump_path
hf_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba_checkpoint_file",
type=str,
required=True,
help="Path to a `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-c",
"--config_json_file",
type=str,
required=True,
help="Path to a `config.json` file corresponding to a MambaConfig of the original mamba_ssm model.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
args = parser.parse_args()

convert_mamba_checkpoint_file_to_huggingface_model_file(
args.mamba_checkpoint_file, args.config_json_file, args.output_dir
)