Skip to content

Commit

Permalink
Add safetensors SD 1.5 tensor names mapping (#84)
Browse files Browse the repository at this point in the history
* Add safetensors SD 1.5 tensor names mapping

* Update README

* Remove PyTorch model support of third party implementation

* format
  • Loading branch information
yichunk authored Jul 10, 2024
1 parent 70de3be commit 9140d68
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 207 deletions.
26 changes: 23 additions & 3 deletions ai_edge_torch/generative/examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,33 @@
This example shows how to use the Edge Generative API to convert a PyTorch Stable Diffusion v1.5 model to TFLite model, and run the image generation inference.

## Convert PyTorch to TFLite model
1. Download PyTorch stable diffusion model [stable-diffusion-pytorch](https://github.com/kjsman/stable-diffusion-pytorch)
1. Unzip the downloaded model weight into `$HOME/Downloads/stable_diffusion_data`
1. Run `convert_to_tflite.py`. This will convert the PyTorch models into TFLite models. The stable diffusion model has four components: CLIP (text embedding), encoder, diffusion and decoder models. Each component is converted to a single TFLite model file.
The example provides source checkpoints mapping, from original HuggingFace repo to reauthored PyTorch model.

### SafeTensors model checkpoints from original HuggingFace repo
1. Clone original HuggingFace stable diffusion repo [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
2. Run `convert_to_tflite.py` and use `v1-5-pruned-emaonly.safetensors` as the source checkpoints for the conversion script. The stable diffusion model has four components: CLIP (text embedding), encoder, diffusion and decoder models. Each component is converted to a single TFLite model file. Notice that the conversion of encoder model is not supported yet.
```bash
python ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py \
--clip_ckpt=$HOME/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors \
--diffusion_ckpt=$HOME/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors \
--decoder_ckpt=$HOME/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors \
--output_dir=/tmp/stable_diffusion_safetensors/ \
--ckpt_format=safetensors
```

## Run Stable Diffusion pipeline
1. Use `run_tflite_pipeline` method in `pipeline.py` to trigger the end-to-end stable diffusion pipeline with TFLite model. See the example usage in `pipeline.py` as a script.

```bash
python ai_edge_torch/generative/examples/stable_diffusion/pipeline.py \
--tokenizer_vocab_dir=$HOME/stable-diffusion-v1-5/tokenizer/ \
--clip_ckpt=/tmp/stable_diffusion_safetensors/clip.tflite \
--diffusion_ckpt=/tmp/stable_diffusion_safetensors/diffusion.tflite \
--decoder_ckpt=/tmp/stable_diffusion_safetensors/decoder.tflite \
--output_path=/tmp/sd_result_tflite.jpg \
--n_inference_steps=20
```

Here is an example generated image.

Prompt: "a photograph of an astronaut riding a horse"
Expand Down
22 changes: 12 additions & 10 deletions ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
import ai_edge_torch.generative.utilities.loader as loading_utils

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="layers.{}.linear_1",
ff_down_proj="layers.{}.linear_2",
ff_gate_proj="layers.{}.linear_1",
attn_fused_qkv_proj="layers.{}.attention.in_proj",
attn_output_proj="layers.{}.attention.out_proj",
pre_attn_norm="layers.{}.layernorm_1",
pre_ff_norm="layers.{}.layernorm_2",
embedding="embedding.token_embedding",
embedding_position="embedding.position_value",
final_norm="layernorm",
ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
lm_head=None,
)

Expand Down Expand Up @@ -84,6 +85,7 @@ def get_model_config() -> cfg.ModelConfig:
rotary_percentage=0.0,
qkv_use_bias=True,
qkv_transpose_before_split=True,
qkv_fused_interleaved=False,
output_proj_use_bias=True,
enable_kv_cache=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
# ==============================================================================

import argparse
import os
from pathlib import Path
from typing import Optional

import torch

Expand All @@ -24,38 +26,65 @@
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
import ai_edge_torch.generative.examples.stable_diffusion.util as util
import ai_edge_torch.generative.utilities.loader as loading_utils
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
'--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
)
arg_parser.add_argument(
'--diffusion_ckpt',
type=str,
help='Path to source diffusion model checkpoint',
required=True,
)
arg_parser.add_argument(
'--decoder_ckpt',
type=str,
help='Path to source image decoder model checkpoint',
required=True,
)
arg_parser.add_argument(
'--output_dir',
type=str,
help='Path to the converted TF Lite directory.',
required=True,
)


@torch.inference_mode
def convert_stable_diffusion_to_tflite(
output_dir: str,
clip_ckpt_path: str,
encoder_ckpt_path: str,
diffusion_ckpt_path: str,
decoder_ckpt_path: str,
image_height: int = 512,
image_width: int = 512,
):

clip_model = clip.CLIP(clip.get_model_config())
loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
loader = stable_diffusion_loader.ClipModelLoader(
clip_ckpt_path,
clip.TENSOR_NAMES,
)
loader.load(clip_model, strict=False)

encoder = Encoder()
encoder.load_state_dict(torch.load(encoder_ckpt_path))

diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
diffusion_ckpt_path, diffusion.TENSORS_NAMES
diffusion_ckpt_path, diffusion.TENSOR_NAMES
)
diffusion_loader.load(diffusion_model)
diffusion_loader.load(diffusion_model, strict=False)

