Skip to content

Commit

Permalink
Fix aten.native_group_norm NHWC rewriter (#100)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
chunnienc authored Jul 23, 2024
1 parent 4cfd377 commit 53fa236
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def is_4d(node: Node):
val = node.meta.get("val")
if val is None:
return False

if isinstance(val, (list, tuple)) and val:
val = val[0]

if not hasattr(val, "shape"):
return False

Expand Down Expand Up @@ -168,14 +172,25 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):


@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
@nhwcable_node_checkers.register(aten.native_group_norm)
def _aten_norm_checker(node):
val = node.meta.get("val")
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
return NHWCable(can_be=False, must_be=False)
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)


@nhwcable_node_checkers.register(aten.native_group_norm)
def _aten_native_group_norm_checker(node):
val = node.meta.get("val")
if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
return NHWCable(can_be=False, must_be=False)
if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
# Disable NHWC rewriter due to precision issue with weight and bias.
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
return NHWCable(can_be=False, must_be=False)
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)


# ==== Ops must be NCHW


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import operator
import os
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -274,6 +275,14 @@ def call(self, exported_program: torch.export.ExportedProgram):
graph_module = layout_partitioners.greedy.partition(graph_module)

graph = graph_module.graph
for node in list(graph.nodes):
if node.target == operator.getitem:
# force the layout mark of a getitem node to follow its producer.
if layout_mark.is_nchw_node(node.args[0]):
layout_mark.mark_as_nchw_node(node)
else:
layout_mark.mark_as_nhwc_node(node)

for node in list(graph.nodes):
if layout_mark.is_nhwc_node(node):
for input_node in layout_check.get_layout_sensitive_inputs(node):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,67 @@ def test_torchvision_resnet18(self):
exported_program = export_with_pass(model, forward_args())
self.assert_outputs_allclose(model, exported_program.module(), forward_args())

# TODO(cnchan): Add more tests.
def test_native_group_norm_no_weight_bias(self):
batch_size = 16
num_channels = 640
flattened_inner_size = 256
num_groups = 32
eps = 1e-6

class SampleModel(torch.nn.Module):

def forward(self, x):
x = torch.nn.AvgPool2d(2)(x)
x = torch.ops.aten.native_group_norm(
x,
None,
None,
batch_size,
num_channels,
flattened_inner_size,
num_groups,
eps,
)[0]
x = torch.nn.AvgPool2d(2)(x)
return x

model = SampleModel().eval()
forward_args = lambda: (torch.rand(16, 640, 32, 32) * 1000,)
exported_program = export_with_pass(model, forward_args())
self.assert_outputs_allclose(model, exported_program.module(), forward_args())

def test_native_group_norm_large_weight_bias(self):
batch_size = 16
num_channels = 640
flattened_inner_size = 256
num_groups = 32
eps = 1e-6

class SampleModel(torch.nn.Module):

def forward(self, x, weight, bias):
x = torch.nn.AvgPool2d(2)(x)
x = torch.ops.aten.native_group_norm(
x,
weight,
bias,
batch_size,
num_channels,
flattened_inner_size,
num_groups,
eps,
)[0]
x = torch.nn.AvgPool2d(2)(x)
return x

model = SampleModel().eval()
forward_args = lambda: (
torch.rand(16, 640, 32, 32) * 1000,
torch.rand([640]) * 1000,
torch.rand([640]) * 1000,
)
exported_program = export_with_pass(model, forward_args())
self.assert_outputs_allclose(model, exported_program.module(), forward_args())


if __name__ == '__main__':
Expand Down

0 comments on commit 53fa236

Please sign in to comment.