diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc index 1a99cb88c3..ab9cf213f0 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -31,7 +32,9 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/gpu/gpu_service.h" #include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h" @@ -100,18 +103,30 @@ bool CanUseGpu() { // OUTPUT - mediapipe::Image. class TensorsToImageCalculator : public Node { public: - static constexpr Input> kInputTensors{"TENSORS"}; + static constexpr Input>::Optional kInputTensors{ + "TENSORS"}; + static constexpr Input::Optional kInputTensor{"TENSOR"}; static constexpr Output kOutputImage{"IMAGE"}; - MEDIAPIPE_NODE_CONTRACT(kInputTensors, kOutputImage); + MEDIAPIPE_NODE_CONTRACT(kInputTensors, kInputTensor, kOutputImage); static absl::Status UpdateContract(CalculatorContract* cc); - absl::Status Open(CalculatorContext* cc); - absl::Status Process(CalculatorContext* cc); - absl::Status Close(CalculatorContext* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: TensorsToImageCalculatorOptions options_; + + // Returns true if kInputTensor or kInputTensors (whatever is connected) is + // empty. + bool IsInputTensorEmpty(CalculatorContext* cc); + + // Retrieves the input tensor from kInputTensor or kInputTensors. Returns + // nullptr if the connected input stream was empty. + absl::StatusOr> GetInputTensor( + CalculatorContext* cc); + absl::Status CpuProcess(CalculatorContext* cc); int tensor_position_; @@ -142,6 +157,8 @@ class TensorsToImageCalculator : public Node { MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToImageCalculator); absl::Status TensorsToImageCalculator::UpdateContract(CalculatorContract* cc) { + RET_CHECK(kInputTensors(cc).IsConnected() ^ kInputTensor(cc).IsConnected()) + << "Either TENSORS or TENSOR must be specified"; #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_METAL_ENABLED MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); @@ -156,12 +173,14 @@ absl::Status TensorsToImageCalculator::UpdateContract(CalculatorContract* cc) { absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); if (!CanUseGpu()) { - ABSL_CHECK(options_.has_input_tensor_float_range() ^ - options_.has_input_tensor_uint_range()) + RET_CHECK(options_.has_input_tensor_float_range() ^ + options_.has_input_tensor_uint_range()) << "Must specify either `input_tensor_float_range` or " "`input_tensor_uint_range` in the calculator options"; } tensor_position_ = options_.tensor_position(); + RET_CHECK(!kInputTensor(cc).IsConnected() || tensor_position_ == 0) + << "The tensor_position option cannot be used with the TENSOR input"; return absl::OkStatus(); } @@ -197,16 +216,30 @@ absl::Status TensorsToImageCalculator::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { - if (kInputTensors(cc).IsEmpty()) { - return absl::OkStatus(); +bool TensorsToImageCalculator::IsInputTensorEmpty(CalculatorContext* cc) { + if (kInputTensor(cc).IsConnected()) { + return kInputTensor(cc).IsEmpty(); } - const auto& input_tensors = kInputTensors(cc).Get(); + return kInputTensors(cc).IsEmpty(); +} + +absl::StatusOr> +TensorsToImageCalculator::GetInputTensor(CalculatorContext* cc) { + if (kInputTensor(cc).IsConnected()) { + return kInputTensor(cc).Get(); + } + + const std::vector& input_tensors = kInputTensors(cc).Get(); RET_CHECK_GT(input_tensors.size(), tensor_position_) << "Expect input tensor at position " << tensor_position_ << ", but have tensors of size " << input_tensors.size(); + return input_tensors[tensor_position_]; +} + +absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { + if (IsInputTensorEmpty(cc)) return absl::OkStatus(); + MP_ASSIGN_OR_RETURN(const Tensor& input_tensor, GetInputTensor(cc)); - const auto& input_tensor = input_tensors[tensor_position_]; const int tensor_in_height = input_tensor.shape().dims[1]; const int tensor_in_width = input_tensor.shape().dims[2]; const int tensor_in_channels = input_tensor.shape().dims[3]; @@ -268,16 +301,12 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { metal_initialized_ = true; } - if (kInputTensors(cc).IsEmpty()) { - return absl::OkStatus(); - } - const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_GT(input_tensors.size(), tensor_position_) - << "Expect input tensor at position " << tensor_position_ - << ", but have tensors of size " << input_tensors.size(); - const int tensor_width = input_tensors[tensor_position_].shape().dims[2]; - const int tensor_height = input_tensors[tensor_position_].shape().dims[1]; - const int tensor_channels = input_tensors[tensor_position_].shape().dims[3]; + if (IsInputTensorEmpty(cc)) return absl::OkStatus(); + MP_ASSIGN_OR_RETURN(const Tensor& input_tensor, GetInputTensor(cc)); + + const int tensor_width = input_tensor.shape().dims[2]; + const int tensor_height = input_tensor.shape().dims[1]; + const int tensor_channels = input_tensor.shape().dims[3]; // TODO: Add 1 channel support. RET_CHECK(tensor_channels == 3); @@ -289,8 +318,8 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { [command_buffer computeCommandEncoder]; [compute_encoder setComputePipelineState:to_buffer_program_]; - auto input_view = mediapipe::MtlBufferView::GetReadView( - input_tensors[tensor_position_], command_buffer); + auto input_view = + mediapipe::MtlBufferView::GetReadView(input_tensor, command_buffer); [compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer output = @@ -460,15 +489,9 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { gl_initialized_ = true; } - if (kInputTensors(cc).IsEmpty()) { - return absl::OkStatus(); - } - const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_GT(input_tensors.size(), tensor_position_) - << "Expect input tensor at position " << tensor_position_ - << ", but have tensors of size " << input_tensors.size(); + if (IsInputTensorEmpty(cc)) return absl::OkStatus(); + MP_ASSIGN_OR_RETURN(const Tensor& input_tensor, GetInputTensor(cc)); - const auto& input_tensor = input_tensors[tensor_position_]; const int tensor_width = input_tensor.shape().dims[2]; const int tensor_height = input_tensor.shape().dims[1]; const int tensor_in_channels = input_tensor.shape().dims[3];