Skip to content

Commit

Permalink
fix fx pattern matcher for bilinear upsample
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713153511
  • Loading branch information
chunnienc authored and copybara-github committed Jan 8, 2025
1 parent 6d285bb commit fd1bde3
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
26 changes: 19 additions & 7 deletions ai_edge_torch/hlfb/mark_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid

from ai_edge_torch import lowertools
from ai_edge_torch.hlfb.mark_pattern import passes
from ai_edge_torch.hlfb.mark_pattern import fx_utils
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
import torch

Expand Down Expand Up @@ -87,7 +87,7 @@ def mark_pattern(
m.meta["ORIGINAL_NODE"] = n

# Sanitize graph_module to match in the same way as pattern's graph_module.
graph_module_to_match = passes.remove_clone_ops(graph_module_to_match)
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)

match_with_attrs = pattern.match(graph_module_to_match)

Expand All @@ -111,13 +111,25 @@ def mark_pattern(
is_input=True,
)

# Only replace input by the marker node for those nodes used in the pattern.
# Only replace input by the marker node for those nodes used in the
# pattern.
in_pattern_nodes = set(match.nodes_map.values())
for user in input_node.users.keys():
if user in in_pattern_nodes:
user.meta["ORIGINAL_NODE"].replace_input_with(
input_node.meta["ORIGINAL_NODE"], new_input_node
)
if user not in in_pattern_nodes:
continue

user.meta["ORIGINAL_NODE"].replace_input_with(
input_node.meta["ORIGINAL_NODE"], new_input_node
)
# Pattern matching graph sanitization may remove clone ops, which means
# the user's input in the original graph may be a clone op. When
# replacing the input with the marker node, we need to further try
# replacing the input of the clone op that connects to the user.
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
if fx_utils.is_clone_op(original_user_input):
original_user_input.replace_input_with(
input_node.meta["ORIGINAL_NODE"], new_input_node
)

for i, pattern_output_node in enumerate(pattern.output_nodes):
output_node = match.nodes_map[pattern_output_node]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Passes to clean up the model graph for pattern matching."""
"""FX graph utilities for pattern matching clean ups."""

import torch


def is_clone_op(node: torch.fx.Node) -> bool:
"""Checks if the node is a clone op."""
return (
node.op == "call_function" and node.target == torch.ops.aten.clone.default
)


def remove_clone_ops(gm: torch.fx.GraphModule):
"""Removes clone ops from the graph.
Expand All @@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
The graph module with clone ops removed.
"""
for node in gm.graph.nodes:
if node.op == "call_function" and node.name.startswith("clone"):
if is_clone_op(node):
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)

Expand Down
17 changes: 9 additions & 8 deletions ai_edge_torch/hlfb/mark_pattern/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from typing import Any, Callable, Optional, Union

from ai_edge_torch import fx_pass_base
from ai_edge_torch.hlfb.mark_pattern import passes
from ai_edge_torch.hlfb.mark_pattern import fx_utils
import torch
from torch.export.graph_signature import TensorArgument
from torch.fx import Graph
from torch.fx import GraphModule
from torch.fx.passes.utils.matcher_utils import InternalMatch
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

Graph = torch.fx.Graph
GraphModule = torch.fx.GraphModule
TensorArgument = torch.export.graph_signature.TensorArgument
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher


def _are_equal(x: Any, y: Any) -> bool:
Expand Down Expand Up @@ -219,8 +220,8 @@ def forward(self, *args, **kwargs):
# Sanitize graph_module for more precise pattern matching.
# The graph_module to match against this pattern should apply equivalent
# sanitization.
self.graph_module = passes.remove_clone_ops(self.graph_module)
self.graph_module = passes.remove_dangling_args(self.graph_module)
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)

# Builds list of ordered input and output nodes.
self.graph_nodes_map = {}
Expand Down
26 changes: 26 additions & 0 deletions ai_edge_torch/hlfb/test/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ def forward(self, x):
{"stablehlo.custom_call @mark_tensor": 6},
)

def test_mark_pattern_with_clone_inputs(self):

class TestModel(torch.nn.Module):

def forward(self, x):
return torch.ops.aten.clone.default(x * x) + x

pattern = pattern_module.Pattern(
"test.add",
lambda a, b: a + b,
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
)

model = TestModel().eval()
args = (torch.rand(20, 20),)
exported_program = torch.export.export(model, args)
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
mlir = _export_stablehlo_mlir(exported_program)

lowertools.assert_string_count(
self,
mlir,
{'stablehlo.composite "test.add"': 1},
{"stablehlo.custom_call @mark_tensor": 3},
)

def test_mark_pattern_with_attr_builder(self):
class TestModel(torch.nn.Module):

Expand Down

0 comments on commit fd1bde3

Please sign in to comment.