From 97424fe2ed66abc884b9ba70c2be3b4b71a6bdc3 Mon Sep 17 00:00:00 2001 From: Mohammadreza Heydary Date: Mon, 2 Dec 2024 10:54:21 -0800 Subject: [PATCH] Update cpp example to use multi-prefill signature. PiperOrigin-RevId: 702020720 --- ai_edge_torch/generative/examples/cpp/BUILD | 10 ++ .../examples/cpp/text_generator_main.cc | 167 +++++++++--------- ai_edge_torch/generative/examples/cpp/utils.h | 44 +++++ 3 files changed, 139 insertions(+), 82 deletions(-) create mode 100644 ai_edge_torch/generative/examples/cpp/utils.h diff --git a/ai_edge_torch/generative/examples/cpp/BUILD b/ai_edge_torch/generative/examples/cpp/BUILD index 9e77d59b..6c3c167e 100644 --- a/ai_edge_torch/generative/examples/cpp/BUILD +++ b/ai_edge_torch/generative/examples/cpp/BUILD @@ -22,6 +22,14 @@ package( default_visibility = ["//visibility:public"], ) +cc_library( + name = "utils", + hdrs = ["utils.h"], + deps = [ + "@org_tensorflow//tensorflow/lite:util", + ], +) + cc_binary( name = "text_generator_main", srcs = [ @@ -35,8 +43,10 @@ cc_binary( "//conditions:default": [], }), deps = [ + ":utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/strings", "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:util", diff --git a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc index 43f40cab..ea67ce40 100644 --- a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc +++ b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -28,6 +29,8 @@ limitations under the License. #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/strings/match.h" +#include "ai_edge_torch/generative/examples/cpp/utils.h" #include "src/sentencepiece_processor.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/experimental/genai/genai_ops.h" @@ -36,7 +39,6 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/signature_runner.h" -#include "tensorflow/lite/util.h" // This is a simplified example of using TFLite to generate text. // Please note that this is only a starting point and the user is expected @@ -52,12 +54,6 @@ limitations under the License. // --stop_token="" \ // --num_threads=4 -#define TFLITE_MINIMAL_CHECK(x) \ - if (!(x)) { \ - fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ - exit(1); \ - } - ABSL_FLAG(std::string, tflite_model, "", "Two-signature tflite model prepared for text generation using ODML " "tools."); @@ -78,12 +74,13 @@ ABSL_FLAG(std::string, weight_cache_path, "", namespace { -// Prepare helpers +using ai_edge_torch::examples::AlignedAllocator; + std::unique_ptr LoadModel() { std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile( absl::GetFlag(FLAGS_tflite_model).c_str()); - TFLITE_MINIMAL_CHECK(model != nullptr); + MINIMAL_CHECK(model != nullptr); return model; } @@ -97,12 +94,12 @@ void ApplyXNNPACKWeightCaching(tflite::Interpreter* interpreter) { delegate_options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS; - TFLITE_MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate( - tflite::Interpreter::TfLiteDelegatePtr( - TfLiteXNNPackDelegateCreate(&delegate_options), - [](TfLiteDelegate* delegate) { - TfLiteXNNPackDelegateDelete(delegate); - })) == kTfLiteOk); + MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate( + tflite::Interpreter::TfLiteDelegatePtr( + TfLiteXNNPackDelegateCreate(&delegate_options), + [](TfLiteDelegate* delegate) { + TfLiteXNNPackDelegateDelete(delegate); + })) == kTfLiteOk); } std::unique_ptr BuildInterpreter( @@ -112,10 +109,10 @@ std::unique_ptr BuildInterpreter( // Scaled Dot Product Attention (SDPA). tflite::ops::custom::GenAIOpsRegisterer(&resolver); tflite::InterpreterBuilder builder(*model, resolver); - TFLITE_MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk); + MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk); std::unique_ptr interpreter; builder(&interpreter); - TFLITE_MINIMAL_CHECK(interpreter != nullptr); + MINIMAL_CHECK(interpreter != nullptr); if (!absl::GetFlag(FLAGS_weight_cache_path).empty()) { // optionally use xnnpack with weight caching @@ -125,34 +122,9 @@ std::unique_ptr BuildInterpreter( return interpreter; } -// TF Lite requires all buffers (including external buffers used for KV cache -// here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use -// this custom allocator. Please use with caution as different platforms may -// have different alignment requirements. -template -class AlignedAllocator { - public: - using value_type = T; - - T* allocate(std::size_t n) { - void* ptr; - std::size_t size = n * sizeof(T); - std::size_t padding = tflite::kDefaultTensorAlignment - - (size % tflite::kDefaultTensorAlignment); - size += padding; - int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); - if (ret != 0) { - return nullptr; - } - return static_cast(ptr); - } - - void deallocate(T* p, std::size_t n) { free(p); } -}; - std::map>> BuildKVCache( tflite::Interpreter* interpreter) { - tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("prefill"); + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); if (runner == nullptr) { return {}; } @@ -178,12 +150,10 @@ std::map>> BuildKVCache( return kv_cache; } -tflite::SignatureRunner* GetSignatureRunner( - tflite::Interpreter* interpreter, const std::string& signature_name, +void PrepareRunner( + tflite::SignatureRunner* runner, std::map>>& kv_cache) { - tflite::SignatureRunner* runner = - interpreter->GetSignatureRunner(signature_name.c_str()); for (auto& [name, cache] : kv_cache) { TfLiteCustomAllocation allocation = { .data = static_cast(cache.data()), @@ -191,12 +161,46 @@ tflite::SignatureRunner* GetSignatureRunner( // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do // a ping-pong buffer and update the pointers between inference calls. - TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor( - name.c_str(), allocation) == kTfLiteOk); - TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor( - name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor( + name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor( + name.c_str(), allocation) == kTfLiteOk); } - TFLITE_MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk); + MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk); +} + +tflite::SignatureRunner* GetPrefillRunner( + tflite::Interpreter* interpreter, std::size_t num_input_tokens, + std::map>>& + kv_cache) { + // Find the prefill signature that best matches the input token size. + tflite::SignatureRunner* runner = nullptr; + int delta = std::numeric_limits::max(); + for (const std::string* key : interpreter->signature_keys()) { + if (!absl::StrContains(*key, "prefill")) { + continue; + } + TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str()) + ->input_tensor("input_pos"); + // The expected shape for input position is [Seq]. + int seq_size = input_pos->dims->data[0]; + if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { + runner = interpreter->GetSignatureRunner(key->c_str()); + delta = seq_size - num_input_tokens; + } + } + MINIMAL_CHECK(runner != nullptr); + PrepareRunner(runner, kv_cache); + return runner; +} + +tflite::SignatureRunner* GetDecodeRunner( + tflite::Interpreter* interpreter, + std::map>>& + kv_cache) { + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); + MINIMAL_CHECK(runner != nullptr); + PrepareRunner(runner, kv_cache); return runner; } @@ -207,8 +211,7 @@ LoadSentencePieceProcessor() { std::string serialized_proto = std::string( std::istreambuf_iterator(input), std::istreambuf_iterator()); auto processor = std::make_unique(); - TFLITE_MINIMAL_CHECK( - processor->LoadFromSerializedProto(serialized_proto).ok()); + MINIMAL_CHECK(processor->LoadFromSerializedProto(serialized_proto).ok()); return processor; } @@ -239,16 +242,32 @@ int main(int argc, char* argv[]) { LoadSentencePieceProcessor(); std::map>> kv_cache = BuildKVCache(interpreter.get()); - TFLITE_MINIMAL_CHECK(!kv_cache.empty()) + MINIMAL_CHECK(!kv_cache.empty()) - // Get prefill and decode signature runners and allocate tensors per - // signature. - tflite::SignatureRunner* prefill_runner = - GetSignatureRunner(interpreter.get(), "prefill", kv_cache); - TFLITE_MINIMAL_CHECK(prefill_runner != nullptr); + // Tokenize the input prompt. + std::string prompt = absl::GetFlag(FLAGS_prompt); + std::vector prompt_tokens; + MINIMAL_CHECK(sp_processor->Encode(prompt, &prompt_tokens).ok()); + + std::string start_token = absl::GetFlag(FLAGS_start_token); + if (!start_token.empty()) { + prompt_tokens.insert(prompt_tokens.begin(), + sp_processor->PieceToId((start_token))); + } + std::string stop_token = absl::GetFlag(FLAGS_stop_token); + int stop_token_id = -1; + if (!stop_token.empty()) { + stop_token_id = sp_processor->PieceToId((stop_token)); + } + + // Get prefill and decode signature runners. + std::size_t effective_prefill_token_size = prompt_tokens.size() - 1; + tflite::SignatureRunner* prefill_runner = GetPrefillRunner( + interpreter.get(), effective_prefill_token_size, kv_cache); + MINIMAL_CHECK(prefill_runner != nullptr); tflite::SignatureRunner* decode_runner = - GetSignatureRunner(interpreter.get(), "decode", kv_cache); - TFLITE_MINIMAL_CHECK(decode_runner != nullptr); + GetDecodeRunner(interpreter.get(), kv_cache); + MINIMAL_CHECK(decode_runner != nullptr); // Get Input Tensors for each of the runners. // Shape: [Batch, Seq], Dtype: int32 @@ -265,22 +284,6 @@ int main(int argc, char* argv[]) { int max_seq_size = prefill_input->dims->data[1]; int kv_cache_max_size = kv_cache_k_0->dims->data[1]; - // Tokenize the input prompt. - std::string prompt = absl::GetFlag(FLAGS_prompt); - std::vector prompt_tokens; - TFLITE_MINIMAL_CHECK(sp_processor->Encode(prompt, &prompt_tokens).ok()); - - std::string start_token = absl::GetFlag(FLAGS_start_token); - if (!start_token.empty()) { - prompt_tokens.insert(prompt_tokens.begin(), - sp_processor->PieceToId((start_token))); - } - std::string stop_token = absl::GetFlag(FLAGS_stop_token); - int stop_token_id = -1; - if (!stop_token.empty()) { - stop_token_id = sp_processor->PieceToId((stop_token)); - } - // Fill in the inputs (assuming one batch). // NOTE: We skip the last token and use that during decode. int prefill_seq_size = @@ -291,7 +294,7 @@ int main(int argc, char* argv[]) { prefill_input->data.i32[i] = prompt_tokens[i]; prefill_input_pos->data.i32[i] = i; } - TFLITE_MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk); + MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk); // Decode until max kv-cache size or user defined step limit, whichever is // smaller. @@ -300,7 +303,7 @@ int main(int argc, char* argv[]) { : absl::GetFlag(FLAGS_max_decode_steps); int decode_steps = std::min(max_decode_steps, kv_cache_max_size - prefill_seq_size); - TFLITE_MINIMAL_CHECK(decode_steps > 0); + MINIMAL_CHECK(decode_steps > 0); std::vector output_tokens; output_tokens.reserve(decode_steps); @@ -309,7 +312,7 @@ int main(int argc, char* argv[]) { for (int i = 0; i < decode_steps; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; - TFLITE_MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk); + MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk); next_token = GreedySampler(decode_runner->output_tensor("logits")); output_tokens.push_back(next_token); next_position += 1; @@ -320,7 +323,7 @@ int main(int argc, char* argv[]) { // Detokenize the generated output. std::string output_text; - TFLITE_MINIMAL_CHECK(sp_processor->Decode(output_tokens, &output_text).ok()); + MINIMAL_CHECK(sp_processor->Decode(output_tokens, &output_text).ok()); printf("Prompt:\n%s\nOutput text:\n%s\n", prompt.c_str(), output_text.c_str()); diff --git a/ai_edge_torch/generative/examples/cpp/utils.h b/ai_edge_torch/generative/examples/cpp/utils.h new file mode 100644 index 00000000..f684bd55 --- /dev/null +++ b/ai_edge_torch/generative/examples/cpp/utils.h @@ -0,0 +1,44 @@ +#ifndef THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_ +#define THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_ + +#include + +#include "tensorflow/lite/util.h" + +namespace ai_edge_torch::examples { + +// A minimal check macro. +#define MINIMAL_CHECK(x) \ + if (!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ + } + +// TF Lite requires all buffers (including external buffers used for KV cache +// here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use +// this custom allocator. Please use with caution as different platforms may +// have different alignment requirements. +template +class AlignedAllocator { + public: + using value_type = T; + + T* allocate(std::size_t n) { + void* ptr; + std::size_t size = n * sizeof(T); + std::size_t padding = tflite::kDefaultTensorAlignment - + (size % tflite::kDefaultTensorAlignment); + size += padding; + int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); + if (ret != 0) { + return nullptr; + } + return static_cast(ptr); + }; + + void deallocate(T* ptr, std::size_t n) { free(ptr); } +}; + +} // namespace ai_edge_torch::examples + +#endif // THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_