Skip to content

Commit

Permalink
Add interpolate nearest composite support (#41)
Browse files Browse the repository at this point in the history
* Add interpolate nearest composite support

* fix fmt
  • Loading branch information
vamsimanchala authored Jun 8, 2024
1 parent db77cc4 commit fe2afd5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
4 changes: 2 additions & 2 deletions ai_edge_torch/convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ai_edge_torch import model
from ai_edge_torch.convert import conversion_utils as cutils
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import CanonicalizePass
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
Expand All @@ -41,7 +41,7 @@ def _run_convert_passes(
return run_passes(
exported_program,
[
BuildUpsampleBilinear2DCompositePass(),
BuildInterpolateCompositePass(),
CanonicalizePass(),
OptimizeLayoutTransposesPass(),
CanonicalizePass(),
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/convert/fx_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,34 @@ def attr_builder(graph_module, pattern, internal_match):
return pattern


class BuildUpsampleBilinear2DCompositePass(FxPassBase):
@functools.cache
def _get_interpolate_nearest2d_pattern():
pattern = mark_pattern.Pattern(
"tfl.resize_nearest_neighbor",
lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
export_args=(torch.rand(1, 3, 100, 100),),
)

@pattern.register_attr_builder
def attr_builder(pattern, graph_module, internal_match):
output = internal_match.returning_nodes[0]
output_h, output_w = output.meta["val"].shape[-2:]
return {
"size": (int(output_h), int(output_w)),
"is_nchw_op": True,
}

return pattern


class BuildInterpolateCompositePass(FxPassBase):

def __init__(self):
super().__init__()
self._patterns = [
_get_upsample_bilinear2d_pattern(),
_get_upsample_bilinear2d_align_corners_pattern(),
_get_interpolate_nearest2d_pattern(),
]

def call(self, graph_module: torch.fx.GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torch_xla

from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import run_passes


Expand All @@ -38,9 +38,7 @@ def forward(self, *args, **kwargs):
module = func

exported_program = torch.export.export(module, export_args)
exported_program = run_passes(
exported_program, [BuildUpsampleBilinear2DCompositePass()]
)
exported_program = run_passes(exported_program, [BuildInterpolateCompositePass()])

return torch_xla.stablehlo.exported_program_to_stablehlo(
exported_program
Expand Down Expand Up @@ -192,6 +190,36 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self):
1,
)

def test_nn_functional_interpolate_nearest(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.functional.interpolate(x, scale_factor=3.0, mode='nearest'),
(torch.rand(1, 3, 10, 10),),
)
self.assertTrue(
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
)
self.assertTrue(
stablehlo.count(
'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}'
),
1,
)

def test_nn_functional_interpolate_nearest_size(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.functional.interpolate(x, size=[15, 20], mode='nearest'),
(torch.rand(1, 3, 10, 10),),
)
self.assertTrue(
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
)
self.assertTrue(
stablehlo.count(
'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}'
),
1,
)


if __name__ == '__main__':
unittest.main()

0 comments on commit fe2afd5

Please sign in to comment.