diff --git a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py index 1e20f7cb..b72ebadc 100644 --- a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py @@ -213,6 +213,35 @@ def is_valid_padding(padding: list[int]): node.target = avg_pool2d +@_register_composite_builder(torch.ops.aten.embedding.default) +def _aten_embedding(gm: GraphModule, node: Node): + op = node.target + args_mapper = TorchOpArgumentsMapper(op) + + def embedding(*args, **kwargs): + nonlocal op, args_mapper + full_kwargs = args_mapper.get_full_kwargs(args, kwargs) + _, embedding_dim = full_kwargs["weight"].size() + idx = full_kwargs["indices"] + idx = idx.type(torch.int) + B, T = idx.size() + + idx = torch.reshape(idx, (B * T,)) + + builder = StableHLOCompositeBuilder("odml.embedding_lookup") + full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs( + idx, + full_kwargs["weight"], + ) + output = op(**full_kwargs) + output = builder.mark_outputs(output) + + output = torch.reshape(output, (B, T, embedding_dim)) + return output + + node.target = embedding + + class BuildAtenCompositePass(PassBase): def call(self, graph_module: GraphModule): diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py index c809d735..3ecee2ef 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py @@ -128,6 +128,29 @@ def test_approximate_gelu_layer(self): ) self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1) + def test_embedding_lookup_layer(self): + stablehlo = _export_to_stablehlo_with_composite( + torch.nn.Embedding(10, 10), (torch.full((1, 10), 0, dtype=torch.long),) + ) + self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1) + + def test_embedding_lookup_op(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda *x: torch.ops.aten.embedding.default(*x), + (torch.rand(10, 10), torch.full((1, 10), 0, dtype=torch.long)), + ) + self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1) + + def test_embedding_lookup_functional(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda *x: torch.nn.functional.embedding(*x), + ( + torch.full((1, 10), 0, dtype=torch.long), + torch.rand(10, 10), + ), + ) + self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1) + if __name__ == '__main__': unittest.main() diff --git a/ai_edge_torch/convert/test/test_convert_composites.py b/ai_edge_torch/convert/test/test_convert_composites.py index 796d3640..3e6e3e3c 100644 --- a/ai_edge_torch/convert/test/test_convert_composites.py +++ b/ai_edge_torch/convert/test/test_convert_composites.py @@ -187,6 +187,15 @@ def test_convert_gelu_approximate(self): self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args)) + def test_convert_embedding_lookup(self): + """Tests conversion of an Embedding module.""" + + args = (torch.full((1, 10), 0, dtype=torch.long),) + torch_module = torch.nn.Embedding(10, 10) + edge_model = ai_edge_torch.convert(torch_module, args) + + self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args)) + if __name__ == '__main__': unittest.main()