Skip to content

Commit

Permalink
Enable xnnpack weight caching to text generator example.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668638006
  • Loading branch information
ai-edge-bot authored and copybara-github committed Aug 28, 2024
1 parent 2a3ad50 commit 846e2e9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions ai_edge_torch/generative/examples/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ cc_binary(
"@com_google_absl//absl/flags:parse",
"@com_google_sentencepiece//:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
Expand Down
27 changes: 27 additions & 0 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/genai/genai_ops.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
Expand Down Expand Up @@ -68,6 +69,8 @@ ABSL_FLAG(std::string, stop_token, "",
"Stop token used to deterine end of decoding loop. If not provided "
"will decode until max_Seq_len or max_decode_steps.");
ABSL_FLAG(int, num_threads, 4, "Number of threads to use. Defaults to 4.");
ABSL_FLAG(std::string, weight_cache_path, "",
"XNNPACK weight caching path, e.g. /tmp/model.xnnpack_cache.");

namespace {

Expand All @@ -80,6 +83,24 @@ std::unique_ptr<tflite::FlatBufferModel> LoadModel() {
return model;
}

void ApplyXNNPACKWeightCaching(tflite::Interpreter* interpreter) {
auto delegate_options = TfLiteXNNPackDelegateOptionsDefault();
std::string weight_cache_path = absl::GetFlag(FLAGS_weight_cache_path);
delegate_options.weight_cache_file_path = weight_cache_path.c_str();
delegate_options.num_threads = absl::GetFlag(FLAGS_num_threads);
delegate_options.flags |=
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING;
delegate_options.flags |=
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS;

TFLITE_MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate(
tflite::Interpreter::TfLiteDelegatePtr(
TfLiteXNNPackDelegateCreate(&delegate_options),
[](TfLiteDelegate* delegate) {
TfLiteXNNPackDelegateDelete(delegate);
})) == kTfLiteOk);
}

std::unique_ptr<tflite::Interpreter> BuildInterpreter(
tflite::FlatBufferModel* model, int num_threads) {
tflite::ops::builtin::BuiltinOpResolver resolver;
Expand All @@ -91,6 +112,12 @@ std::unique_ptr<tflite::Interpreter> BuildInterpreter(
std::unique_ptr<tflite::Interpreter> interpreter;
builder(&interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);

if (!absl::GetFlag(FLAGS_weight_cache_path).empty()) {
// optionally use xnnpack with weight caching
ApplyXNNPACKWeightCaching(interpreter.get());
}

return interpreter;
}

Expand Down

0 comments on commit 846e2e9

Please sign in to comment.