Skip to content

Commit

Permalink
Change all input token_id / input_pos types from int64 to int.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674485096
  • Loading branch information
haozha111 authored and copybara-github committed Sep 14, 2024
1 parent 0634d3a commit 28ce35b
Show file tree
Hide file tree
Showing 24 changed files with 102 additions and 102 deletions.
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ int main(int argc, char* argv[]) {
TFLITE_MINIMAL_CHECK(decode_runner != nullptr);

// Get Input Tensors for each of the runners.
// Shape: [Batch, Seq], Dtype: int64
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* prefill_input = prefill_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int64
// Shape: [Seq], Dtype: int32
TfLiteTensor* prefill_input_pos = prefill_runner->input_tensor("input_pos");
// Shape: [Batch, Seq], Dtype: int64
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* decode_input = decode_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int64
// Shape: [Seq], Dtype: int32
TfLiteTensor* decode_input_pos = decode_runner->input_tensor("input_pos");
int max_seq_size = prefill_input->dims->data[1];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_gemma2_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/gemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_gemma_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/gemma/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
kv_cache_max_len = 1024
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
print("comparing with goldens..")
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def define_and_run_2b(checkpoint_path: str) -> None:
toks = torch.from_numpy(
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
)
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :9] = toks
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
out = model.forward(tokens, input_pos, kv)
out_final = out["logits"][0, 8, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_openelm_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/openelm/openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def define_and_run(checkpoint_path: str) -> None:
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/phi/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_phi2_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/phi/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None:
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
print("comparing with goldens..")
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/smollm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_smollm_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/smollm/smollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def define_and_run(checkpoint_path: str) -> None:
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, config: cfg.ModelConfig):

@torch.inference_mode
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)
tokens = tokens.type(torch.int)

state = self.tok_embedding(tokens) + self.tok_embedding_position
for layer in self.transformer_blocks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
n_tokens = 77
timestamp = 0
len_prompt = 1
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
input_image = torch.full(
(1, 3, image_height, image_width), 0, dtype=torch.float32
)
Expand Down
40 changes: 20 additions & 20 deletions ai_edge_torch/generative/examples/t5/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):

# encoder
seq_len = 512
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_e_token = [1, 2, 3, 4, 5, 6]
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
prompt_e_token, dtype=torch.long
prompt_e_token, dtype=torch.int
)
prefill_e_input_pos = torch.arange(0, seq_len)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_d_token = [1, 2, 3, 4, 5, 6]
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
prompt_d_token, dtype=torch.long
prompt_d_token, dtype=torch.int
)
prefill_d_input_pos = torch.arange(0, seq_len)
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)

# decoder
decode_token = torch.tensor([[1]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[1]], dtype=torch.long)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
decode_token = torch.tensor([[1]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
decode_d_token = torch.tensor([[1]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int)

# Pad mask for self attention only on "real" tokens.
# Pad with `-inf` for any tokens indices that aren't desired.
Expand Down Expand Up @@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):

# encoder
seq_len = 512
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_e_token = [1, 2, 3, 4, 5, 6]
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
prompt_e_token, dtype=torch.long
prompt_e_token, dtype=torch.int
)
prefill_e_input_pos = torch.arange(0, seq_len)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_d_token = [1, 2, 3, 4, 5, 6]
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
prompt_d_token, dtype=torch.long
prompt_d_token, dtype=torch.int
)
prefill_d_input_pos = torch.arange(0, seq_len)
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)

