-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversion Script for Mamba checkpoints (mamba_ssm
-> transformers
)
#29631
Comments
Hi @haileyschoelkopf, thanks for opening this feature request! Normally we have conversion scripts under each model's folder (would be here). I think adsding conversion script sounds like a good idea! @ArthurZucker contributed the model. He's off at the moment, but back soon - I'll let him reply in case there's a good reason we didn't add it alongside the model originally |
Thank you! I'll share here if I get the time to implement one myself! |
If it's okay/not too complicated, could I try and give this a shot (as a new outside contributor)? Admittedly very new to ML stuff, but at a very high level, would this entail implementing conversion scripts similar to something like what's found in other model dirs such as https://github.com/huggingface/transformers/tree/b340d90738fa14bd6f81b65e4148173cbec62ff6/src/transformers/models/bert ? I.e. Just 2 files for the forward and backwards pass |
yep! Basically just load checkpoint file -> convert to a format loadable by the other library, e.g. reshaping or renaming weights as needed -> run |
I've made this if it's helpful for anyone. |
Thanks, that's very helpful. I'll try to get a PR out soon modeled off that. |
hey! The reason I did not add one is because the original checkpoints are compatible. This can still be added but only the config should be inferred / updated! |
Hm, when I try to run a conversion, I get an error suggesting there needs to be a rename:
The unexpected key contains "embedding" (with no s at the end), while the missing key contains "embeddings" (with an s at the end) I've attempted to create a PR which both converts the config and does the above renaming for the forward pass: #29705 (Huge thanks to @SrGonao , my PR does pretty much the same thing as his script, except on local files instead of interacting with the hub) |
Oh no 😅 |
It might have, but I couldn't get it to work without the rename. A quick printout of the original ssm model suggests at least the current version of mamba_ssm works with from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
config = MambaConfig(d_model = 64, n_layer = 8)
model = MambaLMHeadModel(config)
print(model) Outputs:
|
You are right 😢 I have no idea why I did not get any warnings / maybe because the weights are tied it used the Just saw your PR to remove the tie_weights that were forced before in Arf that is really annoying |
I'm confused by what you mean here. I thought that the problem was due to a difference in naming conventions between the two packages, where the transformers library and the mamba_ssm model library just chose to name their embedding layer differently.
Are you referring to @haileyschoelkopf's PR in state-spaces/mamba#211? |
I'll add a loading hook just this once! IMO should be the cleanest way to fix this |
I mean that on my side, when implementing mamba in transformers I did not have a warning about the weights. I suppose that this is because the weights are by default tied. >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280) does not produce any warning. >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280, tie_word_embeddings=False)
Some weights of MambaForCausalLM were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. does. and: >>> from transformers import AutoModel
>>> model = AutoModel.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)
Some weights of MambaModel were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. |
That a really silent bug, and I got tricked by it... |
cc @amyeroberts core issue ! |
Hm, just to make sure I understand:
|
3.This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific? |
|
About 2. I mean avoid having to explicitly convert if you know the config. Config can be initialized first with from_pretrained! But let's still have the conversion script! It will be beneficial to have a mapping between the names and the config explicitly! |
For 1. yes a warning should indeed be issued sorry, we raise error for mismatch sizes! |
Sg, I removed the weight rename from my PR (although now my PR won't actually work until yours is checked in) |
Just a quick plug: I think your PR fixes the checkpointing issue, but PR #29705 is still open for config->config conversion. |
Indeed opening again! |
Feature request
Thanks very much for the Mamba support (#28094), this interoperability is fantastic!
I wanted to ask if there were any utility (doesn't have to be clean, just functional) for converting checkpoints provided for use in the
mamba_ssm
library into the format provided intransformers
.This would be very helpful if it exists! Thanks 🤗
Motivation
I'd like to be able to convert novel trained mamba models from the
state-spaces/mamba
repo into HF transformers without rewriting a conversion script myself if need be.Your contribution
I could write a utility for this if none exists but would probably not have the bandwidth to upstream it.
The text was updated successfully, but these errors were encountered: