Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631415620
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 7, 2024
1 parent c2e6427 commit 99fc736
Show file tree
Hide file tree
Showing 23 changed files with 1,308 additions and 51 deletions.
62 changes: 62 additions & 0 deletions mediapipe/calculators/tensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,66 @@ cc_test(
],
)

cc_library_with_tflite(
name = "inference_feedback_manager",
srcs = ["inference_feedback_manager.cc"],
hdrs = ["inference_feedback_manager.h"],
tflite_deps = [
":inference_io_mapper",
"//mediapipe/util/tflite:utils",
"//mediapipe/util/tflite:tflite_signature_reader",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:namespace",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
deps = [
":inference_calculator_cc_proto",
":inference_calculator_utils",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

cc_test(
name = "inference_feedback_manager_test",
srcs = ["inference_feedback_manager_test.cc"],
data = [
":testdata/feedback_tensor_test_model.tflite",
":testdata/feedback_tensor_with_state_copy_model.tflite",
],
deps = [
":inference_calculator",
":inference_calculator_cc_proto",
":inference_calculator_cpu",
":inference_calculator_interface",
":inference_feedback_manager",
":inference_io_mapper",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"//mediapipe/framework/tool:sink",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/core:framework",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)

cc_library(
name = "inference_calculator_gl",
srcs = ["inference_calculator_gl.cc"],
Expand Down Expand Up @@ -645,6 +705,7 @@ cc_library_with_tflite(
":inference_runner",
":tflite_delegate_ptr",
":inference_io_mapper",
":inference_feedback_manager",
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
Expand All @@ -658,6 +719,7 @@ cc_library_with_tflite(
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:string_util",
Expand Down
12 changes: 12 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,17 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
}

void InferenceCalculator::WarnFeedbackTensorsUnsupported(
CalculatorContract* cc) {
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
if (options.has_input_output_config() &&
!options.input_output_config().feedback_tensor_links().empty()) {
ABSL_LOG(WARNING)
<< "Feedback tensor support is only available for CPU and "
<< "XNNPACK inference. Ignoring "
"input_output_config.feedback_tensor_links option.";
}
}

} // namespace api2
} // namespace mediapipe
3 changes: 3 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class InferenceCalculator : public NodeIntf {

static absl::StatusOr<Packet<tflite::OpResolver>> GetOpResolverAsPacket(
CalculatorContext* cc);

// Checks if feedback tensor support is available and warns otherwise.
static void WarnFeedbackTensorsUnsupported(CalculatorContract* cc);
};

struct InferenceCalculatorSelector : public InferenceCalculator {
Expand Down
20 changes: 20 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,26 @@ message InferenceCalculatorOptions {
TensorIndicesMap output_tensor_indices_map = 2;
TensorNamesMap output_tensor_names_map = 4;
}

// Feedback tensor links are pairs of model input / output tensors where
// the output should be set as inputs in the next model invocation. This
// allows to manage a notion of temporal state by continuously feeding from
// the model's output to the model's input during each inference step. Note
// that these feedback tensors must be excluded from the input/output
// tensor maps above as they are not used as regular inputs/outputs of the
// inference calculator.
message FeedbackTensorLink {
// TfLite output tensor name from default TfLite signature to use as
// source.
optional string from_output_tensor_name = 1;
// TfLite tensor name from default TfLitesignature to pass input
// tensor to.
optional string to_input_tensor_name = 2;
}

// Defines a mapping between output tensors that should be
// used as input tensors during the next inference invocation.
repeated FeedbackTensorLink feedback_tensor_links = 5;
}

// Optionally remaps input and output tensors to align with TfLite model and
Expand Down
4 changes: 3 additions & 1 deletion mediapipe/calculators/tensor/inference_calculator_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ absl::StatusOr<std::unique_ptr<InferenceRunner>>
InferenceCalculatorCpuImpl::CreateInferenceRunner(CalculatorContext* cc) {
MP_ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
MP_ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
const int interpreter_num_threads =
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
MP_ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, MaybeCreateDelegate(cc));
return CreateInferenceInterpreterDelegateRunner(
std::move(model_packet), std::move(op_resolver_packet),
std::move(delegate), interpreter_num_threads);
std::move(delegate), interpreter_num_threads,
&options.input_output_config());
}

absl::StatusOr<TfLiteDelegatePtr>
Expand Down
1 change: 1 addition & 0 deletions mediapipe/calculators/tensor/inference_calculator_gl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";

WarnFeedbackTensorsUnsupported(cc);
return mediapipe::GlCalculatorHelper::UpdateContract(cc);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";

WarnFeedbackTensorsUnsupported(cc);
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
return absl::OkStatus();
}
Expand Down
1 change: 1 addition & 0 deletions mediapipe/calculators/tensor/inference_calculator_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ absl::Status InferenceCalculatorMetalImpl::UpdateContract(
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";

WarnFeedbackTensorsUnsupported(cc);
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
return absl::OkStatus();
}
Expand Down
7 changes: 4 additions & 3 deletions mediapipe/calculators/tensor/inference_calculator_xnnpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ absl::StatusOr<std::unique_ptr<InferenceRunner>>
InferenceCalculatorXnnpackImpl::CreateInferenceRunner(CalculatorContext* cc) {
MP_ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
MP_ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const int interpreter_num_threads =
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
const int interpreter_num_threads = options.cpu_num_thread();
MP_ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, CreateDelegate(cc));
return CreateInferenceInterpreterDelegateRunner(
std::move(model_packet), std::move(op_resolver_packet),
std::move(delegate), interpreter_num_threads);
std::move(delegate), interpreter_num_threads,
&options.input_output_config());
}

