Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a converter from mamba_ssm -> huggingface mamba #29705

Merged
merged 18 commits into from
Apr 4, 2024

Conversation

byi8220
Copy link
Contributor

@byi8220 byi8220 commented Mar 17, 2024

What does this PR do?

Fixes #29631

This PR creates a new command line script convert_mamba_ssm_checkpoint_to_pytorch.py which converts model checkpoints created by the state-spaces/mamba repo into a Huggingface MambaForCausalLM model.

The intended usage of this script is:

python src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py \
--mamba_checkpoint_file=path/to/pytorch_model.bin \
--config_json_file=/path/to/ssm_config.json \
--output_dir=/path/to/out/dir

This script has a dependency on the mamba_ssm package.

Testing

A validation pass is performed before exporting the model.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Collaborator

Thanks for opening this PR @byi8220 ! Let us know when it's ready for review 🤗

@byi8220
Copy link
Contributor Author

byi8220 commented Mar 18, 2024

Thanks, I think it should be reviewable now?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Just a comment about having validation in the script rather than a new test

tests/models/mamba/test_modeling_mamba.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Looks good to me, just some small nits. cc @ArthurZucker for reference

@byi8220
Copy link
Contributor Author

byi8220 commented Apr 3, 2024

Addressed nits.

Since there's no unit test for this anymore just did one last sanity test. Works on my machine

  1. My script used to spot test the conversion (https://gist.github.com/byi8220/9b44a5f6c6c2c7533704801478f1760a) passes for the 130m and 790m models. My GPU doesn't have enough memory to run the larger ones.
  2. Downloading the weights and config of a 130m model online and running the command python src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py --mamba_checkpoint_file=/tmp/scratch/pytorch_model.bin --config_json_file=/tmp/scratch/config.json --output_dir=/tmp/scratch/out appears to function e2e.

@amyeroberts
Copy link
Collaborator

Awesome work - thanks again for adding this and for running some sanity checks!

@amyeroberts amyeroberts merged commit 4e6c5eb into huggingface:main Apr 4, 2024
7 checks passed
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* implement convert_mamba_ssm_checkpoint_to_pytorch

* Add test test_model_from_mamba_ssm_conversion

* moved convert_ssm_config_to_hf_config to inside mamba_ssm_available check

* fix skipif clause

* moved skips to inside test since skipif decorator isn't working for some reason

* Added validation

* removed test

* fixup

* only compare logits

* remove weight rename

* Update src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py

Co-authored-by: amyeroberts <[email protected]>

* nits

---------

Co-authored-by: amyeroberts <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* implement convert_mamba_ssm_checkpoint_to_pytorch

* Add test test_model_from_mamba_ssm_conversion

* moved convert_ssm_config_to_hf_config to inside mamba_ssm_available check

* fix skipif clause

* moved skips to inside test since skipif decorator isn't working for some reason

* Added validation

* removed test

* fixup

* only compare logits

* remove weight rename

* Update src/transformers/models/mamba/convert_mamba_ssm_checkpoint_to_pytorch.py

Co-authored-by: amyeroberts <[email protected]>

* nits

---------

Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Conversion Script for Mamba checkpoints (mamba_ssm -> transformers)
2 participants