Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpolate nearest composite support #41

Merged
merged 2 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ai_edge_torch/convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ai_edge_torch import model
from ai_edge_torch.convert import conversion_utils as cutils
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import CanonicalizePass
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
Expand All @@ -41,7 +41,7 @@ def _run_convert_passes(
return run_passes(
exported_program,
[
BuildUpsampleBilinear2DCompositePass(),
BuildInterpolateCompositePass(),
CanonicalizePass(),
OptimizeLayoutTransposesPass(),
CanonicalizePass(),
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/convert/fx_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,34 @@ def attr_builder(graph_module, pattern, internal_match):
return pattern


class BuildUpsampleBilinear2DCompositePass(FxPassBase):
@functools.cache
def _get_interpolate_nearest2d_pattern():
pattern = mark_pattern.Pattern(
"tfl.resize_nearest_neighbor",
lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
export_args=(torch.rand(1, 3, 100, 100),),
)

@pattern.register_attr_builder
def attr_builder(pattern, graph_module, internal_match):
output = internal_match.returning_nodes[0]
output_h, output_w = output.meta["val"].shape[-2:]
return {
"size": (int(output_h), int(output_w)),
"is_nchw_op": True,
}

return pattern


class BuildInterpolateCompositePass(FxPassBase):

def __init__(self):
super().__init__()
self._patterns = [
_get_upsample_bilinear2d_pattern(),
_get_upsample_bilinear2d_align_corners_pattern(),
_get_interpolate_nearest2d_pattern(),
]

def call(self, graph_module: torch.fx.GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torch_xla

from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
from ai_edge_torch.convert.fx_passes import run_passes


Expand All @@ -38,9 +38,7 @@ def forward(self, *args, **kwargs):
module = func

exported_program = torch.export.export(module, export_args)
exported_program = run_passes(
exported_program, [BuildUpsampleBilinear2DCompositePass()]
)
exported_program = run_passes(exported_program, [BuildInterpolateCompositePass()])

return torch_xla.stablehlo.exported_program_to_stablehlo(
exported_program
Expand Down Expand Up @@ -192,6 +190,36 @@ def test_nn_functional_interpolate_bilinear_size_align_corners(self):
1,
)

def test_nn_functional_interpolate_nearest(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.functional.interpolate(x, scale_factor=3.0, mode='nearest'),
(torch.rand(1, 3, 10, 10),),
)
self.assertTrue(
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
)
self.assertTrue(
stablehlo.count(
'composite_attributes = {is_nchw_op = true, size = dense<30> : tensor<2xi64>}'
),
1,
)

def test_nn_functional_interpolate_nearest_size(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.functional.interpolate(x, size=[15, 20], mode='nearest'),
(torch.rand(1, 3, 10, 10),),
)
self.assertTrue(
stablehlo.count('stablehlo.composite "tfl.resize_nearest_neighbor"'), 1
)
self.assertTrue(
stablehlo.count(
'composite_attributes = {is_nchw_op = true, size = dense<[15, 20]> : tensor<2xi64>}'
),
1,
)


if __name__ == '__main__':
unittest.main()
Loading