absl::StatusOr<TfLiteDelegatePtr>
Expand Down
191 changes: 191 additions & 0 deletions mediapipe/calculators/tensor/inference_feedback_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#include "mediapipe/calculators/tensor/inference_feedback_manager.h"

#include <cstring>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_signature_reader.h"
#include "mediapipe/util/tflite/utils.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {

namespace {

bool TfLiteTensorSpecEqual(const TfLiteTensor& a, const TfLiteTensor& b) {
return a.type == b.type && TfLiteIntArrayEqual(a.dims, b.dims) &&
a.params.scale == b.params.scale &&
a.params.zero_point == b.params.zero_point &&
a.allocation_type == b.allocation_type && a.bytes == b.bytes;
}

absl::flat_hash_map<std::string, int> CreateNameToIndexMap(
const std::vector<std::string>& names) {
absl::flat_hash_map<std::string, int> name_to_index_map;
for (int i = 0; i < names.size(); ++i) {
name_to_index_map[names[i]] = i;
}
return name_to_index_map;
}

} // namespace

absl::Status InferenceFeedbackManager::Init(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const InputOutputTensorNames& input_output_tensor_names,
tflite::Interpreter* interpreter) {
interpreter_ = interpreter;
MP_ASSIGN_OR_RETURN(feedback_tensor_indices_links_,
ConvertSignatureTensorNamesToModelIndices(
io_config, input_output_tensor_names));

for (const auto& link : feedback_tensor_indices_links_) {
const auto [output_unused_iter, output_was_inserted] =
feedback_output_indices_.insert(link.from_idx);
RET_CHECK(output_was_inserted) << "Feedback output tensors must be unique.";
TfLiteTensor* from_tensor =
interpreter_->tensor(interpreter->outputs()[link.from_idx]);
RET_CHECK(!util::tflite::IsDynamicTensor(*from_tensor))
<< "Feedback output tensors must not be dynamic.";
const auto [input_unused_iter, input_was_inserted] =
feedback_input_indices_.insert(link.to_idx);
RET_CHECK(input_was_inserted) << "Feedback input tensors must be unique.";
TfLiteTensor* to_tensor =
interpreter_->tensor(interpreter->inputs()[link.to_idx]);
RET_CHECK(!util::tflite::IsDynamicTensor(*to_tensor))
<< "Feedback input tensors must not be dynamic.";
RET_CHECK(TfLiteTensorSpecEqual(*from_tensor, *to_tensor))
<< "Feedback tensors must have the same spec.";
// Since the TfLite API isn't specific about the initialization of newly
// allocated Tensor memory, we initialize the input to_tensor tensor with
// zeros.
memset(to_tensor->data.raw, 0, to_tensor->bytes);
}

// Populate input_tensor_to_model_indices_ which maps InferenceRunner input
// tensors indices to the model input indices.
input_tensor_to_model_indices_.reserve(interpreter_->inputs().size());
for (int i = 0; i < interpreter_->inputs().size(); ++i) {
if (!feedback_input_indices_.contains(i)) {
input_tensor_to_model_indices_.push_back(i);
}
}
return absl::OkStatus();
}