# decoder
decode_token = torch.tensor([[1]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[1]], dtype=torch.long)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
decode_token = torch.tensor([[1]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
decode_d_token = torch.tensor([[1]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int)

# Pad mask for self attention only on "real" tokens.
# Pad with `-inf` for any tokens indices that aren't desired.
Expand Down
16 changes: 8 additions & 8 deletions ai_edge_torch/generative/examples/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
model = build_t5_model(checkpoint_path)

idx = get_sample_encoder_input_ids()
tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
tokens[0, :77] = idx
input_pos = torch.arange(0, 512)
input_pos = torch.arange(0, 512, dtype=torch.int)

decode_d_token = torch.tensor([[0]], dtype=torch.int64)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[0]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
pad_mask[77:] = float("-inf")
lm_logits = model.forward(
Expand All @@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
)
idx = get_sample_encoder_input_ids()

tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
tokens[0, :77] = idx
input_pos = torch.arange(0, 512)
input_pos = torch.arange(0, 512, dtype=torch.int)

decode_d_token = torch.tensor([[0]], dtype=torch.int64)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[0]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
pad_mask = torch.zeros(
[t5_encoder_model.config.kv_cache_max], dtype=torch.float32
)
Expand Down
20 changes: 10 additions & 10 deletions ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,24 @@
"\n",
" # encoder\n",
" seq_len = 512\n",
" prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)\n",
" prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)\n",
" prompt_e_token = [1, 2, 3, 4, 5, 6]\n",
" prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(\n",
" prompt_e_token, dtype=torch.long\n",
" prompt_e_token, dtype=torch.int\n",
" )\n",
" prefill_e_input_pos = torch.arange(0, seq_len)\n",
" prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)\n",
" prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)\n",
" prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)\n",
" prompt_d_token = [1, 2, 3, 4, 5, 6]\n",
" prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(\n",
" prompt_d_token, dtype=torch.long\n",
" prompt_d_token, dtype=torch.int\n",
" )\n",
" prefill_d_input_pos = torch.arange(0, seq_len)\n",
" prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)\n",
"\n",
" # decoder\n",
" decode_token = torch.tensor([[1]], dtype=torch.long)\n",
" decode_input_pos = torch.tensor([0], dtype=torch.int64)\n",
" decode_d_token = torch.tensor([[1]], dtype=torch.long)\n",
" decode_d_input_pos = torch.tensor([0], dtype=torch.int64)\n",
" decode_token = torch.tensor([[1]], dtype=torch.int)\n",
" decode_input_pos = torch.tensor([0], dtype=torch.int)\n",
" decode_d_token = torch.tensor([[1]], dtype=torch.int)\n",
" decode_d_input_pos = torch.tensor([0], dtype=torch.int)\n",
"\n",
" # Pad mask for self attention only on \"real\" tokens.\n",
" # Pad with `-inf` for any tokens indices that aren't desired.\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig:


def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
input_pos = torch.arange(0, 100)
tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100, dtype=torch.int)
return tokens, input_pos


def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
tokens = torch.tensor([[1]], dtype=torch.long)
tokens = torch.tensor([[1]], dtype=torch.int)
input_pos = torch.tensor([10])
return tokens, input_pos

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None:
kv_cache_max_len = 1024
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
kv = kv_utils.KVCache.from_model_config(model.config)
output = model.forward(tokens, input_pos, kv)
assert torch.allclose(
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def _update_kv_base_impl(
v_slice: torch.Tensor,
) -> KVCacheEntry:
"""Update the cache buffer without High Level Function Boundary annotation."""
k = cache.k_cache.index_copy(1, input_pos, k_slice)
v = cache.v_cache.index_copy(1, input_pos, v_slice)
k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
updated_cache = KVCacheEntry(k, v)
return updated_cache

Expand All @@ -189,7 +189,7 @@ def _update_kv_hlfb_impl(
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
)
k = k_cache.index_copy(1, input_pos, k_slice)
v = v_cache.index_copy(1, input_pos, v_slice)
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
k, v = builder.mark_outputs(k, v)
return KVCacheEntry(k, v)
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/quantize/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def main():
config = gemma.get_fake_model_config()
model = gemma.Gemma(config)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10)
input_pos = torch.arange(0, 10, dtype=torch.int)

# Create a quantization recipe to be applied to the model
quant_config = quant_recipes.full_int8_dynamic_recipe()
Expand Down
Loading

0 comments on commit 28ce35b

Please sign in to comment.