Skip to content

Commit

Permalink
Add a flag to force CPU inference in the inference calculator
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678131539
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 24, 2024
1 parent 5508d9b commit 6e96542
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 12 additions & 0 deletions mediapipe/calculators/tensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ mediapipe_proto_library(
],
)

config_setting(
name = "force_cpu_inference",
values = {
"define": "MEDIAPIPE_FORCE_CPU_INFERENCE=true",
},
visibility = ["//visibility:public"],
)

# This target defines the "InferenceCalculator" component, which looks for the available concrete
# implementations linked into the current binary and picks the one to use.
# You can depend on :inference_calculator instead if you want to automatically include a default
Expand All @@ -420,6 +428,10 @@ cc_library_with_tflite(
name = "inference_calculator_interface",
srcs = ["inference_calculator.cc"],
hdrs = ["inference_calculator.h"],
local_defines = select({
":force_cpu_inference": ["MEDIAPIPE_FORCE_CPU_INFERENCE=1"],
"//conditions:default": [],
}),
tflite_deps = [
":inference_runner",
":inference_io_mapper",
Expand Down
5 changes: 4 additions & 1 deletion mediapipe/calculators/tensor/inference_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class InferenceCalculatorSelectorImpl
subgraph_node);
std::vector<absl::string_view> impls;

#if !defined(MEDIAPIPE_FORCE_CPU_INFERENCE) || !MEDIAPIPE_FORCE_CPU_INFERENCE

const bool should_use_gpu =
!options.has_delegate() || // Use GPU delegate if not specified
(options.has_delegate() && options.delegate().has_gpu());
Expand All @@ -59,7 +61,6 @@ class InferenceCalculatorSelectorImpl
#if MEDIAPIPE_METAL_ENABLED
impls.emplace_back("Metal");
#endif

const bool prefer_gl_advanced =
options.delegate().gpu().use_advanced_gpu_api() &&
(api == Gpu::ANY || api == Gpu::OPENGL || api == Gpu::OPENCL);
Expand All @@ -71,6 +72,8 @@ class InferenceCalculatorSelectorImpl
impls.emplace_back("GlAdvanced");
}
}
#endif // !defined(MEDIAPIPE_FORCE_CPU_INFERENCE) ||
// !MEDIAPIPE_FORCE_CPU_INFERENCE
impls.emplace_back("Cpu");
impls.emplace_back("Xnnpack");
for (const auto& suffix : impls) {
Expand Down

0 comments on commit 6e96542

Please sign in to comment.