Skip to content

Commit

Permalink
Change tokens from i64 to i32 (#68)
Browse files Browse the repository at this point in the history
BUG=b/348016917
  • Loading branch information
paulinesho authored Jun 28, 2024
1 parent b5c7314 commit 2f133e4
Show file tree
Hide file tree
Showing 16 changed files with 42 additions and 42 deletions.
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/gemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def convert_gemma_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/gemma/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def define_and_run_2b() -> None:
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
print("running an inference")
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/phi2/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def convert_phi2_to_tflite(
"""
pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def define_and_run() -> None:
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
print("running an inference")
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, config: cfg.ModelConfig):
)

@torch.inference_mode
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)
def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.int)

state = self.tok_embedding(tokens) + self.tok_embedding_position
for layer in self.transformer_blocks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def convert_stable_diffusion_to_tflite(
n_tokens = 77
timestamp = 0
len_prompt = 1
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
noise = torch.full(
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
Expand Down
24 changes: 12 additions & 12 deletions ai_edge_torch/generative/examples/t5/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):

# encoder
seq_len = 512
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_e_token = [1, 2, 3, 4, 5, 6]
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
prompt_e_token, dtype=torch.long
prompt_e_token, dtype=torch.int
)
prefill_e_input_pos = torch.arange(0, seq_len)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_d_token = [1, 2, 3, 4, 5, 6]
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
prompt_d_token, dtype=torch.long
prompt_d_token, dtype=torch.int
)
prefill_d_input_pos = torch.arange(0, seq_len)

# decoder
decode_token = torch.tensor([[1]], dtype=torch.long)
decode_token = torch.tensor([[1]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[1]], dtype=torch.long)
decode_d_token = torch.tensor([[1]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)

# Pad mask for self attention only on "real" tokens.
Expand Down Expand Up @@ -78,23 +78,23 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):

# encoder
seq_len = 512
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_e_token = [1, 2, 3, 4, 5, 6]
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
prompt_e_token, dtype=torch.long
prompt_e_token, dtype=torch.int
)
prefill_e_input_pos = torch.arange(0, seq_len)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
prompt_d_token = [1, 2, 3, 4, 5, 6]
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
prompt_d_token, dtype=torch.long
prompt_d_token, dtype=torch.int
)
prefill_d_input_pos = torch.arange(0, seq_len)

# decoder
decode_token = torch.tensor([[1]], dtype=torch.long)
decode_token = torch.tensor([[1]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
decode_d_token = torch.tensor([[1]], dtype=torch.long)
decode_d_token = torch.tensor([[1]], dtype=torch.int)
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)

# Pad mask for self attention only on "real" tokens.
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/examples/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def define_and_run_t5(checkpoint_path: str) -> None:
model = build_t5_model(checkpoint_path)

idx = get_sample_encoder_input_ids()
tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
tokens[0, :77] = idx
input_pos = torch.arange(0, 512)

Expand All @@ -586,7 +586,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
idx = get_sample_encoder_input_ids()

tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
tokens[0, :77] = idx
input_pos = torch.arange(0, 512)

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/test_models/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def define_and_run() -> None:
)

model = ToySingleLayerModel(config)
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN, dtype=torch.int), 0)
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
print('running an inference')
print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def get_model_config() -> cfg.ModelConfig:


def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.unsqueeze(torch.arange(0, 100), 0)
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100)
return idx, input_pos


def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.tensor([[1]], dtype=torch.long)
idx = torch.tensor([[1]], dtype=torch.int)
input_pos = torch.tensor([10])
return idx, input_pos

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def get_model_config() -> cfg.ModelConfig:


def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.unsqueeze(torch.arange(0, 100), 0)
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100)
return idx, input_pos


def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
idx = torch.tensor([[1]], dtype=torch.long)
idx = torch.tensor([[1]], dtype=torch.int)
input_pos = torch.tensor([10], dtype=torch.int64)
return idx, input_pos

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def convert_tiny_llama_to_tflite(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_token = torch.tensor([[0]], dtype=torch.int)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def define_and_run() -> None:
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, kv_cache_max_len)
print("running an inference")
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/quantize/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
config = gemma.get_fake_model_config_2b_for_test()
model = gemma.Gemma(config)
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10)

Expand Down
16 changes: 8 additions & 8 deletions ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TestModelConversion(unittest.TestCase):
def test_toy_model_with_kv_cache(self):
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int64
)

Expand All @@ -59,7 +59,7 @@ def test_toy_model_with_multi_batches(self):
config = toy_model_with_kv_cache.get_model_config()
config.batch_size = 2
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int64
)

Expand All @@ -83,7 +83,7 @@ def test_toy_model_with_kv_cache_with_hlfb(self):
config = toy_model_with_kv_cache.get_model_config()
config.enable_hlfb = True
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int64
)

Expand All @@ -109,7 +109,7 @@ def test_tiny_llama(self):
pytorch_model = tiny_llama.TinyLLamma(config)

idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10)

Expand All @@ -135,13 +135,13 @@ def test_tiny_llama_multisig(self):

# prefill
seq_len = 10
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
prefill_tokens[0, : len(prompt_token)] = prompt_token
prefill_input_pos = torch.arange(0, seq_len)

# decode
decode_token = torch.tensor([[1]], dtype=torch.long)
decode_token = torch.tensor([[1]], dtype=torch.int)
decode_input_pos = torch.tensor([5], dtype=torch.int64)

edge_model = (
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_gemma(self):
model = gemma.Gemma(config)

idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10)

Expand All @@ -210,7 +210,7 @@ def test_phi2(self):
pytorch_model = phi2.Phi2(config)

idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
tokens[0, :4] = idx
input_pos = torch.arange(0, 10)

Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
self.skipTest("b/346896669")
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int64
)

Expand All @@ -137,7 +137,7 @@ def test_quantize_convert_compare_toy(self):
self.skipTest("b/338288901")
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int64
)

Expand Down

0 comments on commit 2f133e4

Please sign in to comment.