Skip to content

Commit

Permalink
Modify build dependencies for quantize example.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690781145
  • Loading branch information
haozha111 authored and copybara-github committed Oct 28, 2024
1 parent 570ac33 commit 61800ad
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ai_edge_torch/generative/quantize/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,30 @@

import ai_edge_torch
from ai_edge_torch.generative.examples.gemma import gemma1
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.quantize import quant_recipes
from ai_edge_torch.generative.utilities import model_builder
import numpy as np
import torch


def main():
# Build a PyTorch model as usual
config = gemma1.get_fake_model_config()
model = gemma1.Gemma(config)
model = model_builder.DecoderOnlyModel(config).eval()
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(config)

# Create a quantization recipe to be applied to the model
quant_config = quant_recipes.full_int8_dynamic_recipe()
print(quant_config)

# Convert with quantization
edge_model = ai_edge_torch.convert(
model, (tokens, input_pos), quant_config=quant_config
model, (tokens, input_pos, kv), quant_config=quant_config
)
edge_model.export("/tmp/gemma_2b_quantized.tflite")

Expand Down

0 comments on commit 61800ad

Please sign in to comment.