From 0d481c4d1ceba4fbbb10a3c8a53494ca56eb6caf Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Tue, 17 Sep 2024 11:02:00 -0700 Subject: [PATCH] Add a few changes of int64->int32. PiperOrigin-RevId: 675633068 --- .../generative/examples/cpp/text_generator_main.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc index d57f3afa..b4b979b9 100644 --- a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc +++ b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc @@ -252,8 +252,8 @@ int main(int argc, char* argv[]) { int prefill_seq_size = std::min(static_cast(prompt_tokens.size()), max_seq_size); for (int i = 0; i < prefill_seq_size - 1; ++i) { - prefill_input->data.i64[i] = prompt_tokens[i]; - prefill_input_pos->data.i64[i] = i; + prefill_input->data.i32[i] = prompt_tokens[i]; + prefill_input_pos->data.i32[i] = i; } TFLITE_MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk); @@ -274,8 +274,8 @@ int main(int argc, char* argv[]) { int next_token = prompt_tokens[prefill_seq_size - 1]; int next_position = prefill_seq_size - 1; for (int i = 0; i < decode_steps; ++i) { - decode_input->data.i64[0] = next_token; - decode_input_pos->data.i64[0] = next_position; + decode_input->data.i32[0] = next_token; + decode_input_pos->data.i32[0] = next_position; TFLITE_MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk); next_token = GreedySampler(decode_runner->output_tensor("logits")); output_tokens.push_back(next_token);