Skip to content

Commit

Permalink
Add conversion tests for Stable Diffusion v1.5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678361445
  • Loading branch information
ai-edge-bot authored and copybara-github committed Sep 24, 2024
1 parent 3763b9e commit 68904bc
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 4 deletions.
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package(
default_applicable_licenses = [
"//third_party/py/ai_edge_torch:license",
],
default_visibility = [
"//third_party/py/ai_edge_torch:__subpackages__",
],
)

py_binary(
Expand Down
53 changes: 52 additions & 1 deletion ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


class CLIP(nn.Module):
"""CLIP text encoder
"""CLIP text encoder.
For details, see https://arxiv.org/abs/2103.00020
"""
Expand Down Expand Up @@ -86,6 +86,7 @@ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:


def get_model_config() -> cfg.ModelConfig:
"""Get configs for the CLIP of Stable Diffusion v1.5."""
max_seq_len = 77
vocab_size = 49408
num_layers = 12
Expand Down Expand Up @@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig:
)

return config


def get_fake_model_config() -> cfg.ModelConfig:
"""Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
max_seq_len = 6
vocab_size = 100
num_layers = 2
num_heads = 12
num_query_groups = 12
embedding_dim = 24

attn_config = cfg.AttentionConfig(
num_heads=num_heads,
head_dim=embedding_dim // num_heads,
num_query_groups=num_query_groups,
rotary_percentage=0.0,
qkv_use_bias=True,
qkv_transpose_before_split=True,
qkv_fused_interleaved=False,
output_proj_use_bias=True,
enable_kv_cache=False,
)

ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
intermediate_size=embedding_dim * 4,
use_bias=True,
)

norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)

block_config = cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)

config = cfg.ModelConfig(
vocab_size=vocab_size,
num_layers=num_layers,
max_seq_len=max_seq_len,
embedding_dim=embedding_dim,
block_configs=block_config,
final_norm_config=norm_config,
enable_hlfb=True,
)

return config
56 changes: 56 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
mid_block_config=mid_block_config,
)
return config


