Skip to content

Commit

Permalink
Add unit tests for C layer for the input types of Text Classifier
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569553038
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Sep 29, 2023
1 parent 6915a79 commit 96fa10b
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 11 deletions.
12 changes: 12 additions & 0 deletions mediapipe/tasks/c/components/processors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ cc_library(
"//mediapipe/tasks/cc/components/processors:classifier_options",
],
)

cc_test(
name = "classifier_options_converter_test",
srcs = ["classifier_options_converter_test.cc"],
deps = [
":classifier_options",
":classifier_options_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"@com_google_googletest//:gtest_main",
],
)
6 changes: 3 additions & 3 deletions mediapipe/tasks/c/components/processors/classifier_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ extern "C" {
struct ClassifierOptions {
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
char* display_names_locale;
const char* display_names_locale;

// The maximum number of top-scored classification results to return. If < 0,
// all available results will be returned. If 0, an invalid argument error is
Expand All @@ -40,14 +40,14 @@ struct ClassifierOptions {
// The allowlist of category names. If non-empty, detection results whose
// category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist.
char** category_allowlist;
const char** category_allowlist;
// The number of elements in the category allowlist.
uint32_t category_allowlist_count;

// The denylist of category names. If non-empty, detection results whose
// category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist.
char** category_denylist;
const char** category_denylist;
// The number of elements in the category denylist.
uint32_t category_denylist_count;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void CppConvertToClassifierOptions(
const ClassifierOptions& in,
mediapipe::tasks::components::processors::ClassifierOptions* out) {
out->display_names_locale =
in.display_names_locale ? std::string(in.display_names_locale) : "";
in.display_names_locale ? std::string(in.display_names_locale) : "en";
out->max_results = in.max_results;
out->score_threshold = in.score_threshold;
out->category_allowlist =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"

#include <string>
#include <vector>

#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"

namespace mediapipe::tasks::c::components::processors {

constexpr char kCategoryAllowlist[] = "fruit";
constexpr char kCategoryDenylist[] = "veggies";
constexpr char kDisplayNamesLocaleGerman[] = "de";

TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsCustomValues) {
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
std::vector<const char*> category_denylist = {kCategoryDenylist};

ClassifierOptions c_classifier_options = {
/* display_names_locale= */ kDisplayNamesLocaleGerman,
/* max_results= */ 1,
/* score_threshold= */ 0.1,
/* category_allowlist= */ category_allowlist.data(),
/* category_allowlist_count= */ 1,
/* category_denylist= */ category_denylist.data(),
/* category_denylist_count= */ 1};

mediapipe::tasks::components::processors::ClassifierOptions
cpp_classifier_options = {};

CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
EXPECT_EQ(cpp_classifier_options.display_names_locale, "de");
EXPECT_EQ(cpp_classifier_options.max_results, 1);
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.1);
EXPECT_EQ(cpp_classifier_options.category_allowlist,
std::vector<std::string>{"fruit"});
EXPECT_EQ(cpp_classifier_options.category_denylist,
std::vector<std::string>{"veggies"});
}

TEST(ClassifierOptionsConverterTest, ConvertsClassifierOptionsDefaultValues) {
std::vector<const char*> category_allowlist = {kCategoryAllowlist};
std::vector<const char*> category_denylist = {kCategoryDenylist};

ClassifierOptions c_classifier_options = {/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0};

mediapipe::tasks::components::processors::ClassifierOptions
cpp_classifier_options = {};

CppConvertToClassifierOptions(c_classifier_options, &cpp_classifier_options);
EXPECT_EQ(cpp_classifier_options.display_names_locale, "en");
EXPECT_EQ(cpp_classifier_options.max_results, -1);
EXPECT_FLOAT_EQ(cpp_classifier_options.score_threshold, 0.0);
EXPECT_EQ(cpp_classifier_options.category_allowlist,
std::vector<std::string>{});
EXPECT_EQ(cpp_classifier_options.category_denylist,
std::vector<std::string>{});
}

} // namespace mediapipe::tasks::c::components::processors
12 changes: 12 additions & 0 deletions mediapipe/tasks/c/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ cc_library(
"//mediapipe/tasks/cc/core:base_options",
],
)

cc_test(
name = "base_options_converter_test",
srcs = ["base_options_converter_test.cc"],
deps = [
":base_options",
":base_options_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/core:base_options",
"@com_google_googletest//:gtest_main",
],
)
4 changes: 2 additions & 2 deletions mediapipe/tasks/c/core/base_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ extern "C" {
// Base options for MediaPipe C Tasks.
struct BaseOptions {
// The model asset file contents as a string.
char* model_asset_buffer;
const char* model_asset_buffer;

// The path to the model asset to open and mmap in memory.
char* model_asset_path;
const char* model_asset_path;
};

#ifdef __cplusplus
Expand Down
4 changes: 2 additions & 2 deletions mediapipe/tasks/c/core/base_options_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"

namespace mediapipe::tasks::c::components::containers {
namespace mediapipe::tasks::c::core {

void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out) {
Expand All @@ -33,4 +33,4 @@ void CppConvertToBaseOptions(const BaseOptions& in,
in.model_asset_path ? std::string(in.model_asset_path) : "";
}

} // namespace mediapipe::tasks::c::components::containers
} // namespace mediapipe::tasks::c::core
4 changes: 2 additions & 2 deletions mediapipe/tasks/c/core/base_options_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"

namespace mediapipe::tasks::c::components::containers {
namespace mediapipe::tasks::c::core {

void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out);

} // namespace mediapipe::tasks::c::components::containers
} // namespace mediapipe::tasks::c::core

#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
51 changes: 51 additions & 0 deletions mediapipe/tasks/c/core/base_options_converter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mediapipe/tasks/c/core/base_options_converter.h"

#include <string>

#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"

namespace mediapipe::tasks::c::core {

constexpr char kAssetBuffer[] = "abc";
constexpr char kModelAssetPath[] = "abc.tflite";

TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
/* model_asset_path= */ nullptr};

mediapipe::tasks::core::BaseOptions cpp_base_options = {};

CppConvertToBaseOptions(c_base_options, &cpp_base_options);
EXPECT_EQ(*cpp_base_options.model_asset_buffer, std::string{kAssetBuffer});
EXPECT_EQ(cpp_base_options.model_asset_path, "");
}

TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ kModelAssetPath};

mediapipe::tasks::core::BaseOptions cpp_base_options = {};

CppConvertToBaseOptions(c_base_options, &cpp_base_options);
EXPECT_EQ(cpp_base_options.model_asset_buffer.get(), nullptr);
EXPECT_EQ(cpp_base_options.model_asset_path, std::string{kModelAssetPath});
}

} // namespace mediapipe::tasks::c::core
2 changes: 1 addition & 1 deletion mediapipe/tasks/c/text/text_classifier/text_classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ namespace mediapipe::tasks::c::text::text_classifier {

namespace {

using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
using ::mediapipe::tasks::c::components::containers::
CppConvertToClassificationResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
} // namespace

Expand Down

0 comments on commit 96fa10b

Please sign in to comment.