Skip to content

Commit

Permalink
Support passing attention mask as optional input in text_generator_ma…
Browse files Browse the repository at this point in the history
…in.cc.

PiperOrigin-RevId: 706832231
  • Loading branch information
haozha111 authored and copybara-github committed Dec 16, 2024
1 parent 0704751 commit b2342c2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cc_binary(
"@com_google_absl//absl/strings",
"@com_google_sentencepiece//:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
Expand Down
43 changes: 39 additions & 4 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "ai_edge_torch/generative/examples/cpp/utils.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/genai/genai_ops.h"
#include "tensorflow/lite/interpreter.h"
Expand Down Expand Up @@ -182,8 +183,11 @@ tflite::SignatureRunner* GetPrefillRunner(
}
TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())
->input_tensor("input_pos");
// The expected shape for input position is [Seq].
int seq_size = input_pos->dims->data[0];
// The expected shape for input position is [Seq](from ai_edge_torch) or
// [Batch, Seq](from ai_edge_jax).
MINIMAL_CHECK(input_pos->dims->size == 1 || input_pos->dims->size == 2);
int seq_size = input_pos->dims->size == 1 ? input_pos->dims->data[0]
: input_pos->dims->data[1];
if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) {
runner = interpreter->GetSignatureRunner(key->c_str());
delta = seq_size - num_input_tokens;
Expand Down Expand Up @@ -229,6 +233,17 @@ int GreedySampler(const TfLiteTensor* logits) {
return max_index;
}

// Scans through the input tensor names to check if the attention mask is
// passed as an input tensor.
bool AttentionMaskInInput(tflite::SignatureRunner* runner) {
for (int i = 0; i < runner->input_names().size(); ++i) {
if (strcmp(runner->input_names()[i], "attention_mask") == 0) {
return true;
}
}
return false;
}

} // namespace

int main(int argc, char* argv[]) {
Expand Down Expand Up @@ -269,15 +284,25 @@ int main(int argc, char* argv[]) {
GetDecodeRunner(interpreter.get(), kv_cache);
MINIMAL_CHECK(decode_runner != nullptr);

// Check if the attention mask is passed as an input tensor.
bool attention_mask_as_input = AttentionMaskInInput(prefill_runner);
// Get Input Tensors for each of the runners.
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* prefill_input = prefill_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int32
// Shape: [Seq] or [Batch, Seq], Dtype: int32
TfLiteTensor* prefill_input_pos = prefill_runner->input_tensor("input_pos");
// Shape: [Batch, 1, Seq], Dtype: int32
TfLiteTensor* prefill_input_mask =
attention_mask_as_input ? prefill_runner->input_tensor("attention_mask")
: nullptr;
// Shape: [Batch, Seq], Dtype: int32
TfLiteTensor* decode_input = decode_runner->input_tensor("tokens");
// Shape: [Seq], Dtype: int32
TfLiteTensor* decode_input_pos = decode_runner->input_tensor("input_pos");
// Shape: [Batch, 1, Seq], Dtype: int32
TfLiteTensor* decode_input_mask =
attention_mask_as_input ? decode_runner->input_tensor("attention_mask")
: nullptr;
// shape: [Batch, kv_cache_max, num_query_groups, head_dim]
TfLiteTensor* kv_cache_k_0 = decode_runner->input_tensor("kv_cache_k_0");

Expand All @@ -290,9 +315,12 @@ int main(int argc, char* argv[]) {
std::min(static_cast<int>(prompt_tokens.size()), max_seq_size);
std::memset(prefill_input->data.i32, 0, prefill_input->bytes);
std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes);
for (int i = 0; i < prefill_seq_size - 1; ++i) {
if (prefill_input_mask)
std::memset(prefill_input_mask->data.b, 0, prefill_input_mask->bytes);
for (int i = 0; i < prefill_seq_size; ++i) {
prefill_input->data.i32[i] = prompt_tokens[i];
prefill_input_pos->data.i32[i] = i;
if (prefill_input_mask) prefill_input_mask->data.b[i] = true;
}
MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);

Expand All @@ -305,13 +333,20 @@ int main(int argc, char* argv[]) {
std::min(max_decode_steps, kv_cache_max_size - prefill_seq_size);
MINIMAL_CHECK(decode_steps > 0);

if (decode_input_mask) {
std::memset(decode_input_mask->data.b, 0, decode_input_mask->bytes);
std::memcpy(decode_input_mask->data.b, prefill_input_mask->data.b,
sizeof(bool) * prefill_seq_size);
}
std::vector<int> output_tokens;
output_tokens.reserve(decode_steps);
int next_token = prompt_tokens[prefill_seq_size - 1];
int next_position = prefill_seq_size - 1;
for (int i = 0; i < decode_steps; ++i) {
decode_input->data.i32[0] = next_token;
decode_input_pos->data.i32[0] = next_position;
if (decode_input_mask) decode_input_mask->data.b[next_position] = true;

MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk);
next_token = GreedySampler(decode_runner->output_tensor("logits"));
output_tokens.push_back(next_token);
Expand Down

0 comments on commit b2342c2

Please sign in to comment.