Skip to content

Commit

Permalink
Wrap embedding tables with HLFB (#61)
Browse files Browse the repository at this point in the history
* Wrap embedding tables with HLFB

BUG=b/332759630

* formatting

* reshape in fx pass

* Add e2e test and fix test inputs

* Explicitly cast i64 tokens to i32
  • Loading branch information
paulinesho authored Jul 15, 2024
1 parent 9fa0401 commit 26162a2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
29 changes: 29 additions & 0 deletions ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,35 @@ def is_valid_padding(padding: list[int]):
node.target = avg_pool2d


@_register_composite_builder(torch.ops.aten.embedding.default)
def _aten_embedding(gm: GraphModule, node: Node):
op = node.target
args_mapper = TorchOpArgumentsMapper(op)

def embedding(*args, **kwargs):
nonlocal op, args_mapper
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
_, embedding_dim = full_kwargs["weight"].size()
idx = full_kwargs["indices"]
idx = idx.type(torch.int)
B, T = idx.size()

idx = torch.reshape(idx, (B * T,))

builder = StableHLOCompositeBuilder("odml.embedding_lookup")
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
idx,
full_kwargs["weight"],
)
output = op(**full_kwargs)
output = builder.mark_outputs(output)

output = torch.reshape(output, (B, T, embedding_dim))
return output

node.target = embedding


class BuildAtenCompositePass(PassBase):

def call(self, graph_module: GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,29 @@ def test_approximate_gelu_layer(self):
)
self.assertEqual(stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1)

def test_embedding_lookup_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
torch.nn.Embedding(10, 10), (torch.full((1, 10), 0, dtype=torch.long),)
)
self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1)

def test_embedding_lookup_op(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda *x: torch.ops.aten.embedding.default(*x),
(torch.rand(10, 10), torch.full((1, 10), 0, dtype=torch.long)),
)
self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1)

def test_embedding_lookup_functional(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda *x: torch.nn.functional.embedding(*x),
(
torch.full((1, 10), 0, dtype=torch.long),
torch.rand(10, 10),
),
)
self.assertEqual(stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1)


if __name__ == '__main__':
unittest.main()
9 changes: 9 additions & 0 deletions ai_edge_torch/convert/test/test_convert_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_convert_gelu_approximate(self):

self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))

def test_convert_embedding_lookup(self):
"""Tests conversion of an Embedding module."""

args = (torch.full((1, 10), 0, dtype=torch.long),)
torch_module = torch.nn.Embedding(10, 10)
edge_model = ai_edge_torch.convert(torch_module, args)

self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))


if __name__ == '__main__':
unittest.main()

0 comments on commit 26162a2

Please sign in to comment.