Skip to content

Commit

Permalink
Fix formatting and style in pipeline.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668058319
  • Loading branch information
majiddadashi authored and copybara-github committed Aug 27, 2024
1 parent 46c4a3c commit 3bf9ea9
Showing 1 changed file with 43 additions and 59 deletions.
102 changes: 43 additions & 59 deletions ai_edge_torch/generative/examples/stable_diffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@

import argparse
import os
from pathlib import Path
from typing import Dict, Optional
import pathlib
from typing import Optional

import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
import ai_edge_torch.generative.examples.stable_diffusion.util as util
from ai_edge_torch.model import TfLiteModel
import ai_edge_torch
from ai_edge_torch.generative.examples.stable_diffusion import samplers
from ai_edge_torch.generative.examples.stable_diffusion import tokenizer
from ai_edge_torch.generative.examples.stable_diffusion import util
import numpy as np
from PIL import Image
from tqdm import tqdm
import tqdm

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
Expand Down Expand Up @@ -104,12 +104,12 @@ def __init__(
diffusion_ckpt: str,
decoder_ckpt: str
):
self.tokenizer = Tokenizer(tokenizer_vocab_dir)
self.clip = TfLiteModel.load(clip_ckpt)
self.decoder = TfLiteModel.load(decoder_ckpt)
self.diffusion = TfLiteModel.load(diffusion_ckpt)
self.tokenizer = tokenizer.Tokenizer(tokenizer_vocab_dir)
self.clip = ai_edge_torch.model.TfLiteModel.load(clip_ckpt)
self.decoder = ai_edge_torch.model.TfLiteModel.load(decoder_ckpt)
self.diffusion = ai_edge_torch.model.TfLiteModel.load(diffusion_ckpt)
if encoder_ckpt is not None:
self.encoder = TfLiteModel.load(encoder_ckpt)
self.encoder = ai_edge_torch.model.TfLiteModel.load(encoder_ckpt)


def run_tflite_pipeline(
Expand All @@ -128,48 +128,32 @@ def run_tflite_pipeline(
):
"""Run stable diffusion pipeline with tflite model.
model:
StableDiffsuion model.
prompt:
The prompt to guide the image generation.
output_path:
The path to the generated output image.
uncond_prompt:
The prompt not to guide the image generation.
cfg_scale:
Guidance scale of classifier-free guidance. Higher guidance scale encourages
to generate
images that are closely linked to the text `prompt`, usually at the expense
of lower
image quality.
height:
The height in pixels of the generated image.
width:
The width in pixels of the generated image.
sampler:
A sampler to be used to denoise the encoded image latents. Can be one of
`k_lms, `k_euler`,
or `k_euler_ancestral`.
n_inference_steps:
The number of denoising steps. More denoising steps usually lead to a higher
quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
seed:
A seed to make generation deterministic.
strength:
Conceptually, indicates how much to transform the reference `input_image`.
Must be between 0 and 1.
`input_image` will be used as a starting point, adding more noise to it the
larger the `strength`.
The number of denoising steps depends on the amount of noise initially
added. When `strength` is 1,
added noise will be maximum and the denoising process will run for the full
number of iterations
specified in `n_inference_steps`. A value of 1, therefore, essentially
ignores `input_image`.
input_image:
Image which is served as the starting point for the image generation.
Args:
model: StableDiffsuion model.
prompt: The prompt to guide the image generation.
output_path: The path to the generated output image.
uncond_prompt: The prompt not to guide the image generation.
cfg_scale: Guidance scale of classifier-free guidance. Higher guidance scale
encourages to generate images that are closely linked to the text
`prompt`, usually at the expense of lower image quality.
height: The height in pixels of the generated image.
width: The width in pixels of the generated image.
sampler: A sampler to be used to denoise the encoded image latents. Can be
one of `k_lms, `k_euler`, or `k_euler_ancestral`.
n_inference_steps: The number of denoising steps. More denoising steps
usually lead to a higher quality image at the expense of slower inference.
This parameter will be modulated by `strength`.
seed: A seed to make generation deterministic.
strength: Conceptually, indicates how much to transform the reference
`input_image`. Must be between 0 and 1. `input_image` will be used as a
starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added.
When `strength` is 1, added noise will be maximum and the denoising
process will run for the full number of iterations specified in
`n_inference_steps`. A value of 1, therefore, essentially ignores
`input_image`.
input_image: Image which is served as the starting point for the image
generation.
"""
if not 0 < strength < 1:
raise ValueError('strength must be between 0 and 1')
Expand Down Expand Up @@ -202,15 +186,15 @@ def run_tflite_pipeline(
context = np.concatenate([cond_context, uncond_context], axis=0)
noise_shape = (1, 4, height // 8, width // 8)

# Initialization starts from input_image if any, otherwise, starts from a random sampling.
# Initialization starts from input_image if any, otherwise, starts from a
# random sampling.
if input_image:
if not hasattr(model, 'encoder'):
raise AttributeError(
'Stable Diffusion must be initialized with encoder to accept'
' input_image.'
)
input_image = input_image.resize((width, height))
input_image_np = np.array(input_image).astype(np.float32)
input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
input_image_np = util.move_channel(input_image_np, to='first')
encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
Expand All @@ -223,8 +207,8 @@ def run_tflite_pipeline(
latents *= sampler.initial_scale

# Diffusion process.
timesteps = tqdm(sampler.timesteps)
for i, timestep in enumerate(timesteps):
timesteps = tqdm.tqdm(sampler.timesteps)
for _, timestep in enumerate(timesteps):
time_embedding = util.get_time_embedding(timestep)

input_latents = latents * sampler.get_input_scale()
Expand All @@ -242,7 +226,7 @@ def run_tflite_pipeline(
images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
images = util.move_channel(images, to='last')
if not os.path.exists(output_path):
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(images[0].astype(np.uint8)).save(output_path)


Expand Down

0 comments on commit 3bf9ea9

Please sign in to comment.