From 08020a48b25e48b8ff663ccd11a0b4ea71a42e47 Mon Sep 17 00:00:00 2001 From: vamsimanchala Date: Fri, 7 Jun 2024 23:35:44 +0000 Subject: [PATCH 1/2] Add interpolate nearest composite support --- ai_edge_torch/convert/conversion.py | 4 +-- ai_edge_torch/convert/fx_passes/__init__.py | 2 +- ...py => build_interpolate_composite_pass.py} | 23 ++++++++++++- ...uild_upsample_bilinear2d_composite_pass.py | 33 +++++++++++++++++-- 4 files changed, 56 insertions(+), 6 deletions(-) rename ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py => build_interpolate_composite_pass.py} (78%) diff --git a/ai_edge_torch/convert/conversion.py b/ai_edge_torch/convert/conversion.py index 79618219..2318b6d4 100644 --- a/ai_edge_torch/convert/conversion.py +++ b/ai_edge_torch/convert/conversion.py @@ -25,7 +25,7 @@ from ai_edge_torch import model from ai_edge_torch.convert import conversion_utils as cutils from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass -from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes import CanonicalizePass from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass @@ -41,7 +41,7 @@ def _run_convert_passes( return run_passes( exported_program, [ - BuildUpsampleBilinear2DCompositePass(), + BuildInterpolateCompositePass(), CanonicalizePass(), OptimizeLayoutTransposesPass(), CanonicalizePass(), diff --git a/ai_edge_torch/convert/fx_passes/__init__.py b/ai_edge_torch/convert/fx_passes/__init__.py index 31d40a24..faa060b1 100644 --- a/ai_edge_torch/convert/fx_passes/__init__.py +++ b/ai_edge_torch/convert/fx_passes/__init__.py @@ -24,7 +24,7 @@ from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA -from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA diff --git a/ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py similarity index 78% rename from ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py rename to ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py index f812f882..057501f2 100644 --- a/ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py @@ -65,14 +65,35 @@ def attr_builder(graph_module, pattern, internal_match): return pattern +@functools.cache +def _get_interpolate_nearest2d_pattern(): + pattern = mark_pattern.Pattern( + "tfl.resize_nearest_neighbor", + lambda x: torch.nn.functional.interpolate( + x, scale_factor=2, mode="nearest" + ), + export_args=(torch.rand(1, 3, 100, 100),), + ) + + @pattern.register_attr_builder + def attr_builder(pattern, graph_module, internal_match): + output = internal_match.returning_nodes[0] + output_h, output_w = output.meta["val"].shape[-2:] + return { + "size": (int(output_h), int(output_w)), + "is_nchw_op": True, + } + + return pattern -class BuildUpsampleBilinear2DCompositePass(FxPassBase): +class BuildInterpolateCompositePass(FxPassBase): def __init__(self): super().__init__() self._patterns = [ _get_upsample_bilinear2d_pattern(), _get_upsample_bilinear2d_align_corners_pattern(), + _get_interpolate_nearest2d_pattern(), ] def call(self, graph_module: torch.fx.GraphModule): diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py index 17185141..cd7b1d04 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py @@ -19,7 +19,7 @@ import torch import torch_xla -from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes import run_passes @@ -39,7 +39,7 @@ def forward(self, *args, **kwargs): exported_program = torch.export.export(module, export_args) exported_program = run_passes( - exported_program, [BuildUpsampleBilinear2DCompositePass()] + exported_program, [BuildInterpolateCompositePass()] ) return torch_xla.stablehlo.exported_program_to_stablehlo( @@ -192,6 +192,35 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self): 1, ) + def test_nn_functional_interpolate_nearest(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.functional.interpolate(x, scale_factor=3.0, mode='nearest'), + (torch.rand(1, 3, 10, 10),), + ) + self.assertTrue( + stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 + ) + self.assertTrue( + stablehlo.count( + 'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}' + ), + 1, + ) + + def test_nn_functional_interpolate_nearest_size(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.functional.interpolate(x, size=[15, 20], mode='nearest'), + (torch.rand(1, 3, 10, 10),), + ) + self.assertTrue( + stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 + ) + self.assertTrue( + stablehlo.count( + 'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}' + ), + 1, + ) if __name__ == '__main__': unittest.main() From eaf0abd615548fb7af6fbce42952cc1af7734509 Mon Sep 17 00:00:00 2001 From: vamsimanchala Date: Fri, 7 Jun 2024 23:40:40 +0000 Subject: [PATCH 2/2] fix fmt --- .../convert/fx_passes/build_interpolate_composite_pass.py | 6 +++--- .../test/test_build_upsample_bilinear2d_composite_pass.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py index 057501f2..3f1afde7 100644 --- a/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py @@ -65,13 +65,12 @@ def attr_builder(graph_module, pattern, internal_match): return pattern + @functools.cache def _get_interpolate_nearest2d_pattern(): pattern = mark_pattern.Pattern( "tfl.resize_nearest_neighbor", - lambda x: torch.nn.functional.interpolate( - x, scale_factor=2, mode="nearest" - ), + lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"), export_args=(torch.rand(1, 3, 100, 100),), ) @@ -86,6 +85,7 @@ def attr_builder(pattern, graph_module, internal_match): return pattern + class BuildInterpolateCompositePass(FxPassBase): def __init__(self): diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py index cd7b1d04..4b0eee47 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py @@ -38,9 +38,7 @@ def forward(self, *args, **kwargs): module = func exported_program = torch.export.export(module, export_args) - exported_program = run_passes( - exported_program, [BuildInterpolateCompositePass()] - ) + exported_program = run_passes(exported_program, [BuildInterpolateCompositePass()]) return torch_xla.stablehlo.exported_program_to_stablehlo( exported_program @@ -222,5 +220,6 @@ def test_nn_functional_interpolate_nearest_size(self): 1, ) + if __name__ == '__main__': unittest.main()