Skip to content

Commit

Permalink
Update cpp example to use multi-prefill signature.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702020720
  • Loading branch information
hheydary authored and copybara-github committed Dec 4, 2024
1 parent 2df5bd2 commit 97424fe
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 82 deletions.
10 changes: 10 additions & 0 deletions ai_edge_torch/generative/examples/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ package(
default_visibility = ["//visibility:public"],
)

cc_library(
name = "utils",
hdrs = ["utils.h"],
deps = [
"@org_tensorflow//tensorflow/lite:util",
],
)

cc_binary(
name = "text_generator_main",
srcs = [
Expand All @@ -35,8 +43,10 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
":utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings",
"@com_google_sentencepiece//:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:util",
Expand Down
167 changes: 85 additions & 82 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
Expand All @@ -28,6 +29,8 @@ limitations under the License.

#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/strings/match.h"
#include "ai_edge_torch/generative/examples/cpp/utils.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/genai/genai_ops.h"
Expand All @@ -36,7 +39,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/signature_runner.h"
#include "tensorflow/lite/util.h"

// This is a simplified example of using TFLite to generate text.
// Please note that this is only a starting point and the user is expected
Expand All @@ -52,12 +54,6 @@ limitations under the License.
// --stop_token="<eos>" \
// --num_threads=4

#define TFLITE_MINIMAL_CHECK(x) \
if (!(x)) { \
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
}

ABSL_FLAG(std::string, tflite_model, "",
"Two-signature tflite model prepared for text generation using ODML "
"tools.");
Expand All @@ -78,12 +74,13 @@ ABSL_FLAG(std::string, weight_cache_path, "",

namespace {

// Prepare helpers
using ai_edge_torch::examples::AlignedAllocator;

std::unique_ptr<tflite::FlatBufferModel> LoadModel() {
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(
absl::GetFlag(FLAGS_tflite_model).c_str());
TFLITE_MINIMAL_CHECK(model != nullptr);
MINIMAL_CHECK(model != nullptr);
return model;
}

Expand All @@ -97,12 +94,12 @@ void ApplyXNNPACKWeightCaching(tflite::Interpreter* interpreter) {
delegate_options.flags |=
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;

TFLITE_MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate(
tflite::Interpreter::TfLiteDelegatePtr(
TfLiteXNNPackDelegateCreate(&delegate_options),
[](TfLiteDelegate* delegate) {
TfLiteXNNPackDelegateDelete(delegate);
})) == kTfLiteOk);
MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate(
tflite::Interpreter::TfLiteDelegatePtr(
TfLiteXNNPackDelegateCreate(&delegate_options),
[](TfLiteDelegate* delegate) {
TfLiteXNNPackDelegateDelete(delegate);
})) == kTfLiteOk);
}

std::unique_ptr<tflite::Interpreter> BuildInterpreter(
Expand All @@ -112,10 +109,10 @@ std::unique_ptr<tflite::Interpreter> BuildInterpreter(
// Scaled Dot Product Attention (SDPA).
tflite::ops::custom::GenAIOpsRegisterer(&resolver);
tflite::InterpreterBuilder builder(*model, resolver);
TFLITE_MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk);
MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk);
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);
MINIMAL_CHECK(interpreter != nullptr);

if (!absl::GetFlag(FLAGS_weight_cache_path).empty()) {
// optionally use xnnpack with weight caching
Expand All @@ -125,34 +122,9 @@ std::unique_ptr<tflite::Interpreter> BuildInterpreter(
return interpreter;
}

// TF Lite requires all buffers (including external buffers used for KV cache
// here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use
// this custom allocator. Please use with caution as different platforms may
// have different alignment requirements.
template <typename T>
class AlignedAllocator {
public:
using value_type = T;

T* allocate(std::size_t n) {
void* ptr;
std::size_t size = n * sizeof(T);
std::size_t padding = tflite::kDefaultTensorAlignment -
(size % tflite::kDefaultTensorAlignment);
size += padding;
int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size);
if (ret != 0) {
return nullptr;
}
return static_cast<T*>(ptr);
}

void deallocate(T* p, std::size_t n) { free(p); }
};

std::map<std::string, std::vector<float, AlignedAllocator<float>>> BuildKVCache(
tflite::Interpreter* interpreter) {
tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("prefill");
tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode");
if (runner == nullptr) {
return {};
}
Expand All @@ -178,25 +150,57 @@ std::map<std::string, std::vector<float, AlignedAllocator<float>>> BuildKVCache(
return kv_cache;
}

tflite::SignatureRunner* GetSignatureRunner(
tflite::Interpreter* interpreter, const std::string& signature_name,
void PrepareRunner(
tflite::SignatureRunner* runner,
std::map<std::string, std::vector<float, AlignedAllocator<float>>>&
kv_cache) {
tflite::SignatureRunner* runner =
interpreter->GetSignatureRunner(signature_name.c_str());
for (auto& [name, cache] : kv_cache) {
TfLiteCustomAllocation allocation = {
.data = static_cast<void*>(cache.data()),
.bytes = cache.size() * sizeof(float)};
// Both input and output tensors are set to the same buffer. Not all
// delegates support this in-place update. For those cases, we need to do
// a ping-pong buffer and update the pointers between inference calls.
TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor(
name.c_str(), allocation) == kTfLiteOk);
TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor(
name.c_str(), allocation) == kTfLiteOk);
MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor(
name.c_str(), allocation) == kTfLiteOk);
MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor(
name.c_str(), allocation) == kTfLiteOk);
}
TFLITE_MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk);
MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk);
}

