Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add [Mamba] model #28086

Closed
2 tasks done
JLTastet opened this issue Dec 15, 2023 · 4 comments · Fixed by #28094
Closed
2 tasks done

Add [Mamba] model #28086

JLTastet opened this issue Dec 15, 2023 · 4 comments · Fixed by #28094
Assignees

Comments

@JLTastet
Copy link

Model description

Mamba is a new architecture proposed in arXiv:2312.00752 by Albert Gu (CMU) and Tri Dao (Princeton).

It is inspired by structured state space models (SSMs), but with the addition of a selection mechanism that allows it to combines the ability of transformers to perform content-based reasoning with the performance of SSMs on long sequences. Mamba can be efficiently trained in parallel while also enjoying efficient inference by running recurrently.

The paper claims SoTA performance on various modalities, with performance tested up to 2.8B parameters. Crucially, the model cannot be implemented efficiently using only PyTorch operations; instead, it relies on optimised CUDA and triton kernels.

The original implementation by the authors is available at https://github.com/state-spaces/mamba/tree/main under an Apache 2.0 license.

Starting from their implementation, I have started porting the model to 🤗 Transformers. This is work in progress 🚧, and can be found in my fork at https://github.com/JLTastet/transformers/tree/mamba.

I can open a PR, but in its current state my branch is not ready to be merged. I will also open an issue in the original repo to let the authors know about this, in case they want to chime in.

What I got working:

  • Forward and backward passes.
  • Loading checkpoints from the Hub using AutoModel.

What still needs some work:

  • Even though backprop itself works, I get some CUDA errors when using Trainer, and I still don’t understand what causes them.
  • Compiling the CUDA kernels takes ~1 hour. This does not happen with the original package, so I think they are using prebuilt binaries. I didn’t manage to port that part so far.
  • I don’t think there is any non-CUDA fallback path, so this model probably cannot run without CUDA in its current form.
  • When using generate, we should check that the optimised recurrent inference is used instead of the slower autoregressive inference.
  • Tests, tests and moar tests.
  • Most of the documentation needs to be written.
  • Add the relevant dependencies.
  • The code could certainly benefit from some cleanup (remove dead code, many TODO’s, update copyright notices, ...).

I am opening this issue to avoid duplicating work, since I saw some mention of Mamba today by @ArthurZucker.

My main motivation for porting this model is to learn a bit more about it (and about the internals of 🤗 Transformers) and to run more evals. Some of you probably know this library much better than me, so feel free to write your own implementation if you can do it better or quicker. Otherwise, don’t hesitate to build on top of my fork.

Open source status

  • The model implementation is available
  • The model weights are available

Provide useful links for the implementation

@ArthurZucker
Copy link
Collaborator

Thanks for opening this issue! Given the sensitivity of this model, the HF team will take it over, we'll have a look at your fork and add you as a co-other 🤗

@JLTastet
Copy link
Author

Thanks a lot!

My fork is largely inspired from the original Mamba repo, the differences mostly consisting in boilerplate code. So don’t hesitate to start from the upstream repo.

I (and the linter) have noticed a couple of bugs or pieces of dead code in the upstream (some of which remain in my fork). So keep an eye for them!

@LegallyCoder
Copy link

I did a similar study https://github.com/LegallyCoder/mamba-hf .
I'm working on this too.

@ankhzet
Copy link

ankhzet commented Jan 16, 2024

I've seen a CPU only implementation fork mentioned somewhere in the source repo issues. The author of the fork removed Triton and CUDA dependencies.

Found it: https://github.com/kroggen/mamba-cpu
Training is not working there, tho. Maybe you can get in touch with the author.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants