Skip to content

Commit

Permalink
Add a few changes of int64->int32.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675633068
  • Loading branch information
haozha111 authored and copybara-github committed Sep 17, 2024
1 parent e50f7a7 commit 0d481c4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ int main(int argc, char* argv[]) {
int prefill_seq_size =
std::min(static_cast<int>(prompt_tokens.size()), max_seq_size);
for (int i = 0; i < prefill_seq_size - 1; ++i) {
prefill_input->data.i64[i] = prompt_tokens[i];
prefill_input_pos->data.i64[i] = i;
prefill_input->data.i32[i] = prompt_tokens[i];
prefill_input_pos->data.i32[i] = i;
}
TFLITE_MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);

Expand All @@ -274,8 +274,8 @@ int main(int argc, char* argv[]) {
int next_token = prompt_tokens[prefill_seq_size - 1];
int next_position = prefill_seq_size - 1;
for (int i = 0; i < decode_steps; ++i) {
decode_input->data.i64[0] = next_token;
decode_input_pos->data.i64[0] = next_position;
decode_input->data.i32[0] = next_token;
decode_input_pos->data.i32[0] = next_position;
TFLITE_MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk);
next_token = GreedySampler(decode_runner->output_tensor("logits"));
output_tokens.push_back(next_token);
Expand Down

0 comments on commit 0d481c4

Please sign in to comment.