def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
"""Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
in_channels = 3
latent_channels = 4
out_channels = 3
block_out_channels = [2, 4]
scaling_factor = 0.18215
layers_per_block = 2

norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
)

att_config = unet_cfg.AttentionBlock2DConfig(
dim=block_out_channels[-1],
normalization_config=norm_config,
attention_config=layers_cfg.AttentionConfig(
num_heads=1,
head_dim=block_out_channels[-1],
num_query_groups=1,
qkv_use_bias=True,
output_proj_use_bias=True,
enable_kv_cache=False,
qkv_transpose_before_split=True,
qkv_fused_interleaved=False,
rotary_percentage=0.0,
),
enable_hlfb=False,
)

mid_block_config = unet_cfg.MidBlock2DConfig(
in_channels=block_out_channels[-1],
normalization_config=norm_config,
activation_config=layers_cfg.ActivationConfig(
layers_cfg.ActivationType.SILU
),
num_layers=1,
attention_block_config=att_config,
)

config = unet_cfg.AutoEncoderConfig(
in_channels=in_channels,
latent_channels=latent_channels,
out_channels=out_channels,
activation_config=layers_cfg.ActivationConfig(
layers_cfg.ActivationType.SILU
),
block_out_channels=block_out_channels,
scaling_factor=scaling_factor,
layers_per_block=layers_per_block,
normalization_config=norm_config,
mid_block_config=mid_block_config,
)
return config
70 changes: 69 additions & 1 deletion ai_edge_torch/generative/examples/stable_diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
# Transformer configs.
transformer_num_attention_heads = 8
transformer_batch_size = batch_size
transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
transformer_cross_attention_dim = 768 # Embedding from CLIP model
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
)
Expand Down Expand Up @@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
final_norm_config=final_norm_config,
final_activation_type=final_activation_type,
)


def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
"""Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
Args:
batch_size (int): the batch size of input.
Retruns:
The configuration of diffusion model of Stable Diffusion v1.5.
"""
in_channels = 4
out_channels = 4
block_out_channels = [2, 4, 8, 8]
layers_per_block = 1
downsample_padding = 1

# Residual configs.
residual_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
)
residual_activation_type = layers_cfg.ActivationType.SILU

# Transformer configs.
transformer_num_attention_heads = 1
transformer_batch_size = batch_size
transformer_cross_attention_dim = 4 # Embedding from CLIP model
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
)
transformer_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.LAYER_NORM
)
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU

# Time embedding configs.
time_embedding_dim = 2
time_embedding_blocks_dim = 4

# Mid block configs.
mid_block_layers = 1

# Finaly layer configs.
final_norm_config = layers_cfg.NormalizationConfig(
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
)
final_activation_type = layers_cfg.ActivationType.SILU

return unet_cfg.DiffusionModelConfig(
in_channels=in_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
residual_norm_config=residual_norm_config,
residual_activation_type=residual_activation_type,
transformer_batch_size=transformer_batch_size,
transformer_num_attention_heads=transformer_num_attention_heads,
transformer_cross_attention_dim=transformer_cross_attention_dim,
transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
transformer_norm_config=transformer_norm_config,
transformer_ff_activation_type=transformer_ff_activation_type,
mid_block_layers=mid_block_layers,
time_embedding_dim=time_embedding_dim,
time_embedding_blocks_dim=time_embedding_blocks_dim,
final_norm_config=final_norm_config,
final_activation_type=final_activation_type,
)
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from ai_edge_torch.generative.layers.attention import CrossAttention
from ai_edge_torch.generative.layers.attention import SelfAttention
Expand Down Expand Up @@ -416,7 +416,7 @@ def forward(
time_emb: Optional[torch.Tensor] = None,
context_tensor: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""Forward function of the DownEncoderBlock2D.
Args:
Expand Down
107 changes: 107 additions & 0 deletions ai_edge_torch/generative/test/test_model_conversion_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from ai_edge_torch.generative.examples.phi import phi2
from ai_edge_torch.generative.examples.phi import phi3
from ai_edge_torch.generative.examples.smollm import smollm
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
from ai_edge_torch.generative.layers import kv_cache
from ai_edge_torch.generative.test import utils as test_utils
import numpy as np
Expand Down Expand Up @@ -139,6 +142,110 @@ def test_openelm(self):
pytorch_model = openelm.OpenELM(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_stable_diffusion_clip(self):
config = sd_clip.get_fake_model_config()
prompt_tokens = torch.from_numpy(
np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
)

pytorch_model = sd_clip.CLIP(config).eval()
torch_output = pytorch_model(prompt_tokens)

edge_model = ai_edge_torch.signature(
"encode", pytorch_model, (prompt_tokens,)
).convert()
edge_model.set_interpreter_builder(
self._interpreter_builder(edge_model.tflite_model())
)
edge_output = edge_model(
prompt_tokens.numpy(),
signature_name="encode",
)
self.assertTrue(
np.allclose(
edge_output,
torch_output.detach().numpy(),
atol=1e-4,
rtol=1e-5,
)
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_stable_diffusion_diffusion(self):
config = sd_diffusion.get_fake_model_config(2)
latents = torch.from_numpy(
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
)
context = torch.from_numpy(
np.random.normal(size=(2, 4, 4)).astype(np.float32)
)
time_embedding = torch.from_numpy(
np.random.normal(size=(2, 2)).astype(np.float32)
)

pytorch_model = sd_diffusion.Diffusion(config).eval()
torch_output = pytorch_model(latents, context, time_embedding)

edge_model = ai_edge_torch.signature(
"diffusion", pytorch_model, (latents, context, time_embedding)
).convert()
edge_model.set_interpreter_builder(
self._interpreter_builder(edge_model.tflite_model())
)
edge_output = edge_model(
latents.numpy(),
context.numpy(),
time_embedding.numpy(),
signature_name="diffusion",
)
self.assertTrue(
np.allclose(
edge_output,
torch_output.detach().numpy(),
atol=1e-4,
rtol=1e-5,
)
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_stable_diffusion_decoder(self):
config = sd_decoder.get_fake_model_config()
latents = torch.from_numpy(
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
)

pytorch_model = sd_decoder.Decoder(config).eval()
torch_output = pytorch_model(latents)

edge_model = ai_edge_torch.signature(
"decode", pytorch_model, (latents,)
).convert()
edge_model.set_interpreter_builder(
self._interpreter_builder(edge_model.tflite_model())
)
edge_output = edge_model(
latents.numpy(),
signature_name="decode",
)
self.assertTrue(
np.allclose(
edge_output,
torch_output.detach().numpy(),
atol=1e-4,
rtol=1e-5,
)
)


if __name__ == "__main__":
googletest.main()

0 comments on commit 68904bc

Please sign in to comment.