Skip to content

AMD-AIG-AIMA/AMD-Diffusion-Distillation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AMD Nitro Diffusion

GitHub

combined_image_1x4

This repository provides training recipes for the AMD Nitro models, a series of efficient text-to-image generation models that are distilled from popular diffusion models on AMD Instinct GPUs.

⚡️ It contains an implementation of the core idea of Latent Adversarial Diffusion Distillation, the method used to build the popular Stable Diffusion 3 Turbo model. Since the original authors didn't provide training code, we release our re-implementation to help advance further research in the field.

The models can be found on HuggingFace:

Compared to the Stable Diffusion 2.1 base model, we achieve 95.9% reduction in FLOPs at the cost of just 2.5% lower CLIP score and 2.2% higher FID.

Model FID ↓ CLIP ↑ FLOPs Latency on AMD Instinct MI250 (sec)
Stable Diffusion 2.1 base, 50 steps (cfg=7.5) 25.47 0.3286 83.04 4.94
Stable Diffusion 2.1 Nitro, 1 step 26.04 0.3204 3.36 0.18

Compared to PixArt-Sigma, our high resolution model achieves a 90.9% reduction in FLOPs at the cost of just 3.7% lower CLIP score and 10.5% higher FID.

Model FID ↓ CLIP ↑ FLOPs Latency on AMD Instinct MI250 (sec)
PixArt-Sigma, 20 steps 34.14 0.3289 187.96 7.46
PixArt-Sigma Nitro, 1 step 37.75 0.3167 17.04 0.53

Environment

The codebase in implemented using PyTorch. Follow the official instructions to install it in your compute environment.

Docker image

When running on AMD Instinct GPUs, the easiest way to start is using the docker images. Pull the following docker image from docker hub:

docker pull rocm/pytorch:rocm6.1.3_ubuntu22.04_py3.10_pytorch_release-2.1.2 

Dependencies

Install the core python libraries by:

pip install diffusers==0.29.2 transformers accelerate wandb torchmetrics pycocotools torchmetrics[image] open-clip-torch

Synthetic data generation

Our models are distilled using synthetic data generated from the base models using prompts from DiffusionDB. Follow the instructions in their repo to extract prompts from the dataset and prepare a .txt file where each line corresponds to a prompt.

We provide a sample list data/sample_prompts.txt as an example.

Generating data from Stable Diffusion 2.1 base

bash scripts/run_gen_data.sh

Generating data from PixArt-Sigma

bash scripts/run_gen_data_pixart.sh

Please remember to correctly set "PROMPT_PATH" and "OUT_FOLDER" in the scripts.

Train models

Use the following bash script to perform distillation:

bash scripts/run_train.sh

You will need to set:

  • MODEL_NAME: the base model from which you want to distill an efficient model
  • DATA_ROOT: the data folder that was generated in the previous step
  • Huggingface Accelerate parameters according to your training setup to use the correct number of GPUs and batchsize. You may refer to Accelerate CLI for more details.

Generate images

The distilled models generated by the training script are saved in Diffusers format. Use the following code snippets to perform inference with them:

Stable Diffusion 2.1 Nitro

from diffusers import DDPMScheduler, DiffusionPipeline
import torch

scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler)

ckpt_path = '<path to distilled checkpoint>'
unet_state_dict = torch.load(ckpt_path)
pipe.unet.load_state_dict(unet_state_dict)
pipe = pipe.to("cuda")

image = pipe(prompt='a photo of a cat',
             num_inference_steps=1,
             guidance_scale=0,
             timesteps=[999]).images[0]

PixArt-Sigma Nitro

from diffusers import PixArtSigmaPipeline
import torch
from safetensors.torch import load_file

pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS")

ckpt_path = '<path to distilled checkpoint>'
transformer_state_dict = load_file(ckpt_path)
pipe.transformer.load_state_dict(transformer_state_dict)
pipe = pipe.to("cuda")

image = pipe(prompt='a photo of a cat',
             num_inference_steps=1,
             guidance_scale=0,
             timesteps=[400]).images[0]

Evaluation

COCO dataset

Download COCO val2017 images from here and annotations from here

Create a root folder called coco and unzip these two files into this folder. The folder structure looks like:

coco/
├── val2017/
└── annotations/

To evaluate the model, run:

bash scripts/run_eval.sh

Please correctly set variables in this script including COCO_ROOT, CKPT_PATH, MODEL, etc. The script will generate 5k images based on 5k unique given prompts from the COCO val2017 dataset, and calculate FID and CLIP scores based on these generated images.

License

Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

About

Re-implementation of Adversarial Diffusion Distillation by AMD

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published