From 846e2e9cf7d0062bb6df0d488753b289c8e1fe73 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Wed, 28 Aug 2024 15:48:49 -0700 Subject: [PATCH] Enable xnnpack weight caching to text generator example. PiperOrigin-RevId: 668638006 --- ai_edge_torch/generative/examples/cpp/BUILD | 1 + .../examples/cpp/text_generator_main.cc | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/ai_edge_torch/generative/examples/cpp/BUILD b/ai_edge_torch/generative/examples/cpp/BUILD index bb17e270..5bb5a7ce 100644 --- a/ai_edge_torch/generative/examples/cpp/BUILD +++ b/ai_edge_torch/generative/examples/cpp/BUILD @@ -39,6 +39,7 @@ cc_binary( "@com_google_absl//absl/flags:parse", "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], diff --git a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc index 639162dc..c50068cd 100644 --- a/ai_edge_torch/generative/examples/cpp/text_generator_main.cc +++ b/ai_edge_torch/generative/examples/cpp/text_generator_main.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "src/sentencepiece_processor.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" @@ -68,6 +69,8 @@ ABSL_FLAG(std::string, stop_token, "", "Stop token used to deterine end of decoding loop. If not provided " "will decode until max_Seq_len or max_decode_steps."); ABSL_FLAG(int, num_threads, 4, "Number of threads to use. Defaults to 4."); +ABSL_FLAG(std::string, weight_cache_path, "", + "XNNPACK weight caching path, e.g. /tmp/model.xnnpack_cache."); namespace { @@ -80,6 +83,24 @@ std::unique_ptr LoadModel() { return model; } +void ApplyXNNPACKWeightCaching(tflite::Interpreter* interpreter) { + auto delegate_options = TfLiteXNNPackDelegateOptionsDefault(); + std::string weight_cache_path = absl::GetFlag(FLAGS_weight_cache_path); + delegate_options.weight_cache_file_path = weight_cache_path.c_str(); + delegate_options.num_threads = absl::GetFlag(FLAGS_num_threads); + delegate_options.flags |= + TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING; + delegate_options.flags |= + TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS; + + TFLITE_MINIMAL_CHECK(interpreter->ModifyGraphWithDelegate( + tflite::Interpreter::TfLiteDelegatePtr( + TfLiteXNNPackDelegateCreate(&delegate_options), + [](TfLiteDelegate* delegate) { + TfLiteXNNPackDelegateDelete(delegate); + })) == kTfLiteOk); +} + std::unique_ptr BuildInterpreter( tflite::FlatBufferModel* model, int num_threads) { tflite::ops::builtin::BuiltinOpResolver resolver; @@ -91,6 +112,12 @@ std::unique_ptr BuildInterpreter( std::unique_ptr interpreter; builder(&interpreter); TFLITE_MINIMAL_CHECK(interpreter != nullptr); + + if (!absl::GetFlag(FLAGS_weight_cache_path).empty()) { + // optionally use xnnpack with weight caching + ApplyXNNPACKWeightCaching(interpreter.get()); + } + return interpreter; }