From d7c9e8d80a56612760db605ea2011f78743eb034 Mon Sep 17 00:00:00 2001 From: Vamsi Manchala Date: Wed, 13 Nov 2024 13:53:04 -0800 Subject: [PATCH] Add is_nchw_op to the pattern of build_interpolate_composite_pass. Change existing pytorch composites to unify the upsample-bilinear composites from JAX and PyTorch. PiperOrigin-RevId: 696259412 --- .../build_interpolate_composite_pass.py | 2 ++ ...uild_upsample_bilinear2d_composite_pass.py | 36 +++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py b/ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py index 09e748ea..fb790e08 100644 --- a/ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py @@ -51,6 +51,7 @@ def attr_builder(pattern, graph_module, internal_match): return { "output": (int(output_h), int(output_w)), "align_corners": False, + "is_nchw_op": True, } return pattern @@ -74,6 +75,7 @@ def attr_builder(graph_module, pattern, internal_match): return { "output": (int(output_h), int(output_w)), "align_corners": True, + "is_nchw_op": True, } return pattern diff --git a/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py b/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py index 9dd3a89e..41b3e503 100644 --- a/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/test/test_build_upsample_bilinear2d_composite_pass.py @@ -63,12 +63,12 @@ def test_nn_functional_upsample_bilinear(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [30, 30], "align_corners": false}': 1}, + {'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1}, ) def test_nn_functional_upsample_bilinear_align_corners(self): @@ -84,12 +84,12 @@ def test_nn_functional_upsample_bilinear_align_corners(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [30, 30], "align_corners": true}': 1}, + {'{"output": [30, 30], "align_corners": true, "is_nchw_op": true}': 1}, ) def test_nn_functional_upsample_bilinear_size(self): @@ -105,12 +105,12 @@ def test_nn_functional_upsample_bilinear_size(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [15, 20], "align_corners": false}': 1}, + {'{"output": [15, 20], "align_corners": false, "is_nchw_op": true}': 1}, ) def test_nn_functional_upsample_bilinear_size_align_corners(self): @@ -125,12 +125,12 @@ def test_nn_functional_upsample_bilinear_size_align_corners(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [15, 20], "align_corners": true}': 1}, + {'{"output": [15, 20], "align_corners": true, "is_nchw_op": true}': 1}, ) def test_nn_upsample_bilinear(self): @@ -143,12 +143,12 @@ def test_nn_upsample_bilinear(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [30, 30], "align_corners": false}': 1}, + {'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_bilinear(self): @@ -163,12 +163,12 @@ def test_nn_functional_interpolate_bilinear(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [30, 30], "align_corners": false}': 1}, + {'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_bilinear_align_corners(self): @@ -183,12 +183,12 @@ def test_nn_functional_interpolate_bilinear_align_corners(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [30, 30], "align_corners": true}': 1}, + {'{"output": [30, 30], "align_corners": true, "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_bilinear_size(self): @@ -203,12 +203,12 @@ def test_nn_functional_interpolate_bilinear_size(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [15, 20], "align_corners": false}': 1}, + {'{"output": [15, 20], "align_corners": false, "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_bilinear_size_align_corners(self): @@ -223,12 +223,12 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self): stablehlo, { 'stablehlo.composite "odml.upsample_bilinear2d"': 1, - 'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': ( + 'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': ( 1 ), }, {'stablehlo.custom_call @mark_tensor': 2}, - {'{"output": [15, 20], "align_corners": true}': 1}, + {'{"output": [15, 20], "align_corners": true, "is_nchw_op": true}': 1}, ) def test_nn_functional_interpolate_nearest(self):