Skip to content

Implementation of Diffusion Transformers and Rectified Flow in Jax

Notifications You must be signed in to change notification settings

brianfitzgerald/jax-mmdit

Repository files navigation

MMDiT in Jax

Inference with a trained MMDiT model on the Oxford Flowers dataset.

This is a reproduction of the Diffusion Transformer architecture, along with the Rectified Flow sampling method, in Jax.

Features

  • Training configs for a number of popular reference datasets - CIFAR-10, ImageNet, MNIST, etc. It should be trivial to add additional datasets from HuggingFace.
  • DDP support for multi-GPU training (FSDP planned in the near future)
  • Mixed precision training, checkpointing, profiler, and TensorBoard logging.
  • Minimal dependencies - only Jax, Flax, HuggingFace Datasets, and some utility libraries are required.

Results

I was able to achieve comparable validation loss with MNIST and CIFAR-10 to other DiT implementations. I plan to train on ImageNet soon, and will update this README with the results.

Inference on MNIST

Inference on the CIFAR-10 dataset.

Why was this project written?

I wanted to brush up on my Jax knowledge, and also hadn't implemented a full MMDiT from scratch before. So I figured I'd try to do both at once! :)

Setup and Usage

pip install -r requirements.txt
python train.py

train.py accepts the following CLI args:

n_epochs: Number of epochs to train for.
batch_size: Batch size for training.
learning_rate: Learning rate for the optimizer.
eval_save_steps: Number of steps between evaluation runs and checkpoint saves.
n_eval_batches: Number of batches to evaluate on.
sample_every_n: Number of epochs between sampling runs.
dataset_name: Name of the dataset config to select, valid options are in DATASET_CONFIGS.
profile: Run a single train and eval step, and print out the cost analysis, then exit.
half_precision: case the model to fp16 for training.

TensorBoard is used for logging. Samples will be logged to the samples directory, with the X dimension representing batch and Y dimension representing each iteration of the sampling loop.

Codebase structure

model.py contains the DiT implementation. I made the following changes to the original implementation:

  • Use a GroupNorm layer in the PatchEmbed block.
  • Replace SiLU with swish activation (which are functionally equivalent).
  • I added an absolute position embedding - this seems to signficiantly improve training stability with smaller latent / image sizes. Can be easily disabled.

train.py contains the training loop, and the main entrypoint for the project. Call the main() function to run the training loop; additionally, the Trainer class can be used to load and train / inference the model independently of this loop.

I decided to not use either Tensorflow or PyTorch data loading, to keep external dependencies to a minimum. Instead datasets are loaded using Datasets, and processed with the process_batch function. To add a new dataset, simply add a new entry to the DATASET_CONFIGS dictionary in train.py.

ImageNet Training

ImageNet images are resized to 224x224, and then encoded with the SDXL VAE. You can use the convert_imagenet.py script to encode the images, and then use the imagenet dataset config to train on the dataset.

Acknowledgements

I heavily relied on Simo Ryu's minRF repo for the original implementation, especially the sampling code for rectified flow; as well as the original DiT repo.

I also learned a lot from Rafi Witten's great High Performance LLMs in Jax course, especially multi-device training in Jax.

Todo

  • Use Grain or similar for faster data loading with multiple workers.
  • Implement a VAE to allow training on larger image sizes.
  • Add support for FSDP, and fp8 training.
  • Stable Diffusion 3 support, when the model is released.

Disclaimer: I work at Stability AI, but this project is not affiliated with the company in any way, and the implementation here is fairly different from SD3's architecture.

About

Implementation of Diffusion Transformers and Rectified Flow in Jax

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages