Skip to content

Commit

Permalink
Enable test for model conversion. (#32)
Browse files Browse the repository at this point in the history
* Enable test for model conversion.

* update.

* Update test_model_conversion.py

skip some tests

* Update test_model_conversion.py

disable some tests

* Update test_model_conversion.py

remove old comments.

* style fix.
  • Loading branch information
haozha111 authored Jun 5, 2024
1 parent 01a4c3d commit 382a02c
Showing 1 changed file with 90 additions and 80 deletions.
170 changes: 90 additions & 80 deletions ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
"""Unit tests that check for model conversion and correctness."""

def test_toy_model_with_kv_cache(self):
self.skipTest("b/338288901")
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
Expand All @@ -42,19 +41,21 @@ def test_toy_model_with_kv_cache(self):

edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(idx, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(idx, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)

def test_toy_model_with_kv_cache_with_hlfb(self):
self.skipTest("b/338288901")
config = toy_model_with_kv_cache.get_model_config()
config.enable_hlfb = True
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
Expand All @@ -64,16 +65,19 @@ def test_toy_model_with_kv_cache_with_hlfb(self):

edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(idx, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(idx, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)

def test_tiny_llama(self):
self.skipTest("b/338288901")
Expand All @@ -87,19 +91,21 @@ def test_tiny_llama(self):

edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)

def test_tiny_llama_multisig(self):
self.skipTest("b/338288901")
config = tiny_llama.get_fake_model_config_for_test()
pytorch_model = tiny_llama.TinyLLamma(config)

Expand All @@ -122,32 +128,30 @@ def test_tiny_llama_multisig(self):
.convert()
)

# For the pytorch model, the KV cache is a persistent state internal to the model, and it
# will be shared for prefill and decode. However, for tflite, currently we can't share
# kv-cache between the two signatures. prefill will change the content in kv-cache,
# but it won't be readable by the decode tflite model. This means the output of running `decode` after
# running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
copied_model = copy.deepcopy(pytorch_model)

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(prefill_tokens, prefill_input_pos),
signature_name="prefill",
num_valid_inputs=1,
)
)

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
copied_model,
(decode_token, decode_input_pos),
signature_name="decode",
num_valid_inputs=1,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
copied_model = copy.deepcopy(pytorch_model)

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(prefill_tokens, prefill_input_pos),
signature_name="prefill",
num_valid_inputs=1,
)
)

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
copied_model,
(decode_token, decode_input_pos),
signature_name="decode",
num_valid_inputs=1,
)
)

def test_gemma(self):
self.skipTest("b/338288901")
Expand All @@ -161,17 +165,20 @@ def test_gemma(self):

edge_model = ai_edge_torch.convert(model, (tokens, input_pos))

# TODO(talumbau, haoliang): debug numerical diff.
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-2,
rtol=1e-5,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
# TODO(talumbau, haoliang): debug numerical diff.
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-2,
rtol=1e-5,
)
)

def test_phi2(self):
self.skipTest("b/338288901")
Expand All @@ -185,16 +192,19 @@ def test_phi2(self):

edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))

self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)
# TODO(b/338288901): re-enable test to check output tensors.
skip_output_check = True
if skip_output_check is False:
self.assertTrue(
model_coverage.compare_tflite_torch(
edge_model,
pytorch_model,
(tokens, input_pos),
num_valid_inputs=1,
atol=1e-5,
rtol=1e-5,
)
)


if __name__ == "__main__":
Expand Down

0 comments on commit 382a02c

Please sign in to comment.