diff --git a/ai_edge_torch/generative/quantize/example.py b/ai_edge_torch/generative/quantize/example.py index 7f47667b..b7bccc92 100644 --- a/ai_edge_torch/generative/quantize/example.py +++ b/ai_edge_torch/generative/quantize/example.py @@ -15,7 +15,9 @@ import ai_edge_torch from ai_edge_torch.generative.examples.gemma import gemma1 +from ai_edge_torch.generative.layers import kv_cache as kv_utils from ai_edge_torch.generative.quantize import quant_recipes +from ai_edge_torch.generative.utilities import model_builder import numpy as np import torch @@ -23,11 +25,12 @@ def main(): # Build a PyTorch model as usual config = gemma1.get_fake_model_config() - model = gemma1.Gemma(config) + model = model_builder.DecoderOnlyModel(config).eval() idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, 10, dtype=torch.int) + kv = kv_utils.KVCache.from_model_config(config) # Create a quantization recipe to be applied to the model quant_config = quant_recipes.full_int8_dynamic_recipe() @@ -35,7 +38,7 @@ def main(): # Convert with quantization edge_model = ai_edge_torch.convert( - model, (tokens, input_pos), quant_config=quant_config + model, (tokens, input_pos, kv), quant_config=quant_config ) edge_model.export("/tmp/gemma_2b_quantized.tflite")