Skip to content

Commit

Permalink
Remove the MLIR passes that are Migrated to tflite.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663611553
  • Loading branch information
majiddadashi authored and copybara-github committed Aug 16, 2024
1 parent 7e6023e commit a54d10d
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,24 @@ def test_hardswish_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.Hardswish()(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1

lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.hardswish.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_hardswish_op(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.ops.aten.hardswish.default(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.hardswish.default"'), 1

lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.hardswish.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_avg_pool2d_layer(self):
Expand All @@ -90,8 +98,11 @@ def test_avg_pool2d_layer(self):
)(x),
(torch.rand(1, 3, 6, 6),),
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.avg_pool2d.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_avg_pool2d_op(self):
Expand All @@ -108,8 +119,11 @@ def test_avg_pool2d_op(self):
),
(torch.rand(1, 3, 6, 6),),
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.avg_pool2d.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_avg_pool2d_ceil_mode(self):
Expand All @@ -126,41 +140,56 @@ def test_avg_pool2d_ceil_mode(self):
),
(torch.rand(1, 3, 6, 6),),
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.avg_pool2d.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_gelu_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.GELU()(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.gelu.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_approximate_gelu_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.GELU('tanh')(x), (torch.rand(10, 10),) # pylint: disable=unnecessary-lambda
)
self.assertEqual(
stablehlo.count('stablehlo.composite "aten.gelu.default"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "aten.gelu.default"': 1},
{'stablehlo.custom_call @mark_tensor': 2},
)

def test_embedding_lookup_layer(self):
stablehlo = _export_to_stablehlo_with_composite(
torch.nn.Embedding(10, 10), (torch.full((1, 10), 0, dtype=torch.long),)
)
self.assertEqual(
stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "odml.embedding_lookup"': 1},
{'stablehlo.custom_call @mark_tensor': 3},
)

def test_embedding_lookup_op(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda *x: torch.ops.aten.embedding.default(*x),
(torch.rand(10, 10), torch.full((1, 10), 0, dtype=torch.long)),
)
self.assertEqual(
stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "odml.embedding_lookup"': 1},
{'stablehlo.custom_call @mark_tensor': 3},
)

def test_embedding_lookup_functional(self):
Expand All @@ -171,8 +200,11 @@ def test_embedding_lookup_functional(self):
torch.rand(10, 10),
),
)
self.assertEqual(
stablehlo.count('stablehlo.composite "odml.embedding_lookup"'), 1
lowertools.assert_string_count(
self,
stablehlo,
{'stablehlo.composite "odml.embedding_lookup"': 1},
{'stablehlo.custom_call @mark_tensor': 3},
)


Expand Down
Loading

0 comments on commit a54d10d

Please sign in to comment.