Skip to content

Commit

Permalink
Update max decode steps in text generator example to use KV max size.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696523752
  • Loading branch information
hheydary authored and copybara-github committed Nov 14, 2024
1 parent d7c9e8d commit ff85c8d
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ ABSL_FLAG(std::string, tflite_model, "",
ABSL_FLAG(std::string, sentencepiece_model, "", "Path to sentencepiece model.");
ABSL_FLAG(std::string, prompt, "Write an email:", "Input prompt to the model.");
ABSL_FLAG(int, max_decode_steps, -1,
"The number of tokens to generate. Defaults to maximum Sequence size "
"The number of tokens to generate. Defaults to the KV cache size "
"defined during conversion.");
ABSL_FLAG(std::string, start_token, "",
"Start token is appended to the beginning of input prompt to "
"signify start of sentence.");
ABSL_FLAG(std::string, stop_token, "",
"Stop token used to deterine end of decoding loop. If not provided "
"will decode until max_Seq_len or max_decode_steps.");
"will decode until max_kv_cache_size or max_decode_steps.");
ABSL_FLAG(int, num_threads, 4, "Number of threads to use. Defaults to 4.");
ABSL_FLAG(std::string, weight_cache_path, "",
"XNNPACK weight caching path, e.g. /tmp/model.xnnpack_cache.");
Expand Down Expand Up @@ -258,7 +258,11 @@ int main(int argc, char* argv[]) {
TfLiteTensor* decode_input = decode_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int32
TfLiteTensor* decode_input_pos = decode_runner->input_tensor("input_pos");
// shape: [Batch, kv_cache_max, num_query_groups, head_dim]
TfLiteTensor* kv_cache_k_0 = decode_runner->input_tensor("kv_cache_k_0");

int max_seq_size = prefill_input->dims->data[1];
int kv_cache_max_size = kv_cache_k_0->dims->data[1];

// Tokenize the input prompt.
std::string prompt = absl::GetFlag(FLAGS_prompt);
Expand Down Expand Up @@ -286,16 +290,13 @@ int main(int argc, char* argv[]) {
}
TFLITE_MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);

// Decode until max sequence size or user defined step limit, whichever is
// Decode until max kv-cache size or user defined step limit, whichever is
// smaller.
// NOTE: max kv-cache size is *not* necessarily the same size as the max
// sequence length. KV Cache buffer wraps around if exahusted before max
// sequence length or stopping criteria reach.
int max_decode_steps = absl::GetFlag(FLAGS_max_decode_steps) == -1
? max_seq_size
? kv_cache_max_size
: absl::GetFlag(FLAGS_max_decode_steps);
int decode_steps =
std::min(max_decode_steps, max_seq_size - prefill_seq_size);
std::min(max_decode_steps, kv_cache_max_size - prefill_seq_size);
TFLITE_MINIMAL_CHECK(decode_steps > 0);

std::vector<int> output_tokens;
Expand Down

0 comments on commit ff85c8d

Please sign in to comment.