Skip to content

Commit

Permalink
Merge pull request #5038 from priankakariatyml:ios-language-detector-…
Browse files Browse the repository at this point in the history
…tests

PiperOrigin-RevId: 595832527
  • Loading branch information
copybara-github committed Jan 5, 2024
2 parents 234a550 + a0c482f commit 2d374a6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
7 changes: 7 additions & 0 deletions mediapipe/tasks/ios/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h",
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h",
"//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetector.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorOptions.h",
"//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorResult.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPImage.h",
"//mediapipe/tasks/ios/vision/core:sources/MPPMask.h",
Expand Down Expand Up @@ -147,8 +150,12 @@ apple_static_xcframework(
":MPPTextEmbedder.h",
":MPPTextEmbedderOptions.h",
":MPPTextEmbedderResult.h",
":MPPLanguageDetector.h",
":MPPLanguageDetectorOptions.h",
":MPPLanguageDetectorResult.h",
],
deps = [
"//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector",
"//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier",
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,82 @@ - (void)testClassifyWithL2CModelSucceeds {
[self assertResultsOfDetectLanguageOfText:ruText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions];

NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f]
];

[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}

- (void)testClassifyWithMaxResultsSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.maxResults = 1;
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];

NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];

[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}

- (void)testClassifyWithScoreThresholdSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.scoreThreshold = 0.5f;
MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];

NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];

[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}

- (void)testClassifyWithCategoryAllowListSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.categoryAllowlist = @[ @"zh" ];

MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];

NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f],
];

[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}

- (void)testClassifyWithCategoryDenyListSucceeds {
MPPLanguageDetectorOptions *options =
[self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo];
options.categoryDenylist = @[ @"zh" ];

MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options];

NSString *zhText = @"分久必合合久必分";
NSArray<MPPLanguagePrediction *> *expectedZhLanguagePredictions = @[
[[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f],
];

[self assertResultsOfDetectLanguageOfText:zhText
usingLanguageDetector:languageDetector
approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions];
}

#pragma mark Assert Segmenter Results
Expand All @@ -125,6 +201,8 @@ - (void)assertResultsOfDetectLanguageOfText:(NSString *)text
(NSArray<MPPLanguagePrediction *> *)expectedLanguagePredictions {
MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil];
XCTAssertNotNil(result);

XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count);
XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability,
expectedLanguagePredictions[0].probability, 1e-3);
XCTAssertEqualObjects(result.languagePredictions[0].languageCode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ - (instancetype)init {
self = [super init];
if (self) {
_maxResults = -1;
_scoreThreshold = 0;
_scoreThreshold = -1.0f;
}
return self;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ - (void)copyToProto:(CalculatorOptionsProto *)optionsProto {
}

classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);

if (self.scoreThreshold >= 0) {
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
}

for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString);
Expand Down

0 comments on commit 2d374a6

Please sign in to comment.