diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index dc5aca48a1..e5a3f65faa 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -169,6 +169,7 @@ cc_library( deps = [ "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", "//mediapipe/calculators/tflite:ssd_anchors_calculator", diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc index fe24070a2f..a320e2a266 100644 --- a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc @@ -710,6 +710,57 @@ absl::StatusOr>> CalibrateScores( return model_output_tensors; } +// Identifies whether or not the model has quantized outputs, and performs +// sanity checks. +absl::StatusOr HasQuantizedOutputs( + const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + // Model is checked to have single subgraph before. + const auto* primary_subgraph = (*model.subgraphs())[0]; + int num_output_tensors = primary_subgraph->outputs()->size(); + // Sanity check tensor types and check if model outputs are quantized or not. + int num_quantized_tensors = 0; + for (int i = 0; i < num_output_tensors; ++i) { + const auto* tensor = + primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); + if (tensor->type() != tflite::TensorType_FLOAT32 && + tensor->type() != tflite::TensorType_UINT8) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected output tensor at index %d to have type " + "UINT8 or FLOAT32, found %s instead.", + i, tflite::EnumNameTensorType(tensor->type())), + MediaPipeTasksStatus::kInvalidOutputTensorTypeError); + } + if (tensor->type() == tflite::TensorType_UINT8) { + num_quantized_tensors++; + } + } + if (num_quantized_tensors != num_output_tensors && + num_quantized_tensors != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected either all or none of the output tensors to be " + "quantized, but found %d quantized outputs for %d total outputs.", + num_quantized_tensors, num_output_tensors), + MediaPipeTasksStatus::kInvalidOutputTensorTypeError); + } + // Check if metadata is consistent with model topology. + const auto* output_tensors_metadata = + model_resources.GetMetadataExtractor()->GetOutputTensorMetadata(); + if (output_tensors_metadata != nullptr && + num_output_tensors != output_tensors_metadata->size()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (%d) and " + "output tensors metadata (%d).", + num_output_tensors, output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return num_quantized_tensors > 0; +} + } // namespace absl::Status ConfigureDetectionPostprocessingGraph( @@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph( model.subgraphs()->Get(0)->outputs()->size()), MediaPipeTasksStatus::kInvalidArgumentError); } - + MP_ASSIGN_OR_RETURN(bool has_quantized_outputs, + HasQuantizedOutputs(model_resources)); + options.set_has_quantized_outputs(has_quantized_outputs); const ModelMetadataExtractor* metadata_extractor = model_resources.GetMetadataExtractor(); if (in_model_nms) { @@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph { BuildDetectionPostprocessing( proto::DetectionPostprocessingGraphOptions& graph_options, Source> tensors_in, Graph& graph) { + Source> tensors = tensors_in; + if (graph_options.has_quantized_outputs()) { + auto& tensors_dequantization_node = + graph.AddNode("TensorsDequantizationCalculator"); + tensors_in >> tensors_dequantization_node.In(kTensorsTag); + tensors = tensors_dequantization_node.Out(kTensorsTag) + .Cast>(); + } std::optional>> detections; if (!graph_options.has_non_max_suppression_options()) { // Calculators to perform score calibration, if specified in the options. if (graph_options.has_score_calibration_options()) { - MP_ASSIGN_OR_RETURN(tensors_in, - CalibrateScores(tensors_in, graph_options, graph)); + MP_ASSIGN_OR_RETURN(tensors, + CalibrateScores(tensors, graph_options, graph)); } // Calculator to convert output tensors to a detection proto vector. auto& tensors_to_detections = @@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph { tensors_to_detections .GetOptions() .Swap(graph_options.mutable_tensors_to_detections_options()); - tensors_in >> tensors_to_detections.In(kTensorsTag); + tensors >> tensors_to_detections.In(kTensorsTag); detections = tensors_to_detections.Out(kDetectionsTag) .Cast>(); } else { @@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph { .GetOptions() .Swap(graph_options.mutable_tensors_to_detections_options()); anchors >> tensors_to_detections.SideIn(kAnchorsTag); - tensors_in >> tensors_to_detections.In(kTensorsTag); + tensors >> tensors_to_detections.In(kTensorsTag); detections = tensors_to_detections.Out(kDetectionsTag) .Cast>(); // Non maximum suppression removes redundant object detections. diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc index 5475182b73..cadf3de30d 100644 --- a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc @@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } } + has_quantized_outputs: false )pb")))); EXPECT_THAT( options_out.detection_label_ids_to_text_options().label_items_size(), 90); @@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } } + has_quantized_outputs: false )pb"))); } @@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } } + has_quantized_outputs: false )pb"))); } @@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { score_transformation: IDENTITY default_score: 0.5 } + has_quantized_outputs: false )pb"))); } diff --git a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto index ec11df2b47..ce0edd1602 100644 --- a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto @@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions { // Optional detection label id to text calculator options. optional mediapipe.DetectionLabelIdToTextCalculatorOptions detection_label_ids_to_text_options = 5; + + // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). + optional bool has_quantized_outputs = 6; }