diff --git a/ai_edge_torch/convert/conversion.py b/ai_edge_torch/convert/conversion.py index 2318b6d4..a6e6495c 100644 --- a/ai_edge_torch/convert/conversion.py +++ b/ai_edge_torch/convert/conversion.py @@ -30,6 +30,7 @@ from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass from ai_edge_torch.convert.fx_passes import run_passes +from ai_edge_torch.generative.fx_passes import run_generative_passes from ai_edge_torch.quantize import quant_config as qcfg os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1" @@ -38,6 +39,7 @@ def _run_convert_passes( exported_program: ExportedProgram, ) -> ExportedProgram: + exported_program = run_generative_passes(exported_program) return run_passes( exported_program, [ diff --git a/ai_edge_torch/generative/fx_passes/__init__.py b/ai_edge_torch/generative/fx_passes/__init__.py new file mode 100644 index 00000000..c343bbbc --- /dev/null +++ b/ai_edge_torch/generative/fx_passes/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch + +from ai_edge_torch.convert.fx_passes import CanonicalizePass +from ai_edge_torch.convert.fx_passes import run_passes +from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA + + +def run_generative_passes( + exported_program: torch.export.ExportedProgram, +) -> torch.export.ExportedProgram: + return run_passes( + exported_program, + [ + RemoveSDPACompositeZeroMaskPass(), + CanonicalizePass(), + ], + ) diff --git a/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py new file mode 100644 index 00000000..5e9a2f42 --- /dev/null +++ b/ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py @@ -0,0 +1,47 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch + +from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase +from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA + + +class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase): + + def is_zero_tensor_node(self, node: torch.fx.Node): + return node.target == torch.ops.aten.zeros.default + + def call(self, exported_program: torch.export.ExportedProgram): + graph = exported_program.graph_module.graph + for node in graph.nodes: + if not ( + node.op == "call_function" + and node.target == torch.ops.xla.mark_tensor.default + ): + continue + + source, name, io_position, id, is_input = node.args[:5] + # Composite info: + # - name: odml.scaled_dot_product_attention + # - inputs: q, k, v, mask + if name == "odml.scaled_dot_product_attention" and is_input and io_position == 3: + if self.is_zero_tensor_node(source): + # Remove the mark_tensor call on the mask input by + # replacing the target with an identity function. + node.target = lambda *args, **kwargs: args[0] + + exported_program.graph_module.graph.lint() + exported_program.graph_module.recompile() + return ExportedProgramPassResult(exported_program, True) diff --git a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py new file mode 100644 index 00000000..f3af58a7 --- /dev/null +++ b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py @@ -0,0 +1,125 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import re +from typing import Callable, Union +import unittest + +import torch +import torch_xla + +from ai_edge_torch.convert.fx_passes import CanonicalizePass +from ai_edge_torch.convert.fx_passes import run_passes +from ai_edge_torch.generative.fx_passes import RemoveSDPACompositeZeroMaskPass +from ai_edge_torch.generative.layers.attention import SelfAttention +import ai_edge_torch.generative.layers.model_config as layers_cfg +import ai_edge_torch.generative.layers.unet.builder as unet_builder +import ai_edge_torch.generative.layers.unet.model_config as unet_cfg + + +def _export_to_stablehlo(func: Union[torch.nn.Module, Callable], export_args): + if not isinstance(func, torch.nn.Module): + + class TestModule(torch.nn.Module): + + def forward(self, *args, **kwargs): + return func(*args, **kwargs) + + module = TestModule().eval() + else: + module = func + + exported_program = torch.export.export(module, export_args) + exported_program = run_passes( + exported_program, + [ + RemoveSDPACompositeZeroMaskPass(), + CanonicalizePass(), + ], + ) + + return torch_xla.stablehlo.exported_program_to_stablehlo( + exported_program + ).get_stablehlo_text() + + +class TestRemoveSDPAZeroMaskPass(unittest.TestCase): + + def test_self_attention_no_zero_mask_composite_input(self): + class SampleSdpaBlock(torch.nn.Module): + """Sample attention block with SDPA""" + + def __init__(self, config: unet_cfg.AttentionBlock2DConfig): + super().__init__() + self.config = config + self.attention = SelfAttention( + config.attention_batch_size, + config.dim, + config.attention_config, + 0, + enable_hlfb=config.enable_hlfb, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + B, C, H, W = input_tensor.shape + x = input_tensor + x = input_tensor.view(B, C, H * W) + x = x.transpose(-1, -2) + # x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite. + x = self.attention(x) + x = x.transpose(-1, -2) + x = x.view(B, C, H, W) + return x + + def get_model_config() -> unet_cfg.AttentionBlock2DConfig: + """Get configs for the Decoder of Stable Diffusion v1.5""" + in_channels = 3 + latent_channels = 4 + out_channels = 3 + block_out_channels = [128, 256, 512, 512] + scaling_factor = 0.18215 + layers_per_block = 3 + + norm_config = layers_cfg.NormalizationConfig( + layers_cfg.NormalizationType.GROUP_NORM, group_num=32 + ) + + return unet_cfg.AttentionBlock2DConfig( + dim=block_out_channels[-1], + normalization_config=norm_config, + attention_config=layers_cfg.AttentionConfig( + num_heads=1, + num_query_groups=1, + qkv_use_bias=True, + output_proj_use_bias=True, + enable_kv_cache=False, + qkv_transpose_before_split=True, + rotary_percentage=0.0, + ), + ) + + stablehlo = _export_to_stablehlo( + SampleSdpaBlock(get_model_config()).eval(), (torch.rand(1, 512, 64, 64),) + ) + print(stablehlo) + self.assertTrue( + re.search( + 'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+, %\d+, %\d+ {', + stablehlo, + ) + ) + + +if __name__ == '__main__': + unittest.main()