diff --git a/ai_edge_torch/_convert/fx_passes/test/test_build_aten_composite_pass.py b/ai_edge_torch/_convert/fx_passes/test/test_build_aten_composite_pass.py index 61968435..4228a232 100644 --- a/ai_edge_torch/_convert/fx_passes/test/test_build_aten_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/test/test_build_aten_composite_pass.py @@ -65,16 +65,24 @@ def test_hardswish_layer(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.nn.Hardswish()(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1 + + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.hardswish.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_hardswish_op(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.ops.aten.hardswish.default(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1 + + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.hardswish.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_avg_pool2d_layer(self): @@ -90,8 +98,11 @@ def test_avg_pool2d_layer(self): )(x), (torch.rand(1, 3, 6, 6),), ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.avg_pool2d.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_avg_pool2d_op(self): @@ -108,8 +119,11 @@ def test_avg_pool2d_op(self): ), (torch.rand(1, 3, 6, 6),), ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.avg_pool2d.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_avg_pool2d_ceil_mode(self): @@ -126,32 +140,44 @@ def test_avg_pool2d_ceil_mode(self): ), (torch.rand(1, 3, 6, 6),), ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.avg_pool2d.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_gelu_layer(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.nn.GELU()(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.gelu.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_approximate_gelu_layer(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.nn.GELU('tanh')(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda ) - self.assertEqual( - stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "aten.gelu.default"': 1}, + {'stablehlo.custom_call @mark_tensor': 2}, ) def test_embedding_lookup_layer(self): stablehlo = _export_to_stablehlo_with_composite( torch.nn.Embedding(10, 10), (torch.full((1, 10), 0, dtype=torch.long),) ) - self.assertEqual( - stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "odml.embedding_lookup"': 1}, + {'stablehlo.custom_call @mark_tensor': 3}, ) def test_embedding_lookup_op(self): @@ -159,8 +185,11 @@ def test_embedding_lookup_op(self): lambda *x: torch.ops.aten.embedding.default(*x), (torch.rand(10, 10), torch.full((1, 10), 0, dtype=torch.long)), ) - self.assertEqual( - stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "odml.embedding_lookup"': 1}, + {'stablehlo.custom_call @mark_tensor': 3}, ) def test_embedding_lookup_functional(self): @@ -171,8 +200,11 @@ def test_embedding_lookup_functional(self): torch.rand(10, 10), ), ) - self.assertEqual( - stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1 + lowertools.assert_string_count( + self, + stablehlo, + {'stablehlo.composite "odml.embedding_lookup"': 1}, + {'stablehlo.custom_call @mark_tensor': 3}, ) diff --git a/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py index 978063ec..fb667902 100644 --- a/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py @@ -56,15 +56,18 @@ def test_nn_functional_upsample_bilinear(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = false, output = dense<30>' - ' : tensor<2xi64>}' - ), - 1, + + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [30, 30], "align_corners": false}': 1}, ) def test_nn_functional_upsample_bilinear_align_corners(self): @@ -74,15 +77,18 @@ def test_nn_functional_upsample_bilinear_align_corners(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = true, output = dense<30> :' - ' tensor<2xi64>}' - ), - 1, + + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [30, 30], "align_corners": true}': 1}, ) def test_nn_functional_upsample_bilinear_size(self): @@ -92,15 +98,18 @@ def test_nn_functional_upsample_bilinear_size(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = false, output = dense<[15,' - ' 20]> : tensor<2xi64>}' - ), - 1, + + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [15, 20], "align_corners": false}': 1}, ) def test_nn_functional_upsample_bilinear_size_align_corners(self): @@ -110,15 +119,17 @@ def test_nn_functional_upsample_bilinear_size_align_corners(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = true, output = dense<[15,' - ' 20]> : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [15, 20], "align_corners": true}': 1}, ) def test_nn_upsample_bilinear(self): @@ -126,15 +137,17 @@ def test_nn_upsample_bilinear(self): torch.nn.Upsample(scale_factor=3.0, mode='bilinear').eval(), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = false, output = dense<30>' - ' : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [30, 30], "align_corners": false}': 1}, ) def test_nn_functional_interpolate_bilinear(self): @@ -144,15 +157,17 @@ def test_nn_functional_interpolate_bilinear(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = false, output = dense<30>' - ' : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [30, 30], "align_corners": false}': 1}, ) def test_nn_functional_interpolate_bilinear_align_corners(self): @@ -162,15 +177,17 @@ def test_nn_functional_interpolate_bilinear_align_corners(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = true, output = dense<30> :' - ' tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [30, 30], "align_corners": true}': 1}, ) def test_nn_functional_interpolate_bilinear_size(self): @@ -180,15 +197,17 @@ def test_nn_functional_interpolate_bilinear_size(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = false, output = dense<[15,' - ' 20]> : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [15, 20], "align_corners": false}': 1}, ) def test_nn_functional_interpolate_bilinear_size_align_corners(self): @@ -198,15 +217,17 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "odml.upsample_bilinear2d"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {align_corners = true, output = dense<[15,' - ' 20]> : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "odml.upsample_bilinear2d"': 1, + 'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"output": [15, 20], "align_corners": true}': 1}, ) def test_nn_functional_interpolate_nearest(self): @@ -216,15 +237,17 @@ def test_nn_functional_interpolate_nearest(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {is_nchw_op = true, size = dense<30> :' - ' tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "tfl.resize_nearest_neighbor"': 1, + 'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"size": [30, 30], "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_nearest_size(self): @@ -234,15 +257,17 @@ def test_nn_functional_interpolate_nearest_size(self): ), (torch.rand(1, 3, 10, 10),), ) - self.assertTrue( - stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 - ) - self.assertTrue( - stablehlo.count( - 'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]>' - ' : tensor<2xi64>}' - ), - 1, + lowertools.assert_string_count( + self, + stablehlo, + { + 'stablehlo.composite "tfl.resize_nearest_neighbor"': 1, + 'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}': ( + 1 + ), + }, + {'stablehlo.custom_call @mark_tensor': 2}, + {'{"size": [15, 20], "is_nchw_op": true}': 1}, ) diff --git a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py index 7444644f..662a88c1 100644 --- a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py +++ b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py @@ -112,13 +112,17 @@ def get_model_config() -> unet_cfg.AttentionBlock2DConfig: SampleSdpaBlock(get_model_config()).eval(), (torch.rand(1, 512, 64, 64),), ) - self.assertTrue( - re.search( - 'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+,' - ' %\d+, %\d+ {', - stablehlo, - ) - ) + + if config.Config.use_torch_xla: + self.assertTrue( + re.search( + 'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+,' + ' %\d+, %\d+ {', + stablehlo, + ) + ) + else: + self.assertEqual(stablehlo.count('stablehlo.custom_call @mark_tensor'), 4) if __name__ == '__main__': diff --git a/ai_edge_torch/hlfb/test/test_mark_pattern.py b/ai_edge_torch/hlfb/test/test_mark_pattern.py index d8a88dbe..67ce90a1 100644 --- a/ai_edge_torch/hlfb/test/test_mark_pattern.py +++ b/ai_edge_torch/hlfb/test/test_mark_pattern.py @@ -51,7 +51,12 @@ def forward(self, x): mark_pattern.mark_pattern(exported_program.graph_module, pattern) mlir = _export_stablehlo_mlir(exported_program) - self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2) + lowertools.assert_string_count( + self, + mlir, + {'stablehlo.composite "test.add"': 2}, + {"stablehlo.custom_call @mark_tensor": 6}, + ) def test_mark_pattern_with_attr_builder(self): class TestModel(torch.nn.Module): @@ -72,9 +77,15 @@ def forward(self, x): mark_pattern.mark_pattern(exported_program.graph_module, pattern) mlir = _export_stablehlo_mlir(exported_program) - self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2) - self.assertEqual( - mlir.count('composite_attributes = {alias = "test.test_add"}'), 2 + lowertools.assert_string_count( + self, + mlir, + { + 'stablehlo.composite "test.add"': 2, + 'composite_attributes = {alias = "test.test_add"}': 2, + }, + {"stablehlo.custom_call @mark_tensor": 6}, + {'{"alias": "test.test_add"}': 2}, ) def test_mark_pattern_with_scalar_attr_tracker(self): @@ -104,9 +115,17 @@ def forward(self, x): mark_pattern.mark_pattern(exported_program.graph_module, pattern) mlir = _export_stablehlo_mlir(exported_program) - self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 5) - self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 3) - self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 2) + lowertools.assert_string_count( + self, + mlir, + { + 'stablehlo.composite "test.log_softmax"': 5, + "composite_attributes = {dim = 0 : i64}": 3, + "composite_attributes = {dim = 1 : i64}": 2, + }, + {"stablehlo.custom_call @mark_tensor": 10}, + {'{"dim": 0}': 3, '{"dim": 1}': 2}, + ) def test_mark_tangent_model_and_pattern_input(self): class TestModel(torch.nn.Module): @@ -128,7 +147,12 @@ def forward(self, x, y): mark_pattern.mark_pattern(exported_program.graph_module, pattern) mlir = _export_stablehlo_mlir(exported_program) - self.assertEqual(mlir.count('stablehlo.composite "test.relu'), 1) + lowertools.assert_string_count( + self, + mlir, + {'stablehlo.composite "test.relu"': 1}, + {"stablehlo.custom_call @mark_tensor": 2}, + ) if __name__ == "__main__": diff --git a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py b/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py index e403f4a2..d9f713f9 100644 --- a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +++ b/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py @@ -16,6 +16,7 @@ import math +from ai_edge_torch import config from ai_edge_torch import lowertools from ai_edge_torch.hlfb import StableHLOCompositeBuilder import torch @@ -29,6 +30,10 @@ def _export_stablehlo_mlir(model, args): return lowertools.exported_program_to_mlir_text(ep) +@googletest.skipIf( + not config.Config.use_torch_xla, + reason="The odml_torch counter part is in odml_torch.", +) class TestStableHLOCompositeBuilder(googletest.TestCase): def test_build_composite(self): diff --git a/ai_edge_torch/lowertools/__init__.py b/ai_edge_torch/lowertools/__init__.py index 0de96c4c..007f5d26 100644 --- a/ai_edge_torch/lowertools/__init__.py +++ b/ai_edge_torch/lowertools/__init__.py @@ -14,3 +14,4 @@ # ============================================================================== from ._shim import * +from .test_utils import * diff --git a/ai_edge_torch/lowertools/odml_torch_utils.py b/ai_edge_torch/lowertools/odml_torch_utils.py index ece1b349..1aeb920c 100644 --- a/ai_edge_torch/lowertools/odml_torch_utils.py +++ b/ai_edge_torch/lowertools/odml_torch_utils.py @@ -28,6 +28,7 @@ import torch from tensorflow.compiler.tf2xla.python import xla as tfxla +from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb MlirBundle = odml_torch.export.MlirLowered @@ -162,7 +163,9 @@ def merged_bundle_to_tfl_model( ) converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path) + converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH) converter._experimental_enable_composite_direct_lowering = True + converter.model_origin_framework = "PYTORCH" conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags) diff --git a/ai_edge_torch/lowertools/test_utils.py b/ai_edge_torch/lowertools/test_utils.py new file mode 100644 index 00000000..2558a2b7 --- /dev/null +++ b/ai_edge_torch/lowertools/test_utils.py @@ -0,0 +1,60 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import re +from typing import Optional +from ai_edge_torch import config +from tensorflow.python.platform import googletest + + +def _extract_backend_configs(mlir): + mlir = mlir.replace("\\22", '"') + configs = [] + for match in re.finditer(r"backend_config\s*=\s*\"(\{.*\})\"", mlir): + configs.append(match.group(1)) + return "\n".join(configs) + + +def assert_string_count( + test_case: googletest.TestCase, + mlir: str, + torch_xla_pattern_counter: dict[str, int], + odml_torch_pattern_counter: dict[str, int], + odml_torch_attr_counter: Optional[dict[str, int]] = None, +): + + if odml_torch_attr_counter is None: + odml_torch_attr_counter = {} + + if config.Config.use_torch_xla: + for key in torch_xla_pattern_counter: + test_case.assertEqual( + mlir.count(key), + torch_xla_pattern_counter[key], + ) + else: + for key in odml_torch_pattern_counter: + test_case.assertEqual( + mlir.count(key), + odml_torch_pattern_counter[key], + ) + backend_configs = _extract_backend_configs(mlir) + print("backend_configs:") + print(backend_configs) + for key in odml_torch_attr_counter: + test_case.assertEqual( + backend_configs.count(key), + odml_torch_attr_counter[key], + )