diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index ce84f4b18c..c336c13d98 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -633,6 +633,7 @@ cc_library( ":inference_io_mapper", ":inference_on_disk_cache_helper", ":tensor_span", + "//mediapipe/calculators/tensor:inference_runner", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:mediapipe_profiling", "//mediapipe/framework/api2:packet", diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 8661e742c4..07b23af0c1 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -25,6 +25,7 @@ #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_io_mapper.h" #include "mediapipe/calculators/tensor/inference_on_disk_cache_helper.h" +#include "mediapipe/calculators/tensor/inference_runner.h" #include "mediapipe/calculators/tensor/tensor_span.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_framework.h" @@ -63,17 +64,17 @@ class InferenceCalculatorGlAdvancedImpl private: // Helper class that wraps everything related to GPU inference acceleration. - class GpuInferenceRunner { + class GpuInferenceRunner : public InferenceRunner { public: ~GpuInferenceRunner(); absl::Status Init(CalculatorContext* cc, std::shared_ptr gl_context); - absl::StatusOr> Process( - CalculatorContext* cc, const TensorSpan& input_tensors); + absl::StatusOr> Run( + CalculatorContext* cc, const TensorSpan& input_tensors) override; - const InputOutputTensorNames& GetInputOutputTensorNames() const; + const InputOutputTensorNames& GetInputOutputTensorNames() const override; private: absl::Status InitTFLiteGPURunner( @@ -99,7 +100,7 @@ class InferenceCalculatorGlAdvancedImpl absl::StatusOr> CreateInferenceRunner( CalculatorContext* cc); - std::unique_ptr gpu_inference_runner_; + std::unique_ptr inference_runner_; mediapipe::GlCalculatorHelper gpu_helper_; }; @@ -141,7 +142,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init( } absl::StatusOr> -InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( +InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Run( CalculatorContext* cc, const TensorSpan& input_tensors) { std::vector output_tensors; for (int i = 0; i < input_tensors.size(); ++i) { @@ -267,11 +268,10 @@ absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract( absl::Status InferenceCalculatorGlAdvancedImpl::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - gpu_inference_runner_ = std::make_unique(); - MP_RETURN_IF_ERROR( - gpu_inference_runner_->Init(cc, gpu_helper_.GetSharedGlContext())); + + MP_ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc)); return InferenceCalculatorNodeImpl::UpdateIoMapping( - cc, gpu_inference_runner_->GetInputOutputTensorNames()); + cc, inference_runner_->GetInputOutputTensorNames()); } absl::StatusOr> InferenceCalculatorGlAdvancedImpl::Process( @@ -279,14 +279,14 @@ absl::StatusOr> InferenceCalculatorGlAdvancedImpl::Process( std::vector output_tensors; MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&]() -> absl::Status { MP_ASSIGN_OR_RETURN(output_tensors, - gpu_inference_runner_->Process(cc, tensor_span)); + inference_runner_->Run(cc, tensor_span)); return absl::OkStatus(); })); return output_tensors; } absl::Status InferenceCalculatorGlAdvancedImpl::Close(CalculatorContext* cc) { - gpu_inference_runner_.reset(); + inference_runner_.reset(); return absl::OkStatus(); }