diff --git a/ai_edge_torch/generative/examples/stable_diffusion/clip.py b/ai_edge_torch/generative/examples/stable_diffusion/clip.py index 787b0469..e929c701 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/clip.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/clip.py @@ -68,6 +68,7 @@ def __init__(self): self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)]) self.layernorm = nn.LayerNorm(768) + @torch.inference_mode def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: tokens = tokens.type(torch.long) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py new file mode 100644 index 00000000..bb1b4108 --- /dev/null +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -0,0 +1,107 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from pathlib import Path + +import torch + +import ai_edge_torch +from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP +from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder +from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA +from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder +import ai_edge_torch.generative.examples.stable_diffusion.util as util + + +@torch.inference_mode +def convert_stable_diffusion_to_tflite( + clip_ckpt_path: str, + encoder_ckpt_path: str, + diffusion_ckpt_path: str, + decoder_ckpt_path: str, + image_height: int = 512, + image_width: int = 512, +): + + clip = CLIP() + clip.load_state_dict(torch.load(clip_ckpt_path)) + + encoder = Encoder() + encoder.load_state_dict(torch.load(encoder_ckpt_path)) + + diffusion = Diffusion() + diffusion.load_state_dict(torch.load(diffusion_ckpt_path)) + + decoder = Decoder() + decoder.load_state_dict(torch.load(decoder_ckpt_path)) + + # Tensors used to trace the model graph during conversion. + n_tokens = 77 + timestamp = 0 + len_prompt = 1 + prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long) + input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32) + noise = torch.full( + (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32 + ) + + input_latents = encoder(input_image, noise) + context_cond = clip(prompt_tokens) + context_uncond = torch.zeros_like(context_cond) + context = torch.cat([context_cond, context_uncond], axis=0) + time_embedding = util.get_time_embedding(timestamp) + + # CLIP text encoder + ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export( + '/tmp/stable_diffusion/clip.tflite' + ) + + # TODO(yichunk): convert to multi signature tflite model. + # Image encoder + ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export( + '/tmp/stable_diffusion/encoder.tflite' + ) + + # Diffusion + ai_edge_torch.signature( + 'diffusion', + diffusion, + (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding), + ).convert().export('/tmp/stable_diffusion/diffusion.tflite') + + # Image decoder + ai_edge_torch.signature('decode', decoder, (input_latents,)).convert().export( + '/tmp/stable_diffusion/decoder.tflite' + ) + + +if __name__ == '__main__': + convert_stable_diffusion_to_tflite( + clip_ckpt_path=os.path.join( + Path.home(), 'Downloads/stable_diffusion_data/ckpt/clip.pt' + ), + encoder_ckpt_path=os.path.join( + Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt' + ), + diffusion_ckpt_path=os.path.join( + Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt' + ), + decoder_ckpt_path=os.path.join( + Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt' + ), + image_height=512, + image_width=512, + ) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py index bea5c576..9f7f3f8d 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +import torch from torch import nn from torch.nn import functional as F @@ -104,6 +105,7 @@ def __init__(self): nn.Conv2d(128, 3, kernel_size=3, padding=1), ) + @torch.inference_mode def forward(self, x): x = x / 0.18215 for module in self: diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index d4786bdd..2992f3c3 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -429,6 +429,7 @@ def __init__(self): self.unet = UNet() self.final = FinalLayer(320, 4) + @torch.inference_mode def forward(self, latent, context, time): time = self.time_embedding(time) # print('time:') diff --git a/ai_edge_torch/generative/examples/stable_diffusion/encoder.py b/ai_edge_torch/generative/examples/stable_diffusion/encoder.py index 4ab6ccaa..6f8f2794 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/encoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/encoder.py @@ -19,6 +19,7 @@ from ai_edge_torch.generative.examples.stable_diffusion.decoder import AttentionBlock # NOQA from ai_edge_torch.generative.examples.stable_diffusion.decoder import ResidualBlock # NOQA +import ai_edge_torch.generative.utilities.loader as loading_utils class Encoder(nn.Sequential): @@ -46,6 +47,7 @@ def __init__(self): nn.Conv2d(8, 8, kernel_size=1, padding=0), ) + @torch.inference_mode def forward(self, x, noise): for module in self: if getattr(module, 'stride', None) == (