diff --git a/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py b/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py index 1a2c4925..a7e0c7fd 100644 --- a/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py @@ -45,9 +45,9 @@ def convert_gemma_to_tflite( checkpoint_path, kv_cache_max_len=kv_cache_max_len ) # Tensors used to trace the model graph during conversion. - prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long) + prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int) prefill_input_pos = torch.arange(0, prefill_seq_len) - decode_token = torch.tensor([[0]], dtype=torch.long) + decode_token = torch.tensor([[0]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int64) quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None diff --git a/ai_edge_torch/generative/examples/gemma/gemma.py b/ai_edge_torch/generative/examples/gemma/gemma.py index a6c7e850..7346a483 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma.py +++ b/ai_edge_torch/generative/examples/gemma/gemma.py @@ -163,7 +163,7 @@ def define_and_run_2b() -> None: checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b") model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, kv_cache_max_len) print("running an inference") diff --git a/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py b/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py index f6387554..30d45d51 100644 --- a/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py @@ -43,9 +43,9 @@ def convert_phi2_to_tflite( """ pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) # Tensors used to trace the model graph during conversion. - prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long) + prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int) prefill_input_pos = torch.arange(0, prefill_seq_len) - decode_token = torch.tensor([[0]], dtype=torch.long) + decode_token = torch.tensor([[0]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int64) quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None diff --git a/ai_edge_torch/generative/examples/phi2/phi2.py b/ai_edge_torch/generative/examples/phi2/phi2.py index 8bb31ba2..5af33e1a 100644 --- a/ai_edge_torch/generative/examples/phi2/phi2.py +++ b/ai_edge_torch/generative/examples/phi2/phi2.py @@ -153,7 +153,7 @@ def define_and_run() -> None: checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2") model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, kv_cache_max_len) print("running an inference") diff --git a/ai_edge_torch/generative/examples/stable_diffusion/clip.py b/ai_edge_torch/generative/examples/stable_diffusion/clip.py index 17fe2f4a..c4a09b62 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/clip.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/clip.py @@ -60,8 +60,8 @@ def __init__(self, config: cfg.ModelConfig): ) @torch.inference_mode - def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: - tokens = tokens.type(torch.long) + def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor: + tokens = tokens.type(torch.int) state = self.tok_embedding(tokens) + self.tok_embedding_position for layer in self.transformer_blocks: diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py index 0a45f6bc..49fc5d0a 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -61,7 +61,7 @@ def convert_stable_diffusion_to_tflite( n_tokens = 77 timestamp = 0 len_prompt = 1 - prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long) + prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int) input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32) noise = torch.full( (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32 diff --git a/ai_edge_torch/generative/examples/t5/convert_to_tflite.py b/ai_edge_torch/generative/examples/t5/convert_to_tflite.py index 3b49a7c3..f115dbab 100644 --- a/ai_edge_torch/generative/examples/t5/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/t5/convert_to_tflite.py @@ -30,23 +30,23 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str): # encoder seq_len = 512 - prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long) + prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int) prompt_e_token = [1, 2, 3, 4, 5, 6] prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor( - prompt_e_token, dtype=torch.long + prompt_e_token, dtype=torch.int ) prefill_e_input_pos = torch.arange(0, seq_len) - prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long) + prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int) prompt_d_token = [1, 2, 3, 4, 5, 6] prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor( - prompt_d_token, dtype=torch.long + prompt_d_token, dtype=torch.int ) prefill_d_input_pos = torch.arange(0, seq_len) # decoder - decode_token = torch.tensor([[1]], dtype=torch.long) + decode_token = torch.tensor([[1]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int64) - decode_d_token = torch.tensor([[1]], dtype=torch.long) + decode_d_token = torch.tensor([[1]], dtype=torch.int) decode_d_input_pos = torch.tensor([0], dtype=torch.int64) # Pad mask for self attention only on "real" tokens. @@ -78,23 +78,23 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str): # encoder seq_len = 512 - prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long) + prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int) prompt_e_token = [1, 2, 3, 4, 5, 6] prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor( - prompt_e_token, dtype=torch.long + prompt_e_token, dtype=torch.int ) prefill_e_input_pos = torch.arange(0, seq_len) - prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long) + prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int) prompt_d_token = [1, 2, 3, 4, 5, 6] prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor( - prompt_d_token, dtype=torch.long + prompt_d_token, dtype=torch.int ) prefill_d_input_pos = torch.arange(0, seq_len) # decoder - decode_token = torch.tensor([[1]], dtype=torch.long) + decode_token = torch.tensor([[1]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int64) - decode_d_token = torch.tensor([[1]], dtype=torch.long) + decode_d_token = torch.tensor([[1]], dtype=torch.int) decode_d_input_pos = torch.tensor([0], dtype=torch.int64) # Pad mask for self attention only on "real" tokens. diff --git a/ai_edge_torch/generative/examples/t5/t5.py b/ai_edge_torch/generative/examples/t5/t5.py index 9dfc503b..cb1bfe09 100644 --- a/ai_edge_torch/generative/examples/t5/t5.py +++ b/ai_edge_torch/generative/examples/t5/t5.py @@ -562,7 +562,7 @@ def define_and_run_t5(checkpoint_path: str) -> None: model = build_t5_model(checkpoint_path) idx = get_sample_encoder_input_ids() - tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu") tokens[0, :77] = idx input_pos = torch.arange(0, 512) @@ -586,7 +586,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None: t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path) idx = get_sample_encoder_input_ids() - tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu") tokens[0, :77] = idx input_pos = torch.arange(0, 512) diff --git a/ai_edge_torch/generative/examples/test_models/toy_model.py b/ai_edge_torch/generative/examples/test_models/toy_model.py index 6e1be6c2..3dd56a99 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model.py @@ -93,7 +93,7 @@ def define_and_run() -> None: ) model = ToySingleLayerModel(config) - idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0) + idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN, dtype=torch.int), 0) input_pos = torch.arange(0, KV_CACHE_MAX_LEN) print('running an inference') print( diff --git a/ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py b/ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py index 5c56a534..a52dc5dc 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py @@ -115,13 +115,13 @@ def get_model_config() -> cfg.ModelConfig: def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]: - idx = torch.unsqueeze(torch.arange(0, 100), 0) + idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0) input_pos = torch.arange(0, 100) return idx, input_pos def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]: - idx = torch.tensor([[1]], dtype=torch.long) + idx = torch.tensor([[1]], dtype=torch.int) input_pos = torch.tensor([10]) return idx, input_pos diff --git a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py index 2ee42234..2c047289 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py @@ -103,13 +103,13 @@ def get_model_config() -> cfg.ModelConfig: def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]: - idx = torch.unsqueeze(torch.arange(0, 100), 0) + idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0) input_pos = torch.arange(0, 100) return idx, input_pos def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]: - idx = torch.tensor([[1]], dtype=torch.long) + idx = torch.tensor([[1]], dtype=torch.int) input_pos = torch.tensor([10], dtype=torch.int64) return idx, input_pos diff --git a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py index 21f1ae20..e607e0b2 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py @@ -45,9 +45,9 @@ def convert_tiny_llama_to_tflite( checkpoint_path, kv_cache_max_len=kv_cache_max_len ) # Tensors used to trace the model graph during conversion. - prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long) + prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int) prefill_input_pos = torch.arange(0, prefill_seq_len) - decode_token = torch.tensor([[0]], dtype=torch.long) + decode_token = torch.tensor([[0]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int64) quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None diff --git a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py index 08314a1a..69cebed8 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +++ b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py @@ -153,7 +153,7 @@ def define_and_run() -> None: checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama") model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, kv_cache_max_len) print("running an inference") diff --git a/ai_edge_torch/generative/quantize/example.py b/ai_edge_torch/generative/quantize/example.py index 24ca0a8d..1af859bc 100644 --- a/ai_edge_torch/generative/quantize/example.py +++ b/ai_edge_torch/generative/quantize/example.py @@ -26,7 +26,7 @@ def main(): config = gemma.get_fake_model_config_2b_for_test() model = gemma.Gemma(config) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, 10) diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index 66cbcd95..c8c69b83 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -35,7 +35,7 @@ class TestModelConversion(unittest.TestCase): def test_toy_model_with_kv_cache(self): config = toy_model_with_kv_cache.get_model_config() pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( + idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor( [10], dtype=torch.int64 ) @@ -59,7 +59,7 @@ def test_toy_model_with_multi_batches(self): config = toy_model_with_kv_cache.get_model_config() config.batch_size = 2 pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor( + idx, input_pos = torch.tensor([[1], [2]], dtype=torch.int), torch.tensor( [10], dtype=torch.int64 ) @@ -83,7 +83,7 @@ def test_toy_model_with_kv_cache_with_hlfb(self): config = toy_model_with_kv_cache.get_model_config() config.enable_hlfb = True pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( + idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor( [10], dtype=torch.int64 ) @@ -109,7 +109,7 @@ def test_tiny_llama(self): pytorch_model = tiny_llama.TinyLLamma(config) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, 10) @@ -135,13 +135,13 @@ def test_tiny_llama_multisig(self): # prefill seq_len = 10 - prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu") + prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu") prompt_token = torch.from_numpy(np.array([1, 2, 3, 4])) prefill_tokens[0, : len(prompt_token)] = prompt_token prefill_input_pos = torch.arange(0, seq_len) # decode - decode_token = torch.tensor([[1]], dtype=torch.long) + decode_token = torch.tensor([[1]], dtype=torch.int) decode_input_pos = torch.tensor([5], dtype=torch.int64) edge_model = ( @@ -183,7 +183,7 @@ def test_gemma(self): model = gemma.Gemma(config) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, 10) @@ -210,7 +210,7 @@ def test_phi2(self): pytorch_model = phi2.Phi2(config) idx = torch.from_numpy(np.array([[1, 2, 3, 4]])) - tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu") + tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu") tokens[0, :4] = idx input_pos = torch.arange(0, 10) diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index 8a97b874..1bb7e489 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -119,7 +119,7 @@ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression): self.skipTest("b/346896669") config = toy_model_with_kv_cache.get_model_config() pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( + idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor( [10], dtype=torch.int64 ) @@ -137,7 +137,7 @@ def test_quantize_convert_compare_toy(self): self.skipTest("b/338288901") config = toy_model_with_kv_cache.get_model_config() pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( + idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor( [10], dtype=torch.int64 )