tflite::SignatureRunner* GetPrefillRunner(
tflite::Interpreter* interpreter, std::size_t num_input_tokens,
std::map<std::string, std::vector<float, AlignedAllocator<float>>>&
kv_cache) {
// Find the prefill signature that best matches the input token size.
tflite::SignatureRunner* runner = nullptr;
int delta = std::numeric_limits<int>::max();
for (const std::string* key : interpreter->signature_keys()) {
if (!absl::StrContains(*key, "prefill")) {
continue;
}
TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())
->input_tensor("input_pos");
// The expected shape for input position is [Seq].
int seq_size = input_pos->dims->data[0];
if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) {
runner = interpreter->GetSignatureRunner(key->c_str());
delta = seq_size - num_input_tokens;
}
}
MINIMAL_CHECK(runner != nullptr);
PrepareRunner(runner, kv_cache);
return runner;
}

tflite::SignatureRunner* GetDecodeRunner(
tflite::Interpreter* interpreter,
std::map<std::string, std::vector<float, AlignedAllocator<float>>>&
kv_cache) {
tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode");
MINIMAL_CHECK(runner != nullptr);
PrepareRunner(runner, kv_cache);
return runner;
}

Expand All @@ -207,8 +211,7 @@ LoadSentencePieceProcessor() {
std::string serialized_proto = std::string(
std::istreambuf_iterator<char>(input), std::istreambuf_iterator<char>());
auto processor = std::make_unique<sentencepiece::SentencePieceProcessor>();
TFLITE_MINIMAL_CHECK(
processor->LoadFromSerializedProto(serialized_proto).ok());
MINIMAL_CHECK(processor->LoadFromSerializedProto(serialized_proto).ok());
return processor;
}

