Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Support Mamba 2 blocks #68

Open
tscholak opened this issue Nov 25, 2024 · 0 comments
Open

[feat] Support Mamba 2 blocks #68

tscholak opened this issue Nov 25, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@tscholak
Copy link
Collaborator

tscholak commented Nov 25, 2024

🧐 Problem Description

Fast-LLM currently lacks support for models using Mamba 2 blocks, which are a cornerstone of modern state-space models (SSMs) like Zamba 2 and NVIDIA's Hymba. These architectures interleave Mamba 2 with transformer layers, offering faster training, low inference latency, and reduced compute/memory footprint compared to traditional transformers. Without Mamba 2 support, Fast-LLM misses out on compatibility with cutting-edge hybrid architectures, limiting its ability to adopt and continually pretrain models like Zamba 2 or similar architectures.

💡 Proposed Solution

Implement Mamba 2 blocks in Fast-LLM in a phased approach:

  1. Initial PyTorch Integration:

    • Port a reference implementation of Mamba 2 blocks (e.g., from Zamba 2) into Fast-LLM.
    • Focus on enabling basic training functionality for SSM+transformer hybrids.
    • Prioritize correctness over performance in early iterations.
  2. Support Hybrid Architectures:

    • Enable pretraining and fine-tuning of architectures like Zamba 2, which interleave Mamba 2 with shared transformer blocks.
    • Add functionality to load pretrained Zamba 2 models into Fast-LLM for continual pretraining.
  3. Optimization for Performance:

    • Address performance bottlenecks by replacing generic PyTorch implementations with optimized CUDA or Triton kernels.
    • Reuse existing Mamba 2 CUDA kernels where feasible, or reimplement them in Triton for better scalability and flexibility.
  4. Extend to Other Architectures:

    • Evaluate Hymba and other hybrid architectures to determine additional requirements for their support.

By progressing in iterative stages, we can implement Mamba 2 support incrementally, ensuring stability, performance, and alignment with Fast-LLM's broader architecture.

🔄 Alternatives Considered

Avoiding Fast-LLM for training these models is an option, but it undermines the framework’s role as our go-to tool for pretraining in the Foundation Models Lab. While we may reuse existing implementations if they integrate smoothly, it's more likely we'll eventually develop our own Mamba 2 implementation, similar to how we approached dropless MoE.

📈 Potential Benefits

Mamba 2 blocks bring computational efficiency, enabling faster training and inference than traditional transformers. Supporting these blocks enhances Fast-LLM's capabilities, making it more competitive and unlocking new research opportunities.

📝 Additional Context

@tscholak tscholak added the enhancement New feature or request label Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant