diff --git a/ai_edge_torch/generative/examples/stable_diffusion/BUILD b/ai_edge_torch/generative/examples/stable_diffusion/BUILD index 1aac3c50..f23c5344 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/BUILD +++ b/ai_edge_torch/generative/examples/stable_diffusion/BUILD @@ -19,6 +19,9 @@ package( default_applicable_licenses = [ "//third_party/py/ai_edge_torch:license", ], + default_visibility = [ + "//third_party/py/ai_edge_torch:__subpackages__", + ], ) py_binary( diff --git a/ai_edge_torch/generative/examples/stable_diffusion/clip.py b/ai_edge_torch/generative/examples/stable_diffusion/clip.py index 0821d54b..5c0bd3eb 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/clip.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/clip.py @@ -48,7 +48,7 @@ class CLIP(nn.Module): - """CLIP text encoder + """CLIP text encoder. For details, see https://arxiv.org/abs/2103.00020 """ @@ -86,6 +86,7 @@ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: def get_model_config() -> cfg.ModelConfig: + """Get configs for the CLIP of Stable Diffusion v1.5.""" max_seq_len = 77 vocab_size = 49408 num_layers = 12 @@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig: ) return config + + +def get_fake_model_config() -> cfg.ModelConfig: + """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing.""" + max_seq_len = 6 + vocab_size = 100 + num_layers = 2 + num_heads = 12 + num_query_groups = 12 + embedding_dim = 24 + + attn_config = cfg.AttentionConfig( + num_heads=num_heads, + head_dim=embedding_dim // num_heads, + num_query_groups=num_query_groups, + rotary_percentage=0.0, + qkv_use_bias=True, + qkv_transpose_before_split=True, + qkv_fused_interleaved=False, + output_proj_use_bias=True, + enable_kv_cache=False, + ) + + ff_config = cfg.FeedForwardConfig( + type=cfg.FeedForwardType.SEQUENTIAL, + activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK), + intermediate_size=embedding_dim * 4, + use_bias=True, + ) + + norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM) + + block_config = cfg.TransformerBlockConfig( + attn_config=attn_config, + ff_config=ff_config, + pre_attention_norm_config=norm_config, + post_attention_norm_config=norm_config, + ) + + config = cfg.ModelConfig( + vocab_size=vocab_size, + num_layers=num_layers, + max_seq_len=max_seq_len, + embedding_dim=embedding_dim, + block_configs=block_config, + final_norm_config=norm_config, + enable_hlfb=True, + ) + + return config diff --git a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py index d0e51ddc..44e2fa4a 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/decoder.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/decoder.py @@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig: mid_block_config=mid_block_config, ) return config + + +def get_fake_model_config() -> unet_cfg.AutoEncoderConfig: + """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing.""" + in_channels = 3 + latent_channels = 4 + out_channels = 3 + block_out_channels = [2, 4] + scaling_factor = 0.18215 + layers_per_block = 2 + + norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + ) + + att_config = unet_cfg.AttentionBlock2DConfig( + dim=block_out_channels[-1], + normalization_config=norm_config, + attention_config=layers_cfg.AttentionConfig( + num_heads=1, + head_dim=block_out_channels[-1], + num_query_groups=1, + qkv_use_bias=True, + output_proj_use_bias=True, + enable_kv_cache=False, + qkv_transpose_before_split=True, + qkv_fused_interleaved=False, + rotary_percentage=0.0, + ), + enable_hlfb=False, + ) + + mid_block_config = unet_cfg.MidBlock2DConfig( + in_channels=block_out_channels[-1], + normalization_config=norm_config, + activation_config=layers_cfg.ActivationConfig( + layers_cfg.ActivationType.SILU + ), + num_layers=1, + attention_block_config=att_config, + ) + + config = unet_cfg.AutoEncoderConfig( + in_channels=in_channels, + latent_channels=latent_channels, + out_channels=out_channels, + activation_config=layers_cfg.ActivationConfig( + layers_cfg.ActivationType.SILU + ), + block_out_channels=block_out_channels, + scaling_factor=scaling_factor, + layers_per_block=layers_per_block, + normalization_config=norm_config, + mid_block_config=mid_block_config, + ) + return config diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index f88f8fa3..3724af32 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: # Transformer configs. transformer_num_attention_heads = 8 transformer_batch_size = batch_size - transformer_cross_attention_dim = 768 # Embedding fomr CLIP model + transformer_cross_attention_dim = 768 # Embedding from CLIP model transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig( layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32 ) @@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: final_norm_config=final_norm_config, final_activation_type=final_activation_type, ) + + +def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig: + """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing. + + Args: + batch_size (int): the batch size of input. + + Retruns: + The configuration of diffusion model of Stable Diffusion v1.5. + """ + in_channels = 4 + out_channels = 4 + block_out_channels = [2, 4, 8, 8] + layers_per_block = 1 + downsample_padding = 1 + + # Residual configs. + residual_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + ) + residual_activation_type = layers_cfg.ActivationType.SILU + + # Transformer configs. + transformer_num_attention_heads = 1 + transformer_batch_size = batch_size + transformer_cross_attention_dim = 4 # Embedding from CLIP model + transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2 + ) + transformer_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.LAYER_NORM + ) + transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU + + # Time embedding configs. + time_embedding_dim = 2 + time_embedding_blocks_dim = 4 + + # Mid block configs. + mid_block_layers = 1 + + # Finaly layer configs. + final_norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=2 + ) + final_activation_type = layers_cfg.ActivationType.SILU + + return unet_cfg.DiffusionModelConfig( + in_channels=in_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + downsample_padding=downsample_padding, + residual_norm_config=residual_norm_config, + residual_activation_type=residual_activation_type, + transformer_batch_size=transformer_batch_size, + transformer_num_attention_heads=transformer_num_attention_heads, + transformer_cross_attention_dim=transformer_cross_attention_dim, + transformer_pre_conv_norm_config=transformer_pre_conv_norm_config, + transformer_norm_config=transformer_norm_config, + transformer_ff_activation_type=transformer_ff_activation_type, + mid_block_layers=mid_block_layers, + time_embedding_dim=time_embedding_dim, + time_embedding_blocks_dim=time_embedding_blocks_dim, + final_norm_config=final_norm_config, + final_activation_type=final_activation_type, + ) diff --git a/ai_edge_torch/generative/layers/unet/blocks_2d.py b/ai_edge_torch/generative/layers/unet/blocks_2d.py index 115c3a00..690e04b6 100644 --- a/ai_edge_torch/generative/layers/unet/blocks_2d.py +++ b/ai_edge_torch/generative/layers/unet/blocks_2d.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from ai_edge_torch.generative.layers.attention import CrossAttention from ai_edge_torch.generative.layers.attention import SelfAttention @@ -416,7 +416,7 @@ def forward( time_emb: Optional[torch.Tensor] = None, context_tensor: Optional[torch.Tensor] = None, output_hidden_states: bool = False, - ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """Forward function of the DownEncoderBlock2D. Args: diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index 491f889d..d3c6eb4b 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -23,6 +23,9 @@ from ai_edge_torch.generative.examples.phi import phi2 from ai_edge_torch.generative.examples.phi import phi3 from ai_edge_torch.generative.examples.smollm import smollm +from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip +from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder +from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion from ai_edge_torch.generative.layers import kv_cache from ai_edge_torch.generative.test import utils as test_utils import numpy as np @@ -139,6 +142,110 @@ def test_openelm(self): pytorch_model = openelm.OpenELM(config).eval() self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) + @googletest.skipIf( + ai_edge_config.Config.use_torch_xla, + reason="tests with custom ops are not supported on oss", + ) + def test_stable_diffusion_clip(self): + config = sd_clip.get_fake_model_config() + prompt_tokens = torch.from_numpy( + np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32) + ) + + pytorch_model = sd_clip.CLIP(config).eval() + torch_output = pytorch_model(prompt_tokens) + + edge_model = ai_edge_torch.signature( + "encode", pytorch_model, (prompt_tokens,) + ).convert() + edge_model.set_interpreter_builder( + self._interpreter_builder(edge_model.tflite_model()) + ) + edge_output = edge_model( + prompt_tokens.numpy(), + signature_name="encode", + ) + self.assertTrue( + np.allclose( + edge_output, + torch_output.detach().numpy(), + atol=1e-4, + rtol=1e-5, + ) + ) + + @googletest.skipIf( + ai_edge_config.Config.use_torch_xla, + reason="tests with custom ops are not supported on oss", + ) + def test_stable_diffusion_diffusion(self): + config = sd_diffusion.get_fake_model_config(2) + latents = torch.from_numpy( + np.random.normal(size=(2, 4, 8, 8)).astype(np.float32) + ) + context = torch.from_numpy( + np.random.normal(size=(2, 4, 4)).astype(np.float32) + ) + time_embedding = torch.from_numpy( + np.random.normal(size=(2, 2)).astype(np.float32) + ) + + pytorch_model = sd_diffusion.Diffusion(config).eval() + torch_output = pytorch_model(latents, context, time_embedding) + + edge_model = ai_edge_torch.signature( + "diffusion", pytorch_model, (latents, context, time_embedding) + ).convert() + edge_model.set_interpreter_builder( + self._interpreter_builder(edge_model.tflite_model()) + ) + edge_output = edge_model( + latents.numpy(), + context.numpy(), + time_embedding.numpy(), + signature_name="diffusion", + ) + self.assertTrue( + np.allclose( + edge_output, + torch_output.detach().numpy(), + atol=1e-4, + rtol=1e-5, + ) + ) + + @googletest.skipIf( + ai_edge_config.Config.use_torch_xla, + reason="tests with custom ops are not supported on oss", + ) + def test_stable_diffusion_decoder(self): + config = sd_decoder.get_fake_model_config() + latents = torch.from_numpy( + np.random.normal(size=(1, 4, 64, 64)).astype(np.float32) + ) + + pytorch_model = sd_decoder.Decoder(config).eval() + torch_output = pytorch_model(latents) + + edge_model = ai_edge_torch.signature( + "decode", pytorch_model, (latents,) + ).convert() + edge_model.set_interpreter_builder( + self._interpreter_builder(edge_model.tflite_model()) + ) + edge_output = edge_model( + latents.numpy(), + signature_name="decode", + ) + self.assertTrue( + np.allclose( + edge_output, + torch_output.detach().numpy(), + atol=1e-4, + rtol=1e-5, + ) + ) + if __name__ == "__main__": googletest.main()