diff --git a/ai_edge_torch/generative/examples/test_models/convert_toy_model.py b/ai_edge_torch/generative/examples/test_models/convert_toy_model.py new file mode 100644 index 00000000..7769d021 --- /dev/null +++ b/ai_edge_torch/generative/examples/test_models/convert_toy_model.py @@ -0,0 +1,105 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# A toy example which has a single-layer transformer block. +from absl import app +import ai_edge_torch +from ai_edge_torch import lowertools +from ai_edge_torch.generative.examples.test_models import toy_model +from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache +from ai_edge_torch.generative.layers import kv_cache as kv_utils +import torch + +KV_CACHE_MAX_LEN = 100 + + +def convert_toy_model(_) -> None: + """Converts a toy model to tflite.""" + model = toy_model.ToySingleLayerModel(toy_model.get_model_config()) + idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0) + input_pos = torch.arange(0, KV_CACHE_MAX_LEN) + print('running an inference') + print( + model.forward( + idx, + input_pos, + ) + ) + + # Convert model to tflite. + print('converting model to tflite') + edge_model = ai_edge_torch.convert( + model, + ( + idx, + input_pos, + ), + ) + edge_model.export('/tmp/toy_model.tflite') + + +def _export_stablehlo_mlir(model, args): + ep = torch.export.export(model, args) + return lowertools.exported_program_to_mlir_text(ep) + + +def convert_toy_model_with_kv_cache(_) -> None: + """Converts a toy model with kv cache to tflite.""" + dump_mlir = False + + config = toy_model_with_kv_cache.get_model_config() + model = toy_model_with_kv_cache.ToyModelWithKVCache(config) + model.eval() + print('running an inference') + kv = kv_utils.KVCache.from_model_config(config) + + tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs() + decode_token, decode_input_pos = ( + toy_model_with_kv_cache.get_sample_decode_inputs() + ) + print(model.forward(tokens, input_pos, kv)) + + if dump_mlir: + mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv)) + with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f: + f.write(mlir_text) + + # Convert model to tflite with 2 signatures (prefill + decode). + print('converting toy model to tflite with 2 signatures (prefill + decode)') + edge_model = ( + ai_edge_torch.signature( + 'prefill', + model, + sample_kwargs={ + 'tokens': tokens, + 'input_pos': input_pos, + 'kv_cache': kv, + }, + ) + .signature( + 'decode', + model, + sample_kwargs={ + 'tokens': decode_token, + 'input_pos': decode_input_pos, + 'kv_cache': kv, + }, + ) + .convert() + ) + edge_model.export('/tmp/toy_external_kv_cache.tflite') + + +if __name__ == '__main__': + app.run(convert_toy_model) diff --git a/ai_edge_torch/generative/examples/test_models/toy_model.py b/ai_edge_torch/generative/examples/test_models/toy_model.py index 5dadafa4..d31bf82c 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model.py @@ -15,13 +15,12 @@ # A toy example which has a single-layer transformer block. from typing import Tuple -import ai_edge_torch +from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers.attention import TransformerBlock import ai_edge_torch.generative.layers.attention_utils as attn_utils -import ai_edge_torch.generative.layers.builder as builder import ai_edge_torch.generative.layers.model_config as cfg import torch -import torch.nn as nn +from torch import nn RoPECache = Tuple[torch.Tensor, torch.Tensor] KV_CACHE_MAX_LEN = 100 @@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig: final_norm_config=norm_config, ) return config - - -def define_and_run() -> None: - model = ToySingleLayerModel(get_model_config()) - idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0) - input_pos = torch.arange(0, KV_CACHE_MAX_LEN) - print('running an inference') - print( - model.forward( - idx, - input_pos, - ) - ) - - # Convert model to tflite. - print('converting model to tflite') - edge_model = ai_edge_torch.convert( - model, - ( - idx, - input_pos, - ), - ) - edge_model.export('/tmp/toy_model.tflite') - - -if __name__ == '__main__': - define_and_run() diff --git a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py index 87d6eabb..79f48132 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py @@ -17,15 +17,14 @@ from typing import Tuple -import ai_edge_torch -from ai_edge_torch import lowertools +from absl import app from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg import torch -import torch.nn as nn +from torch import nn RoPECache = Tuple[torch.Tensor, torch.Tensor] @@ -87,11 +86,6 @@ def forward( return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache} -def _export_stablehlo_mlir(model, args): - ep = torch.export.export(model, args) - return lowertools.exported_program_to_mlir_text(ep) - - def get_model_config() -> cfg.ModelConfig: attn_config = cfg.AttentionConfig( num_heads=32, @@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]: tokens = torch.tensor([[1]], dtype=torch.int) input_pos = torch.tensor([10]) return tokens, input_pos - - -def define_and_run() -> None: - dump_mlir = False - - config = get_model_config() - model = ToyModelWithExternalKV(config) - model.eval() - print('running an inference') - kv = kv_utils.KVCache.from_model_config(config) - - tokens, input_pos = get_sample_prefill_inputs() - decode_token, decode_input_pos = get_sample_decode_inputs() - print(model.forward(tokens, input_pos, kv)) - - if dump_mlir: - mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv)) - with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f: - f.write(mlir_text) - - # Convert model to tflite with 2 signatures (prefill + decode). - print('converting toy model to tflite with 2 signatures (prefill + decode)') - edge_model = ( - ai_edge_torch.signature( - 'prefill', - model, - sample_kwargs={ - 'tokens': tokens, - 'input_pos': input_pos, - 'kv_cache': kv, - }, - ) - .signature( - 'decode', - model, - sample_kwargs={ - 'tokens': decode_token, - 'input_pos': decode_input_pos, - 'kv_cache': kv, - }, - ) - .convert() - ) - edge_model.export('/tmp/toy_external_kv_cache.tflite') - - -if __name__ == '__main__': - define_and_run()