Skip to content

Commit

Permalink
Enable Latent Consistency models ONNX export (#1469)
Browse files Browse the repository at this point in the history
* Enable export latent consistency model

* add pipeline

* format

* fix docstring

* modify regex pattern

* remove constraint diffusers version

* fix pipeline

* fix infered task

* add test

* fix style

* add documentation

* format
  • Loading branch information
echarlaix authored Oct 30, 2023
1 parent e164827 commit 01dd5c3
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,8 @@ The following ORT classes are available for the following custom tasks.

[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline
- __call__

#### ORTLatentConsistencyModelPipeline

[[autodoc]] onnxruntime.ORTLatentConsistencyModelPipeline
- __call__
16 changes: 16 additions & 0 deletions docs/source/onnxruntime/usage_guides/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,19 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
image.save("sailing_ship.png")
```



## Latent Consistency Models

### Text-to-Image

Here is an example of how you can load a Latent Consistency Models (LCMs) from [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) and run inference using ONNX Runtime :

```python
from optimum.onnxruntime import ORTLatentConsistencyModelPipeline

model_id = "SimianLuo/LCM_Dreamshaper_v7"
pipeline = ORTLatentConsistencyModelPipeline.from_pretrained(model_id, export=True)
prompt = "sailing ship in storm by Leonardo da Vinci"
images = pipeline(prompt, num_inference_steps=4, guidance_scale=8.0).images
```
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def ordered_inputs(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -
sig = inspect.signature(model.call)

for param in sig.parameters:
param_regex = re.compile(rf"{param}(\.\d*)?")
param_regex = re.compile(rf"{param}(\..*)?$")
to_insert = []
for name, dynamic_axes in inputs.items():
if re.match(param_regex, name):
Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs["timestep_cond"] = {0: "batch_size"}
return common_inputs

@property
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,11 +1408,8 @@ def _infer_task_from_model_name_or_path(
)
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
if getattr(model_info, "library_name", None) == "diffusers":
# TODO : getattr(model_info, "model_index") defining auto_model_class_name currently set to None
for task in ("stable-diffusion-xl", "stable-diffusion"):
if task in model_info.tags:
inferred_task_name = task
break
class_name = model_info.config["diffusers"]["class_name"]
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
elif getattr(model_info, "library_name", None) == "timm":
inferred_task_name = "image-classification"
else:
Expand Down
4 changes: 4 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
]
else:
_import_structure["modeling_diffusion"] = [
Expand All @@ -86,6 +87,7 @@
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
]


Expand Down Expand Up @@ -135,6 +137,7 @@
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_diffusers_objects import (
ORTLatentConsistencyModelPipeline,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand All @@ -143,6 +146,7 @@
)
else:
from .modeling_diffusion import (
ORTLatentConsistencyModelPipeline,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand Down
14 changes: 13 additions & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
Expand Down Expand Up @@ -501,6 +502,7 @@ def forward(
encoder_hidden_states: np.ndarray,
text_embeds: Optional[np.ndarray] = None,
time_ids: Optional[np.ndarray] = None,
timestep_cond: Optional[np.ndarray] = None,
):
onnx_inputs = {
"sample": sample,
Expand All @@ -512,7 +514,8 @@ def forward(
onnx_inputs["text_embeds"] = text_embeds
if time_ids is not None:
onnx_inputs["time_ids"] = time_ids

if timestep_cond is not None:
onnx_inputs["timestep_cond"] = timestep_cond
outputs = self.session.run(None, onnx_inputs)
return outputs

Expand Down Expand Up @@ -562,6 +565,15 @@ class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDi
__call__ = StableDiffusionInpaintPipelineMixin.__call__


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTLatentConsistencyModelPipeline(ORTStableDiffusionPipelineBase, LatentConsistencyPipelineMixin):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
"""

__call__ = LatentConsistencyPipelineMixin.__call__


class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase):
auto_model_class = StableDiffusionXLImg2ImgPipeline

Expand Down
230 changes: 230 additions & 0 deletions optimum/pipelines/diffusers/pipeline_latent_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable, List, Optional, Union

import numpy as np
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from .pipeline_stable_diffusion import StableDiffusionPipelineMixin


logger = logging.getLogger(__name__)


class LatentConsistencyPipelineMixin(StableDiffusionPipelineMixin):
# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
guidance_scale: float = 8.5,
num_images_per_prompt: int = 1,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`Optional[int]`, defaults to None):
The height in pixels of the generated image.
width (`Optional[int]`, defaults to None):
The width in pixels of the generated image.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor

# Don't need to get negative prompts due to LCM guided distillation
negative_prompt = None
negative_prompt_embeds = None

# check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)

# define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

if generator is None:
generator = np.random

prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
False,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps, original_inference_steps=original_inference_steps)
timesteps = self.scheduler.timesteps

latents = self.prepare_latents(
batch_size * num_images_per_prompt,
self.unet.config["in_channels"],
height,
width,
prompt_embeds.dtype,
generator,
latents,
)

bs = batch_size * num_images_per_prompt
# get Guidance Scale Embedding
w = np.full(bs, guidance_scale - 1, dtype=prompt_embeds.dtype)
w_embedding = self.get_guidance_scale_embedding(
w, embedding_dim=self.unet.config["time_cond_proj_dim"], dtype=prompt_embeds.dtype
)

# Adapted from diffusers to extend it for other runtimes than ORT
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)

num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latents,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
timestep_cond=w_embedding,
)[0]

# compute the previous noisy sample x_t -> x_t-1
latents, denoised = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), return_dict=False
)
latents, denoised = latents.numpy(), denoised.numpy()

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
image = denoised
has_nsfw_concept = None
else:
denoised /= self.vae_decoder.config["scaling_factor"]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=denoised[i : i + 1])[0] for i in range(denoised.shape[0])]
)
image, has_nsfw_concept = self.run_safety_checker(image)

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

# Adapted from https://github.com/huggingface/diffusers/blob/v0.22.0/src/diffusers/pipelines/latent_consistency/pipeline_latent_consistency.py#L264
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=None):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
w = w * 1000
half_dim = embedding_dim // 2
emb = np.log(10000.0) / (half_dim - 1)
emb = np.exp(np.arange(half_dim, dtype=dtype) * -emb)
emb = w[:, None] * emb[None, :]
emb = np.concatenate([np.sin(emb), np.cos(emb)], axis=1)

if embedding_dim % 2 == 1: # zero pad
emb = np.pad(emb, [(0, 0), (0, 1)])

assert emb.shape == (w.shape[0], embedding_dim)
return emb
11 changes: 11 additions & 0 deletions optimum/utils/dummy_diffusers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["diffusers"])


class ORTLatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["diffusers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["diffusers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["diffusers"])
Loading

0 comments on commit 01dd5c3

Please sign in to comment.