Skip to content

Commit

Permalink
Add is_nchw_op to the pattern of build_interpolate_composite_pass. Ch…
Browse files Browse the repository at this point in the history
…ange existing pytorch composites to unify the upsample-bilinear composites from JAX and PyTorch.

PiperOrigin-RevId: 696259412
  • Loading branch information
vamsimanchala authored and copybara-github committed Nov 13, 2024
1 parent 956cfc9 commit d7c9e8d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def attr_builder(pattern, graph_module, internal_match):
return {
"output": (int(output_h), int(output_w)),
"align_corners": False,
"is_nchw_op": True,
}

return pattern
Expand All @@ -74,6 +75,7 @@ def attr_builder(graph_module, pattern, internal_match):
return {
"output": (int(output_h), int(output_w)),
"align_corners": True,
"is_nchw_op": True,
}

return pattern
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def test_nn_functional_upsample_bilinear(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': (
'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [30, 30], "align_corners": false}': 1},
{'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1},
)

def test_nn_functional_upsample_bilinear_align_corners(self):
Expand All @@ -84,12 +84,12 @@ def test_nn_functional_upsample_bilinear_align_corners(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': (
'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [30, 30], "align_corners": true}': 1},
{'{"output": [30, 30], "align_corners": true, "is_nchw_op": true}': 1},
)

def test_nn_functional_upsample_bilinear_size(self):
Expand All @@ -105,12 +105,12 @@ def test_nn_functional_upsample_bilinear_size(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': (
'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [15, 20], "align_corners": false}': 1},
{'{"output": [15, 20], "align_corners": false, "is_nchw_op": true}': 1},
)

def test_nn_functional_upsample_bilinear_size_align_corners(self):
Expand All @@ -125,12 +125,12 @@ def test_nn_functional_upsample_bilinear_size_align_corners(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': (
'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [15, 20], "align_corners": true}': 1},
{'{"output": [15, 20], "align_corners": true, "is_nchw_op": true}': 1},
)

def test_nn_upsample_bilinear(self):
Expand All @@ -143,12 +143,12 @@ def test_nn_upsample_bilinear(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': (
'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [30, 30], "align_corners": false}': 1},
{'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1},
)

def test_nn_functional_interpolate_bilinear(self):
Expand All @@ -163,12 +163,12 @@ def test_nn_functional_interpolate_bilinear(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = false, output = dense<30> : tensor<2xi64>}': (
'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [30, 30], "align_corners": false}': 1},
{'{"output": [30, 30], "align_corners": false, "is_nchw_op": true}': 1},
)

def test_nn_functional_interpolate_bilinear_align_corners(self):
Expand All @@ -183,12 +183,12 @@ def test_nn_functional_interpolate_bilinear_align_corners(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = true, output = dense<30> : tensor<2xi64>}': (
'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<30> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [30, 30], "align_corners": true}': 1},
{'{"output": [30, 30], "align_corners": true, "is_nchw_op": true}': 1},
)

def test_nn_functional_interpolate_bilinear_size(self):
Expand All @@ -203,12 +203,12 @@ def test_nn_functional_interpolate_bilinear_size(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = false, output = dense<[15, 20]> : tensor<2xi64>}': (
'composite_attributes = {align_corners = false, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [15, 20], "align_corners": false}': 1},
{'{"output": [15, 20], "align_corners": false, "is_nchw_op": true}': 1},
)

def test_nn_functional_interpolate_bilinear_size_align_corners(self):
Expand All @@ -223,12 +223,12 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self):
stablehlo,
{
'stablehlo.composite "odml.upsample_bilinear2d"': 1,
'composite_attributes = {align_corners = true, output = dense<[15, 20]> : tensor<2xi64>}': (
'composite_attributes = {align_corners = true, is_nchw_op = true, output = dense<[15, 20]> : tensor<2xi64>}': (
1
),
},
{'stablehlo.custom_call @mark_tensor': 2},
{'{"output": [15, 20], "align_corners": true}': 1},
{'{"output": [15, 20], "align_corners": true, "is_nchw_op": true}': 1},
)

def test_nn_functional_interpolate_nearest(self):
Expand Down

0 comments on commit d7c9e8d

Please sign in to comment.