Skip to content

Commit

Permalink
Add convet_to_tflite script for stable diffusion 1.5 model (#18)
Browse files Browse the repository at this point in the history
* Migrate stable diffusion example to ai-torch-edge
Refactoring will follow to move the reused modules into layers directory.

* Add convet_to_tflite script for stable diffusion 1.5 model.

* Add a todo for converting SD to multi signature tflite model.
  • Loading branch information
yichunk authored May 29, 2024
1 parent 7f52f70 commit 3a169a0
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 0 deletions.
1 change: 1 addition & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self):
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
self.layernorm = nn.LayerNorm(768)

@torch.inference_mode
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
from pathlib import Path

import torch

import ai_edge_torch
from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP
from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
import ai_edge_torch.generative.examples.stable_diffusion.util as util


@torch.inference_mode
def convert_stable_diffusion_to_tflite(
clip_ckpt_path: str,
encoder_ckpt_path: str,
diffusion_ckpt_path: str,
decoder_ckpt_path: str,
image_height: int = 512,
image_width: int = 512,
):

clip = CLIP()
clip.load_state_dict(torch.load(clip_ckpt_path))

encoder = Encoder()
encoder.load_state_dict(torch.load(encoder_ckpt_path))

diffusion = Diffusion()
diffusion.load_state_dict(torch.load(diffusion_ckpt_path))

decoder = Decoder()
decoder.load_state_dict(torch.load(decoder_ckpt_path))

# Tensors used to trace the model graph during conversion.
n_tokens = 77
timestamp = 0
len_prompt = 1
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
noise = torch.full(
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
)

input_latents = encoder(input_image, noise)
context_cond = clip(prompt_tokens)
context_uncond = torch.zeros_like(context_cond)
context = torch.cat([context_cond, context_uncond], axis=0)
time_embedding = util.get_time_embedding(timestamp)

# CLIP text encoder
ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export(
'/tmp/stable_diffusion/clip.tflite'
)

# TODO(yichunk): convert to multi signature tflite model.
# Image encoder
ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
'/tmp/stable_diffusion/encoder.tflite'
)

# Diffusion
ai_edge_torch.signature(
'diffusion',
diffusion,
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
).convert().export('/tmp/stable_diffusion/diffusion.tflite')

# Image decoder
ai_edge_torch.signature('decode', decoder, (input_latents,)).convert().export(
'/tmp/stable_diffusion/decoder.tflite'
)


if __name__ == '__main__':
convert_stable_diffusion_to_tflite(
clip_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/clip.pt'
),
encoder_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt'
),
diffusion_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt'
),
decoder_ckpt_path=os.path.join(
Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt'
),
image_height=512,
image_width=512,
)
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import torch
from torch import nn
from torch.nn import functional as F

Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(self):
nn.Conv2d(128, 3, kernel_size=3, padding=1),
)

@torch.inference_mode
def forward(self, x):
x = x / 0.18215
for module in self:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def __init__(self):
self.unet = UNet()
self.final = FinalLayer(320, 4)

@torch.inference_mode
def forward(self, latent, context, time):
time = self.time_embedding(time)
# print('time:')
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ai_edge_torch.generative.examples.stable_diffusion.decoder import AttentionBlock # NOQA
from ai_edge_torch.generative.examples.stable_diffusion.decoder import ResidualBlock # NOQA
import ai_edge_torch.generative.utilities.loader as loading_utils


class Encoder(nn.Sequential):
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self):
nn.Conv2d(8, 8, kernel_size=1, padding=0),
)

@torch.inference_mode
def forward(self, x, noise):
for module in self:
if getattr(module, 'stride', None) == (
Expand Down

0 comments on commit 3a169a0

Please sign in to comment.