Expand Down Expand Up @@ -239,16 +242,32 @@ int main(int argc, char* argv[]) {
LoadSentencePieceProcessor();
std::map<std::string, std::vector<float, AlignedAllocator<float>>> kv_cache =
BuildKVCache(interpreter.get());
TFLITE_MINIMAL_CHECK(!kv_cache.empty())
MINIMAL_CHECK(!kv_cache.empty())

// Get prefill and decode signature runners and allocate tensors per
// signature.
tflite::SignatureRunner* prefill_runner =
GetSignatureRunner(interpreter.get(), "prefill", kv_cache);
TFLITE_MINIMAL_CHECK(prefill_runner != nullptr);
// Tokenize the input prompt.
std::string prompt = absl::GetFlag(FLAGS_prompt);
std::vector<int> prompt_tokens;
MINIMAL_CHECK(sp_processor->Encode(prompt, &prompt_tokens).ok());

std::string start_token = absl::GetFlag(FLAGS_start_token);
if (!start_token.empty()) {
prompt_tokens.insert(prompt_tokens.begin(),
sp_processor->PieceToId((start_token)));
}
std::string stop_token = absl::GetFlag(FLAGS_stop_token);
int stop_token_id = -1;
if (!stop_token.empty()) {
stop_token_id = sp_processor->PieceToId((stop_token));
}

// Get prefill and decode signature runners.
std::size_t effective_prefill_token_size = prompt_tokens.size() - 1;
tflite::SignatureRunner* prefill_runner = GetPrefillRunner(
interpreter.get(), effective_prefill_token_size, kv_cache);
MINIMAL_CHECK(prefill_runner != nullptr);
tflite::SignatureRunner* decode_runner =
GetSignatureRunner(interpreter.get(), "decode", kv_cache);
TFLITE_MINIMAL_CHECK(decode_runner != nullptr);
GetDecodeRunner(interpreter.get(), kv_cache);
MINIMAL_CHECK(decode_runner != nullptr);

// Get Input Tensors for each of the runners.
// Shape: [Batch, Seq], Dtype: int32
Expand All @@ -265,22 +284,6 @@ int main(int argc, char* argv[]) {
int max_seq_size = prefill_input->dims->data[1];
int kv_cache_max_size = kv_cache_k_0->dims->data[1];

// Tokenize the input prompt.
std::string prompt = absl::GetFlag(FLAGS_prompt);
std::vector<int> prompt_tokens;
TFLITE_MINIMAL_CHECK(sp_processor->Encode(prompt, &prompt_tokens).ok());

std::string start_token = absl::GetFlag(FLAGS_start_token);
if (!start_token.empty()) {
prompt_tokens.insert(prompt_tokens.begin(),
sp_processor->PieceToId((start_token)));
}
std::string stop_token = absl::GetFlag(FLAGS_stop_token);
int stop_token_id = -1;
if (!stop_token.empty()) {
stop_token_id = sp_processor->PieceToId((stop_token));
}

// Fill in the inputs (assuming one batch).
// NOTE: We skip the last token and use that during decode.
int prefill_seq_size =
Expand All @@ -291,7 +294,7 @@ int main(int argc, char* argv[]) {
prefill_input->data.i32[i] = prompt_tokens[i];
prefill_input_pos->data.i32[i] = i;
}
TFLITE_MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);
MINIMAL_CHECK(prefill_runner->Invoke() == kTfLiteOk);

// Decode until max kv-cache size or user defined step limit, whichever is
// smaller.
Expand All @@ -300,7 +303,7 @@ int main(int argc, char* argv[]) {
: absl::GetFlag(FLAGS_max_decode_steps);
int decode_steps =
std::min(max_decode_steps, kv_cache_max_size - prefill_seq_size);
TFLITE_MINIMAL_CHECK(decode_steps > 0);
MINIMAL_CHECK(decode_steps > 0);

std::vector<int> output_tokens;
output_tokens.reserve(decode_steps);
Expand All @@ -309,7 +312,7 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < decode_steps; ++i) {
decode_input->data.i32[0] = next_token;
decode_input_pos->data.i32[0] = next_position;
TFLITE_MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk);
MINIMAL_CHECK(decode_runner->Invoke() == kTfLiteOk);
next_token = GreedySampler(decode_runner->output_tensor("logits"));
output_tokens.push_back(next_token);
next_position += 1;
Expand All @@ -320,7 +323,7 @@ int main(int argc, char* argv[]) {

// Detokenize the generated output.
std::string output_text;
TFLITE_MINIMAL_CHECK(sp_processor->Decode(output_tokens, &output_text).ok());
MINIMAL_CHECK(sp_processor->Decode(output_tokens, &output_text).ok());

printf("Prompt:\n%s\nOutput text:\n%s\n", prompt.c_str(),
output_text.c_str());
Expand Down
44 changes: 44 additions & 0 deletions ai_edge_torch/generative/examples/cpp/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_
#define THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_

#include <cstddef>

#include "tensorflow/lite/util.h"

namespace ai_edge_torch::examples {

// A minimal check macro.
#define MINIMAL_CHECK(x) \
if (!(x)) { \
fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
exit(1); \
}

// TF Lite requires all buffers (including external buffers used for KV cache
// here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use
// this custom allocator. Please use with caution as different platforms may
// have different alignment requirements.
template <typename T>
class AlignedAllocator {
public:
using value_type = T;

T* allocate(std::size_t n) {
void* ptr;
std::size_t size = n * sizeof(T);
std::size_t padding = tflite::kDefaultTensorAlignment -
(size % tflite::kDefaultTensorAlignment);
size += padding;
int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size);
if (ret != 0) {
return nullptr;
}
return static_cast<T*>(ptr);
};

void deallocate(T* ptr, std::size_t n) { free(ptr); }
};

} // namespace ai_edge_torch::examples

#endif // THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_

0 comments on commit 97424fe

Please sign in to comment.