diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index 009431eb1a..127e9a6569 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -83,6 +83,9 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedder.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderOptions.h", "//mediapipe/tasks/ios/text/text_embedder:sources/MPPTextEmbedderResult.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetector.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorOptions.h", + "//mediapipe/tasks/ios/text/language_detector:sources/MPPLanguageDetectorResult.h", "//mediapipe/tasks/ios/vision/core:sources/MPPRunningMode.h", "//mediapipe/tasks/ios/vision/core:sources/MPPImage.h", "//mediapipe/tasks/ios/vision/core:sources/MPPMask.h", @@ -147,8 +150,12 @@ apple_static_xcframework( ":MPPTextEmbedder.h", ":MPPTextEmbedderOptions.h", ":MPPTextEmbedderResult.h", + ":MPPLanguageDetector.h", + ":MPPLanguageDetectorOptions.h", + ":MPPLanguageDetectorResult.h", ], deps = [ + "//mediapipe/tasks/ios/text/language_detector:MPPLanguageDetector", "//mediapipe/tasks/ios/text/text_classifier:MPPTextClassifier", "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", ], diff --git a/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm b/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm index 28d2ea5c03..1b738a124e 100644 --- a/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm +++ b/mediapipe/tasks/ios/test/text/language_detector/MPPLanguageDetectorTests.mm @@ -116,6 +116,82 @@ - (void)testClassifyWithL2CModelSucceeds { [self assertResultsOfDetectLanguageOfText:ruText usingLanguageDetector:languageDetector approximatelyEqualsExpectedLanguagePredictions:expectedRuLanguagePredictions]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f] + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithMaxResultsSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.maxResults = 1; + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithScoreThresholdSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.scoreThreshold = 0.5f; + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithCategoryAllowListSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.categoryAllowlist = @[ @"zh" ]; + + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"zh" probability:0.505424f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; +} + +- (void)testClassifyWithCategoryDenyListSucceeds { + MPPLanguageDetectorOptions *options = + [self languageDetectorOptionsWithModelFileInfo:kLanguageDetectorModelFileInfo]; + options.categoryDenylist = @[ @"zh" ]; + + MPPLanguageDetector *languageDetector = [self createLanguageDetectorWithOptionsSucceeds:options]; + + NSString *zhText = @"分久必合合久必分"; + NSArray *expectedZhLanguagePredictions = @[ + [[MPPLanguagePrediction alloc] initWithLanguageCode:@"ja" probability:0.481617f], + ]; + + [self assertResultsOfDetectLanguageOfText:zhText + usingLanguageDetector:languageDetector + approximatelyEqualsExpectedLanguagePredictions:expectedZhLanguagePredictions]; } #pragma mark Assert Segmenter Results @@ -125,6 +201,8 @@ - (void)assertResultsOfDetectLanguageOfText:(NSString *)text (NSArray *)expectedLanguagePredictions { MPPLanguageDetectorResult *result = [languageDetector detectText:text error:nil]; XCTAssertNotNil(result); + + XCTAssertEqual(result.languagePredictions.count, expectedLanguagePredictions.count); XCTAssertEqualWithAccuracy(result.languagePredictions[0].probability, expectedLanguagePredictions[0].probability, 1e-3); XCTAssertEqualObjects(result.languagePredictions[0].languageCode, diff --git a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m index df36493ef0..9113a9a47b 100644 --- a/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m +++ b/mediapipe/tasks/ios/text/language_detector/sources/MPPLanguageDetectorOptions.m @@ -20,7 +20,7 @@ - (instancetype)init { self = [super init]; if (self) { _maxResults = -1; - _scoreThreshold = 0; + _scoreThreshold = -1.0f; } return self; } diff --git a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm index 9d75105b4f..0f1db0484c 100644 --- a/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm +++ b/mediapipe/tasks/ios/text/language_detector/utils/sources/MPPLanguageDetectorOptions+Helpers.mm @@ -42,7 +42,10 @@ - (void)copyToProto:(CalculatorOptionsProto *)optionsProto { } classifierOptionsProto->set_max_results((int)self.maxResults); - classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + if (self.scoreThreshold >= 0) { + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + } for (NSString *category in self.categoryAllowlist) { classifierOptionsProto->add_category_allowlist(category.cppString);