From 382a02c990eca7bd55ad81788d1a2e12a5310cd2 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Wed, 5 Jun 2024 13:36:23 -0700 Subject: [PATCH] Enable test for model conversion. (#32) * Enable test for model conversion. * update. * Update test_model_conversion.py skip some tests * Update test_model_conversion.py disable some tests * Update test_model_conversion.py remove old comments. * style fix. --- .../generative/test/test_model_conversion.py | 170 +++++++++--------- 1 file changed, 90 insertions(+), 80 deletions(-) diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index 07ab485c..00dcb779 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase): """Unit tests that check for model conversion and correctness.""" def test_toy_model_with_kv_cache(self): - self.skipTest("b/338288901") config = toy_model_with_kv_cache.get_model_config() pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( @@ -42,19 +41,21 @@ def test_toy_model_with_kv_cache(self): edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos)) - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - pytorch_model, - (idx, input_pos), - num_valid_inputs=1, - atol=1e-5, - rtol=1e-5, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + pytorch_model, + (idx, input_pos), + num_valid_inputs=1, + atol=1e-5, + rtol=1e-5, + ) + ) def test_toy_model_with_kv_cache_with_hlfb(self): - self.skipTest("b/338288901") config = toy_model_with_kv_cache.get_model_config() config.enable_hlfb = True pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) @@ -64,16 +65,19 @@ def test_toy_model_with_kv_cache_with_hlfb(self): edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos)) - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - pytorch_model, - (idx, input_pos), - num_valid_inputs=1, - atol=1e-5, - rtol=1e-5, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + pytorch_model, + (idx, input_pos), + num_valid_inputs=1, + atol=1e-5, + rtol=1e-5, + ) + ) def test_tiny_llama(self): self.skipTest("b/338288901") @@ -87,19 +91,21 @@ def test_tiny_llama(self): edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos)) - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - pytorch_model, - (tokens, input_pos), - num_valid_inputs=1, - atol=1e-5, - rtol=1e-5, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + pytorch_model, + (tokens, input_pos), + num_valid_inputs=1, + atol=1e-5, + rtol=1e-5, + ) + ) def test_tiny_llama_multisig(self): - self.skipTest("b/338288901") config = tiny_llama.get_fake_model_config_for_test() pytorch_model = tiny_llama.TinyLLamma(config) @@ -122,32 +128,30 @@ def test_tiny_llama_multisig(self): .convert() ) - # For the pytorch model, the KV cache is a persistent state internal to the model, and it - # will be shared for prefill and decode. However, for tflite, currently we can't share - # kv-cache between the two signatures. prefill will change the content in kv-cache, - # but it won't be readable by the decode tflite model. This means the output of running `decode` after - # running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch. - copied_model = copy.deepcopy(pytorch_model) - - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - pytorch_model, - (prefill_tokens, prefill_input_pos), - signature_name="prefill", - num_valid_inputs=1, - ) - ) - - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - copied_model, - (decode_token, decode_input_pos), - signature_name="decode", - num_valid_inputs=1, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + copied_model = copy.deepcopy(pytorch_model) + + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + pytorch_model, + (prefill_tokens, prefill_input_pos), + signature_name="prefill", + num_valid_inputs=1, + ) + ) + + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + copied_model, + (decode_token, decode_input_pos), + signature_name="decode", + num_valid_inputs=1, + ) + ) def test_gemma(self): self.skipTest("b/338288901") @@ -161,17 +165,20 @@ def test_gemma(self): edge_model = ai_edge_torch.convert(model, (tokens, input_pos)) - # TODO(talumbau, haoliang): debug numerical diff. - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - model, - (tokens, input_pos), - num_valid_inputs=1, - atol=1e-2, - rtol=1e-5, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + # TODO(talumbau, haoliang): debug numerical diff. + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + model, + (tokens, input_pos), + num_valid_inputs=1, + atol=1e-2, + rtol=1e-5, + ) + ) def test_phi2(self): self.skipTest("b/338288901") @@ -185,16 +192,19 @@ def test_phi2(self): edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos)) - self.assertTrue( - model_coverage.compare_tflite_torch( - edge_model, - pytorch_model, - (tokens, input_pos), - num_valid_inputs=1, - atol=1e-5, - rtol=1e-5, - ) - ) + # TODO(b/338288901): re-enable test to check output tensors. + skip_output_check = True + if skip_output_check is False: + self.assertTrue( + model_coverage.compare_tflite_torch( + edge_model, + pytorch_model, + (tokens, input_pos), + num_valid_inputs=1, + atol=1e-5, + rtol=1e-5, + ) + ) if __name__ == "__main__":