diff --git a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py index 712a87e2..1e20f7cb 100644 --- a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py @@ -99,6 +99,36 @@ def hardswish(self: torch.Tensor): node.target = hardswish +@_register_composite_builder(torch.ops.aten.gelu.default) +def _aten_gelu(gm: GraphModule, node: Node): + op = node.target + args_mapper = TorchOpArgumentsMapper(op) + + def gelu(*args, **kwargs): + nonlocal op, args_mapper + + full_kwargs = args_mapper.get_full_kwargs(args, kwargs) + + # TFLite supports exact and tanh approximate. + if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh": + return op(*args, **kwargs) + + builder = StableHLOCompositeBuilder( + "aten.gelu.default", + attr=_tree_map_to_composite_attr_values( + { + "approximate": full_kwargs["approximate"], + } + ), + ) + full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"]) + output = op(full_kwargs["self"]) + output = builder.mark_outputs(output) + return output + + node.target = gelu + + @_register_composite_builder(torch.ops.aten.avg_pool2d.default) def _aten_avg_pool2d(gm: GraphModule, node: Node): op = node.target diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py index 90041f5c..c809d735 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py @@ -58,13 +58,13 @@ def test_hardswish_layer(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.nn.Hardswish()(x), (torch.rand(10, 10),) ) - self.assertTrue(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1) + self.assertEqual(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1) def test_hardswish_op(self): stablehlo = _export_to_stablehlo_with_composite( lambda x: torch.ops.aten.hardswish.default(x), (torch.rand(10, 10),) ) - self.assertTrue(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1) + self.assertEqual(stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1) def test_avg_pool2d_layer(self): stablehlo = _export_to_stablehlo_with_composite( @@ -78,7 +78,9 @@ def test_avg_pool2d_layer(self): )(x), (torch.rand(1, 3, 6, 6),), ) - self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1) + self.assertEqual( + stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + ) def test_avg_pool2d_op(self): stablehlo = _export_to_stablehlo_with_composite( @@ -93,7 +95,9 @@ def test_avg_pool2d_op(self): ), (torch.rand(1, 3, 6, 6),), ) - self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1) + self.assertEqual( + stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + ) def test_avg_pool2d_ceil_mode(self): stablehlo = _export_to_stablehlo_with_composite( @@ -108,7 +112,21 @@ def test_avg_pool2d_ceil_mode(self): ), (torch.rand(1, 3, 6, 6),), ) - self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1) + self.assertEqual( + stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1 + ) + + def test_gelu_layer(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.GELU()(x), (torch.rand(10, 10),) + ) + self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1) + + def test_approximate_gelu_layer(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.GELU('tanh')(x), (torch.rand(10, 10),) + ) + self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1) if __name__ == '__main__': diff --git a/ai_edge_torch/convert/test/test_convert_composites.py b/ai_edge_torch/convert/test/test_convert_composites.py index 08456d5e..796d3640 100644 --- a/ai_edge_torch/convert/test/test_convert_composites.py +++ b/ai_edge_torch/convert/test/test_convert_composites.py @@ -169,6 +169,24 @@ def test_convert_interpolate_bilinear_functional(self, input_size, kwargs): model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args) ) + def test_convert_gelu(self): + """Tests conversion of a GELU module.""" + + args = (torch.randn((5, 10)),) + torch_module = torch.nn.GELU().eval() + edge_model = ai_edge_torch.convert(torch_module, args) + + self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args)) + + def test_convert_gelu_approximate(self): + """Tests conversion of an Approximate GELU module.""" + + args = (torch.randn((5, 10)),) + torch_module = torch.nn.GELU('tanh').eval() + edge_model = ai_edge_torch.convert(torch_module, args) + + self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args)) + if __name__ == '__main__': unittest.main() diff --git a/requirements.txt b/requirements.txt index 11ce23d6..b925ce67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy tabulate safetensors --pre -tf-nightly>=2.17.0.dev20240529 +tf-nightly-cpu>=2.17.0.dev20240604 -f https://download.pytorch.org/whl/nightly/torch_nightly.html torch==2.4.0.dev20240429+cpu -f https://download.pytorch.org/whl/nightly/torch_nightly.html