From 6c94a00d62bdc40b397d247fd591c0b3230d0b99 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Fri, 13 Dec 2024 16:52:00 -0800 Subject: [PATCH] update tests wrt odml-torch as default changes PiperOrigin-RevId: 706043950 --- .../test/test_stablehlo_composite_builder.py | 286 ------------------ test/test_quantize.py | 4 - 2 files changed, 290 deletions(-) delete mode 100644 ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py diff --git a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py b/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py deleted file mode 100644 index b808b828..00000000 --- a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2024 The AI Edge Torch Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for StableHLOCompositeBuilder.""" - -import math - -import ai_edge_torch -from ai_edge_torch import hlfb -from ai_edge_torch import lowertools -from ai_edge_torch.hlfb import StableHLOCompositeBuilder -import torch -import torch.nn.functional as F - -from absl.testing import absltest as googletest - - -def _export_stablehlo_mlir(model, args): - ep = torch.export.export(model, args) - return lowertools.exported_program_to_mlir_text(ep) - -StableHLOCompositeBuilder = hlfb.StableHLOCompositeBuilder - - -@googletest.skipIf( - not ai_edge_torch.config.use_torch_xla, - reason="The odml_torch counter part is in odml_torch.", -) -class TestStableHLOCompositeBuilder(googletest.TestCase): - - def test_build_composite(self): - class SampleModel(torch.nn.Module): - - def forward(self, x): - builder = StableHLOCompositeBuilder(name="test.plus_two") - y = x + 1 - y = builder.mark_inputs(y) - z = y + 2 - z = builder.mark_outputs(z) - return z - - mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),)) - self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1) - - def test_build_multiple_composites(self): - class SampleModel(torch.nn.Module): - - def plus_one(self, x: torch.Tensor): - builder = StableHLOCompositeBuilder("test.plus_one") - x = builder.mark_inputs(x) - y = x + 1 - y = builder.mark_outputs(y) - return y - - def plus_two(self, x: torch.Tensor): - builder = StableHLOCompositeBuilder("test.plus_two") - x = builder.mark_inputs(x) - y = x + 2 - y = builder.mark_outputs(y) - return y - - def forward(self, x): - x = self.plus_two(x) - x = x + 3 - x = self.plus_one(x) - x = x + 4 - x = self.plus_two(x) - return x - - mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),)) - self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1) - self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2) - - def test_build_composite_with_attr(self): - class SampleModel(torch.nn.Module): - - def __init__(self): - super().__init__() - - def log_softmax(self, x: torch.Tensor, dim: int): - builder = StableHLOCompositeBuilder( - name="test.log_softmax", attr={"dim": dim} - ) - x = builder.mark_inputs(x) - y = torch.nn.functional.log_softmax(x, dim=dim) - y = builder.mark_outputs(y) - return y - - def forward(self, x): - x = x + 1 - x = self.log_softmax(x, 0) - x = self.log_softmax(x, 1) - return x - - mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),)) - self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2) - self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1) - self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1) - - def test_build_composite_with_mix_type_attrs(self): - class SampleModel(torch.nn.Module): - - def __init__(self): - super().__init__() - - def log_softmax(self, x: torch.Tensor, dim: int): - builder = StableHLOCompositeBuilder( - name="test.log_softmax", - attr={ - "dim": dim, - "source": "torch.nn", - "version": 1.0, - }, - ) - x = builder.mark_inputs(x) - y = torch.nn.functional.log_softmax(x, dim=dim) - y = builder.mark_outputs(y) - return y - - def forward(self, x): - x = x + 1 - x = self.log_softmax(x, 0) - return x - - mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),)) - self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1) - self.assertEqual( - mlir.count( - 'composite_attributes = {dim = 0 : i64, source = "torch.nn",' - " version = 1.000000e+00 : f32}" - ), - 1, - ) - - def test_sdpa_composite(self): - class SDPAModel(torch.nn.Module): - - def scaled_dot_product_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - head_size: int, - mask: torch.Tensor, - ): - builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention") - q, k, v, mask = builder.mark_inputs(q, k, v, mask) - - scale = 1.0 / math.sqrt(head_size) - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=0.0, - is_causal=mask is None, - scale=scale, - ) - result = y.transpose(1, 2) - result = builder.mark_outputs(result) - return result - - def forward(self, q, k, v, mask): - x = self.scaled_dot_product_attention( - q, - k, - v, - 8, - mask, - ) - return x - - query = torch.rand(1, 1, 32, 4) - key = torch.rand(1, 500, 1, 4) - value = torch.rand(1, 500, 1, 4) - mask = torch.rand(1, 1, 1, 500) - - mlir = _export_stablehlo_mlir( - SDPAModel().eval(), - (query, key, value, mask), - ) - self.assertEqual( - mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1 - ) - - def test_sdpa_composite_with_attr(self): - class SDPAModel(torch.nn.Module): - - def scaled_dot_product_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - head_size: int, - include_captanh: bool, - ): - builder = StableHLOCompositeBuilder( - name="test.scaled_dot_product_attention", - attr={"include_captanh": include_captanh}, - ) - q, k, v = builder.mark_inputs(q, k, v) - - scale = 1.0 / math.sqrt(head_size) - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=True, - scale=scale, - ) - result = y.transpose(1, 2) - result = builder.mark_outputs(result) - return result - - def forward(self, q, k, v): - x = self.scaled_dot_product_attention(q, k, v, 8, True) - y = self.scaled_dot_product_attention(q, k, v, 8, False) - return x + y - - query = torch.rand(1, 1, 32, 4) - key = torch.rand(1, 500, 1, 4) - value = torch.rand(1, 500, 1, 4) - mlir = _export_stablehlo_mlir( - SDPAModel().eval(), - (query, key, value), - ) - self.assertEqual( - mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2 - ) - self.assertEqual( - mlir.count("composite_attributes = {include_captanh = true}"), 1 - ) - self.assertEqual( - mlir.count("composite_attributes = {include_captanh = false}"), 1 - ) - - def test_build_composite_with_multiple_inputs_outputs(self): - class SampleModel(torch.nn.Module): - - def mimo_sample(self, a, b, c): - builder = StableHLOCompositeBuilder(name="test.mimo_sample") - - a, b, c = builder.mark_inputs(a, b, c) - x = a + b + c - y = (a - b) * x - z = (c + 1.0) * a - x, y, z = builder.mark_outputs(x, y, z) - - result = x + y * z - return result - - def forward(self, a, b, c): - x = self.mimo_sample(a, b, c) - x = self.mimo_sample(a, b, x) - x = self.mimo_sample(x, x, c) - return x - - mlir = _export_stablehlo_mlir( - SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2)) - ) - self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3) - - -if __name__ == "__main__": - googletest.main() diff --git a/test/test_quantize.py b/test/test_quantize.py index 99db733a..aeb45209 100644 --- a/test/test_quantize.py +++ b/test/test_quantize.py @@ -37,10 +37,6 @@ def setUp(self): super().setUp() torch.manual_seed(0) - @googletest.skipIf( - not ai_edge_torch.config.use_torch_xla, - reason="Only working with torch_xla at the moment.", - ) def test_quantizer_arg(self): """Compare the sizes of models.