Skip to content

Commit

Permalink
Add tests for C API containers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569526282
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Sep 29, 2023
1 parent d4561fb commit 6915a79
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 8 deletions.
24 changes: 24 additions & 0 deletions mediapipe/tasks/c/components/containers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ cc_library(
],
)

cc_test(
name = "category_converter_test",
srcs = ["category_converter_test.cc"],
deps = [
":category",
":category_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:category",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "classification_result",
hdrs = ["classification_result.h"],
Expand All @@ -47,3 +59,15 @@ cc_library(
"//mediapipe/tasks/cc/components/containers:classification_result",
],
)

cc_test(
name = "classification_result_converter_test",
srcs = ["classification_result_converter_test.cc"],
deps = [
":classification_result",
":classification_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:classification_result",
"@com_google_googletest//:gtest_main",
],
)
4 changes: 2 additions & 2 deletions mediapipe/tasks/c/components/containers/category.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ struct Category {

// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
const char* category_name;
char* category_name;

// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
const char* display_name;
char* display_name;
};

#ifdef __cplusplus
Expand Down
63 changes: 63 additions & 0 deletions mediapipe/tasks/c/components/containers/category_converter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mediapipe/tasks/c/components/containers/category_converter.h"

#include <cstdlib>
#include <optional>
#include <string>

#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/category.h"

namespace mediapipe::tasks::c::components::containers {

TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
mediapipe::tasks::components::containers::Category cpp_category = {
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ "category_name",
/* display_name= */ "display_name",
};

Category c_category;
CppConvertToCategory(cpp_category, &c_category);
EXPECT_EQ(c_category.index, 1);
EXPECT_FLOAT_EQ(c_category.score, 0.1);
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
EXPECT_EQ(std::string{c_category.display_name}, "display_name");

free(c_category.category_name);
free(c_category.display_name);
}

TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
mediapipe::tasks::components::containers::Category cpp_category = {
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ std::nullopt,
/* display_name= */ std::nullopt,
};

Category c_category;
CppConvertToCategory(cpp_category, &c_category);
EXPECT_EQ(c_category.index, 1);
EXPECT_FLOAT_EQ(c_category.score, 0.1);
EXPECT_EQ(c_category.category_name, nullptr);
EXPECT_EQ(c_category.display_name, nullptr);
}

} // namespace mediapipe::tasks::c::components::containers
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct Classifications {
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
const char* head_name;
char* head_name;
};

// Defines classification results of a model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/category_converter.h"
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"

namespace mediapipe::tasks::c::components::containers {
Expand All @@ -27,20 +28,22 @@ void CppConvertToClassificationResult(
const mediapipe::tasks::components::containers::ClassificationResult& in,
ClassificationResult* out) {
out->has_timestamp_ms = in.timestamp_ms.has_value();
if (out->has_timestamp_ms) {
out->timestamp_ms = in.timestamp_ms.value();
}
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;

out->classifications_count = in.classifications.size();
out->classifications = new Classifications[out->classifications_count];
out->classifications = out->classifications_count
? new Classifications[out->classifications_count]
: nullptr;

for (uint32_t i = 0; i < out->classifications_count; ++i) {
auto classification_in = in.classifications[i];
auto& classification_out = out->classifications[i];

classification_out.categories_count = classification_in.categories.size();
classification_out.categories =
new Category[classification_out.categories_count];
classification_out.categories_count
? new Category[classification_out.categories_count]
: nullptr;
for (uint32_t j = 0; j < classification_out.categories_count; ++j) {
CppConvertToCategory(classification_in.categories[j],
&(classification_out.categories[j]));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"

#include <cstdlib>
#include <optional>
#include <string>

#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"

namespace mediapipe::tasks::c::components::containers {

TEST(ClassificationResultConverterTest,
ConvertsClassificationResulCustomCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {{/* categories= */ {{
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ std::nullopt,
/* display_name= */ std::nullopt,
}},
/* head_index= */ 0,
/* head_name= */ "foo"}},
/* timestamp_ms= */ 42,
};

ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_NE(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 1);
EXPECT_NE(c_classification_result.classifications[0].categories, nullptr);
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 1);
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
EXPECT_EQ(std::string(c_classification_result.classifications[0].head_name),
"foo");
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);

free(c_classification_result.classifications[0].categories);
free(c_classification_result.classifications[0].head_name);
free(c_classification_result.classifications);
}

TEST(ClassificationResultConverterTest,
ConvertsClassificationResulEmptyCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {{/* categories= */ {}, /* head_index= */ 0,
/* head_name= */ std::nullopt}},
/* timestamp_ms= */ std::nullopt,
};

ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_NE(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 1);
EXPECT_EQ(c_classification_result.classifications[0].categories, nullptr);
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 0);
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
EXPECT_EQ(c_classification_result.classifications[0].head_name, nullptr);
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);

free(c_classification_result.classifications);
}

TEST(ClassificationResultConverterTest,
ConvertsClassificationResultNoCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {},
/* timestamp_ms= */ std::nullopt,
};

ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_EQ(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 0);
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
}

} // namespace mediapipe::tasks::c::components::containers

0 comments on commit 6915a79

Please sign in to comment.