void InferenceFeedbackManager::SwapFeedbackTensors() {
for (const auto& link : feedback_tensor_indices_links_) {
TfLiteTensor* from_tensor =
interpreter_->tensor(interpreter_->outputs()[link.from_idx]);
TfLiteTensor* to_tensor =
interpreter_->tensor(interpreter_->inputs()[link.to_idx]);
{
using std::swap;
// TODO b/338023494 - Use TfLite CustomAllocator to manage memory of
// feedback tensors (replace std::swap)
swap(*from_tensor, *to_tensor);
}
}
}

// static
absl::StatusOr<std::vector<InferenceFeedbackManager::TensorFeedbackIndicesLink>>
InferenceFeedbackManager::ConvertSignatureTensorNamesToModelIndices(
const InferenceCalculatorOptions::InputOutputConfig& io_config,
const InputOutputTensorNames& input_output_tensor_names_map) {
std::vector<TensorFeedbackIndicesLink> indices_links;
if (input_output_tensor_names_map.empty() ||
input_output_tensor_names_map.size() > 1) {
// Fail gracefully by returning an empty TensorFeedbackIndicesLink list if
// SignatureDef is not available or not supported.
ABSL_LOG(WARNING)
<< "Feedback manager requires a model with a single signature "
"inference. Disabling support for feedback tensors.";
return indices_links;
}
// Obtain reference to single-signature in input_output_tensor_names_map.
const auto& input_output_tensor_names =
input_output_tensor_names_map.begin()->second;

const auto input_name_to_index_map =
CreateNameToIndexMap(input_output_tensor_names.input_tensor_names);
const auto output_name_to_index_map =
CreateNameToIndexMap(input_output_tensor_names.output_tensor_names);

// Create a set of all input/output tensor names used for InferenceCalculator
// I/O mapping.
absl::flat_hash_set<std::string> input_output_mapping_tensor_names;
for (const auto& name : io_config.input_tensor_names_map().tensor_names()) {
input_output_mapping_tensor_names.insert(name);
}
for (const auto& name : io_config.output_tensor_names_map().tensor_names()) {
input_output_mapping_tensor_names.insert(name);
}

for (const auto& link : io_config.feedback_tensor_links()) {
RET_CHECK(!input_output_mapping_tensor_names.contains(
link.from_output_tensor_name()))
<< absl::StrFormat(
"Feedback output tensor [%s] cannot be used for input/output "
"mapping. Input/output mapping tensor names: [%s]",
link.from_output_tensor_name(),
absl::StrJoin(input_output_mapping_tensor_names, ", "));
RET_CHECK(!input_output_mapping_tensor_names.contains(
link.to_input_tensor_name()))
<< absl::StrFormat(
"Feedback input tensor [%s] cannot be used for input/output "
"mapping. Input/output mapping tensor names: [%s]",
link.to_input_tensor_name(),
absl::StrJoin(input_output_mapping_tensor_names, ", "));
TensorFeedbackIndicesLink indices_link;
auto from_it =
output_name_to_index_map.find(link.from_output_tensor_name());
RET_CHECK(from_it != output_name_to_index_map.end())
<< "Output tensor name not found: " << link.from_output_tensor_name();
auto to_it = input_name_to_index_map.find(link.to_input_tensor_name());
RET_CHECK(to_it != input_name_to_index_map.end())
<< "Input tensor name not found: " << link.to_input_tensor_name();
indices_link.from_idx = from_it->second;
indices_link.to_idx = to_it->second;
indices_links.push_back(indices_link);
}
return indices_links;
}

bool InferenceFeedbackManager::IsFeedbackInputTensorAtIndex(int idx) const {
return feedback_input_indices_.contains(idx);
}

bool InferenceFeedbackManager::IsFeedbackOutputTensorAtIndex(int idx) const {
return feedback_output_indices_.contains(idx);
}

absl::StatusOr<int> InferenceFeedbackManager::MapInputTensorToModelIndex(
int input_idx) const {
RET_CHECK(input_idx >= 0 &&
input_idx <= input_tensor_to_model_indices_.size())
<< "Invalid input tensor index: " << input_idx;
return input_tensor_to_model_indices_[input_idx];
}

int InferenceFeedbackManager::GetNumberOfNonFeedbackInputTensors() const {
return input_tensor_to_model_indices_.size();
}

int InferenceFeedbackManager::GetNumberOfFeedbackTensors() const {
return feedback_tensor_indices_links_.size();
}
} // namespace mediapipe
Loading

0 comments on commit 99fc736

Please sign in to comment.