Skip to content

Commit

Permalink
Wrap GELU in composite and lower to tfl.gelu (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
majiddadashi authored Jun 4, 2024
1 parent 25c764a commit 475607a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
30 changes: 30 additions & 0 deletions ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ def hardswish(self: torch.Tensor):
node.target = hardswish


@_register_composite_builder(torch.ops.aten.gelu.default)
def _aten_gelu(gm: GraphModule, node: Node):
op = node.target
args_mapper = TorchOpArgumentsMapper(op)

def gelu(*args, **kwargs):
nonlocal op, args_mapper

full_kwargs = args_mapper.get_full_kwargs(args, kwargs)

# TFLite supports exact and tanh approximate.
if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
return op(*args, **kwargs)

builder = StableHLOCompositeBuilder(
"aten.gelu.default",
attr=_tree_map_to_composite_attr_values(
{
"approximate": full_kwargs["approximate"],
}
),
)
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
output = op(full_kwargs["self"])
output = builder.mark_outputs(output)
return output

node.target = gelu


@_register_composite_builder(torch.ops.aten.avg_pool2d.default)
def _aten_avg_pool2d(gm: GraphModule, node: Node):
op = node.target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def test_hardswish_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.Hardswish()(x), (torch.rand(10, 10),)
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1)
self.assertEqual(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1)

def test_hardswish_op(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.ops.aten.hardswish.default(x), (torch.rand(10, 10),)
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1)
self.assertEqual(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1)

def test_avg_pool2d_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
Expand All @@ -78,7 +78,9 @@ def test_avg_pool2d_layer(self):
)(x),
(torch.rand(1, 3, 6, 6),),
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
)

def test_avg_pool2d_op(self):
stablehlo = _export_to_stablehlo_with_composite(
Expand All @@ -93,7 +95,9 @@ def test_avg_pool2d_op(self):
),
(torch.rand(1, 3, 6, 6),),
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
)

def test_avg_pool2d_ceil_mode(self):
stablehlo = _export_to_stablehlo_with_composite(
Expand All @@ -108,7 +112,21 @@ def test_avg_pool2d_ceil_mode(self):
),
(torch.rand(1, 3, 6, 6),),
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
)

def test_gelu_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.GELU()(x), (torch.rand(10, 10),)
)
self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1)

def test_approximate_gelu_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.GELU('tanh')(x), (torch.rand(10, 10),)
)
self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1)


if __name__ == '__main__':
Expand Down
18 changes: 18 additions & 0 deletions ai_edge_torch/convert/test/test_convert_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
)

def test_convert_gelu(self):
"""Tests conversion of a GELU module."""

args = (torch.randn((5, 10)),)
torch_module = torch.nn.GELU().eval()
edge_model = ai_edge_torch.convert(torch_module, args)

self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))

def test_convert_gelu_approximate(self):
"""Tests conversion of an Approximate GELU module."""

args = (torch.randn((5, 10)),)
torch_module = torch.nn.GELU('tanh').eval()
edge_model = ai_edge_torch.convert(torch_module, args)

self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy
tabulate
safetensors
--pre
tf-nightly>=2.17.0.dev20240529
tf-nightly-cpu>=2.17.0.dev20240604
-f https://download.pytorch.org/whl/nightly/torch_nightly.html
torch==2.4.0.dev20240429+cpu
-f https://download.pytorch.org/whl/nightly/torch_nightly.html
Expand Down

0 comments on commit 475607a

Please sign in to comment.