decoder_model = decoder.Decoder(decoder.get_model_config())
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
decoder_ckpt_path, decoder.TENSORS_NAMES
decoder_ckpt_path, decoder.TENSOR_NAMES
)
decoder_loader.load(decoder_model)
decoder_loader.load(decoder_model, strict=False)

# TODO(yichunk): enable image encoder conversion
# if encoder_ckpt_path is not None:
# encoder = Encoder()
# encoder.load_state_dict(torch.load(encoder_ckpt_path))

# Tensors used to trace the model graph during conversion.
n_tokens = 77
Expand All @@ -67,50 +96,47 @@ def convert_stable_diffusion_to_tflite(
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
)

input_latents = encoder(input_image, noise)
input_latents = torch.zeros_like(noise)
context_cond = clip_model(prompt_tokens)
context_uncond = torch.zeros_like(context_cond)
context = torch.cat([context_cond, context_uncond], axis=0)
time_embedding = util.get_time_embedding(timestamp)

if not os.path.exists(output_dir):
Path(output_dir).mkdir(parents=True, exist_ok=True)

# TODO(yichunk): convert to multi signature tflite model.
# CLIP text encoder
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
'/tmp/stable_diffusion/clip.tflite'
f'{output_dir}/clip.tflite'
)

# TODO(yichunk): convert to multi signature tflite model.
# TODO(yichunk): enable image encoder conversion
# Image encoder
ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
'/tmp/stable_diffusion/encoder.tflite'
)
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
# f'{output_dir}/encoder.tflite'
# )

# Diffusion
ai_edge_torch.signature(
'diffusion',
diffusion_model,
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
).convert().export('/tmp/stable_diffusion/diffusion.tflite')
).convert().export(f'{output_dir}/diffusion.tflite')

# Image decoder
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
'/tmp/stable_diffusion/decoder.tflite'
f'{output_dir}/decoder.tflite'
)


if __name__ == '__main__':
args = arg_parser.parse_args()
convert_stable_diffusion_to_tflite(
clip_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/clip.pt'
),
encoder_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt'
),
diffusion_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt'
),
decoder_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt'
),
output_dir=args.output_dir,
clip_ckpt_path=args.clip_ckpt,
diffusion_ckpt_path=args.diffusion_ckpt,
decoder_ckpt_path=args.decoder_ckpt,
image_height=512,
image_width=512,
)
Loading

0 comments on commit 9140d68

Please sign in to comment.