From fe2afd569f947fefe2efa022ee06105bc1b1f534 Mon Sep 17 00:00:00 2001 From: Vamsi Krishna Manchala <110555921+vamsimanchala@users.noreply.github.com> Date: Sat, 8 Jun 2024 00:24:13 +0000 Subject: [PATCH] Add interpolate nearest composite support (#41) * Add interpolate nearest composite support * fix fmt --- ai_edge_torch/convert/conversion.py | 4 +-- ai_edge_torch/convert/fx_passes/__init__.py | 2 +- ...py => build_interpolate_composite_pass.py} | 23 +++++++++++- ...uild_upsample_bilinear2d_composite_pass.py | 36 ++++++++++++++++--- 4 files changed, 57 insertions(+), 8 deletions(-) rename ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py => build_interpolate_composite_pass.py} (79%) diff --git a/ai_edge_torch/convert/conversion.py b/ai_edge_torch/convert/conversion.py index 79618219..2318b6d4 100644 --- a/ai_edge_torch/convert/conversion.py +++ b/ai_edge_torch/convert/conversion.py @@ -25,7 +25,7 @@ from ai_edge_torch import model from ai_edge_torch.convert import conversion_utils as cutils from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass -from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes import CanonicalizePass from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass @@ -41,7 +41,7 @@ def _run_convert_passes( return run_passes( exported_program, [ - BuildUpsampleBilinear2DCompositePass(), + BuildInterpolateCompositePass(), CanonicalizePass(), OptimizeLayoutTransposesPass(), CanonicalizePass(), diff --git a/ai_edge_torch/convert/fx_passes/__init__.py b/ai_edge_torch/convert/fx_passes/__init__.py index 31d40a24..faa060b1 100644 --- a/ai_edge_torch/convert/fx_passes/__init__.py +++ b/ai_edge_torch/convert/fx_passes/__init__.py @@ -24,7 +24,7 @@ from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA -from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA diff --git a/ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py similarity index 79% rename from ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py rename to ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py index f812f882..3f1afde7 100644 --- a/ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py @@ -66,13 +66,34 @@ def attr_builder(graph_module, pattern, internal_match): return pattern -class BuildUpsampleBilinear2DCompositePass(FxPassBase): +@functools.cache +def _get_interpolate_nearest2d_pattern(): + pattern = mark_pattern.Pattern( + "tfl.resize_nearest_neighbor", + lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"), + export_args=(torch.rand(1, 3, 100, 100),), + ) + + @pattern.register_attr_builder + def attr_builder(pattern, graph_module, internal_match): + output = internal_match.returning_nodes[0] + output_h, output_w = output.meta["val"].shape[-2:] + return { + "size": (int(output_h), int(output_w)), + "is_nchw_op": True, + } + + return pattern + + +class BuildInterpolateCompositePass(FxPassBase): def __init__(self): super().__init__() self._patterns = [ _get_upsample_bilinear2d_pattern(), _get_upsample_bilinear2d_align_corners_pattern(), + _get_interpolate_nearest2d_pattern(), ] def call(self, graph_module: torch.fx.GraphModule): diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py index 17185141..4b0eee47 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py @@ -19,7 +19,7 @@ import torch import torch_xla -from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA +from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA from ai_edge_torch.convert.fx_passes import run_passes @@ -38,9 +38,7 @@ def forward(self, *args, **kwargs): module = func exported_program = torch.export.export(module, export_args) - exported_program = run_passes( - exported_program, [BuildUpsampleBilinear2DCompositePass()] - ) + exported_program = run_passes(exported_program, [BuildInterpolateCompositePass()]) return torch_xla.stablehlo.exported_program_to_stablehlo( exported_program @@ -192,6 +190,36 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self): 1, ) + def test_nn_functional_interpolate_nearest(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.functional.interpolate(x, scale_factor=3.0, mode='nearest'), + (torch.rand(1, 3, 10, 10),), + ) + self.assertTrue( + stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 + ) + self.assertTrue( + stablehlo.count( + 'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}' + ), + 1, + ) + + def test_nn_functional_interpolate_nearest_size(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.functional.interpolate(x, size=[15, 20], mode='nearest'), + (torch.rand(1, 3, 10, 10),), + ) + self.assertTrue( + stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1 + ) + self.assertTrue( + stablehlo.count( + 'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}' + ), + 1, + ) + if __name__ == '__main__': unittest.main()