From 6e96542d51491cf6daf32e76e857f3f3c9d27cb4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 24 Sep 2024 01:29:40 -0700 Subject: [PATCH] Add a flag to force CPU inference in the inference calculator PiperOrigin-RevId: 678131539 --- mediapipe/calculators/tensor/BUILD | 12 ++++++++++++ mediapipe/calculators/tensor/inference_calculator.cc | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 8b91ed990a..6dfb8b3785 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -408,6 +408,14 @@ mediapipe_proto_library( ], ) +config_setting( + name = "force_cpu_inference", + values = { + "define": "MEDIAPIPE_FORCE_CPU_INFERENCE=true", + }, + visibility = ["//visibility:public"], +) + # This target defines the "InferenceCalculator" component, which looks for the available concrete # implementations linked into the current binary and picks the one to use. # You can depend on :inference_calculator instead if you want to automatically include a default @@ -420,6 +428,10 @@ cc_library_with_tflite( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], hdrs = ["inference_calculator.h"], + local_defines = select({ + ":force_cpu_inference": ["MEDIAPIPE_FORCE_CPU_INFERENCE=1"], + "//conditions:default": [], + }), tflite_deps = [ ":inference_runner", ":inference_io_mapper", diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 330c03ca5e..5f20775378 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -50,6 +50,8 @@ class InferenceCalculatorSelectorImpl subgraph_node); std::vector impls; +#if !defined(MEDIAPIPE_FORCE_CPU_INFERENCE) || !MEDIAPIPE_FORCE_CPU_INFERENCE + const bool should_use_gpu = !options.has_delegate() || // Use GPU delegate if not specified (options.has_delegate() && options.delegate().has_gpu()); @@ -59,7 +61,6 @@ class InferenceCalculatorSelectorImpl #if MEDIAPIPE_METAL_ENABLED impls.emplace_back("Metal"); #endif - const bool prefer_gl_advanced = options.delegate().gpu().use_advanced_gpu_api() && (api == Gpu::ANY || api == Gpu::OPENGL || api == Gpu::OPENCL); @@ -71,6 +72,8 @@ class InferenceCalculatorSelectorImpl impls.emplace_back("GlAdvanced"); } } +#endif // !defined(MEDIAPIPE_FORCE_CPU_INFERENCE) || + // !MEDIAPIPE_FORCE_CPU_INFERENCE impls.emplace_back("Cpu"); impls.emplace_back("Xnnpack"); for (const auto& suffix : impls) {