From 3bf9ea9ec4b419b74d4dbf8deae069f12b5aeea1 Mon Sep 17 00:00:00 2001 From: Majid Dadashi Date: Tue, 27 Aug 2024 11:03:19 -0700 Subject: [PATCH] Fix formatting and style in pipeline.py. PiperOrigin-RevId: 668058319 --- .../examples/stable_diffusion/pipeline.py | 102 ++++++++---------- 1 file changed, 43 insertions(+), 59 deletions(-) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/pipeline.py b/ai_edge_torch/generative/examples/stable_diffusion/pipeline.py index df7293e1..fb3b26ef 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/pipeline.py @@ -15,16 +15,16 @@ import argparse import os -from pathlib import Path -from typing import Dict, Optional +import pathlib +from typing import Optional -import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers -from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA -import ai_edge_torch.generative.examples.stable_diffusion.util as util -from ai_edge_torch.model import TfLiteModel +import ai_edge_torch +from ai_edge_torch.generative.examples.stable_diffusion import samplers +from ai_edge_torch.generative.examples.stable_diffusion import tokenizer +from ai_edge_torch.generative.examples.stable_diffusion import util import numpy as np from PIL import Image -from tqdm import tqdm +import tqdm arg_parser = argparse.ArgumentParser() arg_parser.add_argument( @@ -104,12 +104,12 @@ def __init__( diffusion_ckpt: str, decoder_ckpt: str ): - self.tokenizer = Tokenizer(tokenizer_vocab_dir) - self.clip = TfLiteModel.load(clip_ckpt) - self.decoder = TfLiteModel.load(decoder_ckpt) - self.diffusion = TfLiteModel.load(diffusion_ckpt) + self.tokenizer = tokenizer.Tokenizer(tokenizer_vocab_dir) + self.clip = ai_edge_torch.model.TfLiteModel.load(clip_ckpt) + self.decoder = ai_edge_torch.model.TfLiteModel.load(decoder_ckpt) + self.diffusion = ai_edge_torch.model.TfLiteModel.load(diffusion_ckpt) if encoder_ckpt is not None: - self.encoder = TfLiteModel.load(encoder_ckpt) + self.encoder = ai_edge_torch.model.TfLiteModel.load(encoder_ckpt) def run_tflite_pipeline( @@ -128,48 +128,32 @@ def run_tflite_pipeline( ): """Run stable diffusion pipeline with tflite model. - model: - - StableDiffsuion model. - prompt: - The prompt to guide the image generation. - output_path: - The path to the generated output image. - uncond_prompt: - The prompt not to guide the image generation. - cfg_scale: - Guidance scale of classifier-free guidance. Higher guidance scale encourages - to generate - images that are closely linked to the text `prompt`, usually at the expense - of lower - image quality. - height: - The height in pixels of the generated image. - width: - The width in pixels of the generated image. - sampler: - A sampler to be used to denoise the encoded image latents. Can be one of - `k_lms, `k_euler`, - or `k_euler_ancestral`. - n_inference_steps: - The number of denoising steps. More denoising steps usually lead to a higher - quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - seed: - A seed to make generation deterministic. - strength: - Conceptually, indicates how much to transform the reference `input_image`. - Must be between 0 and 1. - `input_image` will be used as a starting point, adding more noise to it the - larger the `strength`. - The number of denoising steps depends on the amount of noise initially - added. When `strength` is 1, - added noise will be maximum and the denoising process will run for the full - number of iterations - specified in `n_inference_steps`. A value of 1, therefore, essentially - ignores `input_image`. - input_image: - Image which is served as the starting point for the image generation. + Args: + model: StableDiffsuion model. + prompt: The prompt to guide the image generation. + output_path: The path to the generated output image. + uncond_prompt: The prompt not to guide the image generation. + cfg_scale: Guidance scale of classifier-free guidance. Higher guidance scale + encourages to generate images that are closely linked to the text + `prompt`, usually at the expense of lower image quality. + height: The height in pixels of the generated image. + width: The width in pixels of the generated image. + sampler: A sampler to be used to denoise the encoded image latents. Can be + one of `k_lms, `k_euler`, or `k_euler_ancestral`. + n_inference_steps: The number of denoising steps. More denoising steps + usually lead to a higher quality image at the expense of slower inference. + This parameter will be modulated by `strength`. + seed: A seed to make generation deterministic. + strength: Conceptually, indicates how much to transform the reference + `input_image`. Must be between 0 and 1. `input_image` will be used as a + starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. + When `strength` is 1, added noise will be maximum and the denoising + process will run for the full number of iterations specified in + `n_inference_steps`. A value of 1, therefore, essentially ignores + `input_image`. + input_image: Image which is served as the starting point for the image + generation. """ if not 0 < strength < 1: raise ValueError('strength must be between 0 and 1') @@ -202,7 +186,8 @@ def run_tflite_pipeline( context = np.concatenate([cond_context, uncond_context], axis=0) noise_shape = (1, 4, height // 8, width // 8) - # Initialization starts from input_image if any, otherwise, starts from a random sampling. + # Initialization starts from input_image if any, otherwise, starts from a + # random sampling. if input_image: if not hasattr(model, 'encoder'): raise AttributeError( @@ -210,7 +195,6 @@ def run_tflite_pipeline( ' input_image.' ) input_image = input_image.resize((width, height)) - input_image_np = np.array(input_image).astype(np.float32) input_image_np = util.rescale(input_image, (0, 255), (-1, 1)) input_image_np = util.move_channel(input_image_np, to='first') encoder_noise = np.random.normal(size=noise_shape).astype(np.float32) @@ -223,8 +207,8 @@ def run_tflite_pipeline( latents *= sampler.initial_scale # Diffusion process. - timesteps = tqdm(sampler.timesteps) - for i, timestep in enumerate(timesteps): + timesteps = tqdm.tqdm(sampler.timesteps) + for _, timestep in enumerate(timesteps): time_embedding = util.get_time_embedding(timestep) input_latents = latents * sampler.get_input_scale() @@ -242,7 +226,7 @@ def run_tflite_pipeline( images = util.rescale(images, (-1, 1), (0, 255), clamp=True) images = util.move_channel(images, to='last') if not os.path.exists(output_path): - Path(output_path).parent.mkdir(parents=True, exist_ok=True) + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) Image.fromarray(images[0].astype(np.uint8)).save(output_path)