Skip to content

Commit

Permalink
Merge pull request #5006 from priankakariatyml:ios-interactive-segmen…
Browse files Browse the repository at this point in the history
…ter-containers

PiperOrigin-RevId: 597300604
  • Loading branch information
copybara-github committed Jan 10, 2024
2 parents dccbad0 + 4c03da5 commit dbcd3d0
Show file tree
Hide file tree
Showing 8 changed files with 779 additions and 10 deletions.
77 changes: 77 additions & 0 deletions mediapipe/tasks/ios/test/vision/interactive_segmenter/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")

package(default_visibility = ["//visibility:public"])

licenses(["notice"])

# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]

# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]

objc_library(
name = "MPPInteractiveSegmenterObjcTestLibrary",
testonly = 1,
srcs = ["MPPInteractiveSegmenterTests.mm"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
"//mediapipe/tasks/ios/test/vision/utils:MPPMaskTestUtils",
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
"//mediapipe/tasks/ios/vision/interactive_segmenter:MPPInteractiveSegmenter",
] + select({
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
}),
)

ios_unit_test(
name = "MPPInteractiveSegmenterObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPInteractiveSegmenterObjcTestLibrary",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#import <Foundation/Foundation.h>
#import <XCTest/XCTest.h>

#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPMask+TestUtils.h"
#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenter.h"

#include <iostream>
#include <vector>

static MPPFileInfo *const kCatsAndDogsImageFileInfo =
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs" type:@"jpg"];
static MPPFileInfo *const kCatsAndDogsMaskImage1FileInfo =
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs_mask_dog1" type:@"png"];
static MPPFileInfo *const kCatsAndDogsMaskImage2FileInfo =
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs_mask_dog2" type:@"png"];

static MPPFileInfo *const kDeepLabModelFileInfo =
[[MPPFileInfo alloc] initWithName:@"ptm_512_hdt_ptm_woid" type:@"tflite"];

static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";

constexpr NSInteger kMagnificationFactor = 255;

#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)

namespace {
double sum(const std::vector<float> &mask) {
double sum = 0.0;
for (const float &maskElement : mask) {
sum += maskElement;
}
return sum;
}

std::vector<float> multiply(const float *mask1, const float *mask2, size_t size) {
std::vector<float> multipliedMask;
multipliedMask.reserve(size);

for (int i = 0; i < size; i++) {
multipliedMask.push_back(mask1[i] * mask2[i]);
}

return multipliedMask;
}

double softIOU(const float *mask1, const float *mask2, size_t size) {
std::vector<float> interSectionVector = multiply(mask1, mask2, size);
double interSectionSum = sum(interSectionVector);

std::vector<float> m1m1Vector = multiply(mask1, mask1, size);
double m1m1 = sum(m1m1Vector);

std::vector<float> m2m2Vector = multiply(mask2, mask2, size);
double m2m2 = sum(m2m2Vector);

double unionSum = m1m1 + m2m2 - interSectionSum;

return unionSum > 0.0 ? interSectionSum / unionSum : 0.0;
}
} // namespace

@interface MPPInteractiveSegmenterTests : XCTestCase
@end

@implementation MPPInteractiveSegmenterTests

#pragma mark General Tests

- (void)setUp {
// When expected and actual mask sizes are not equal, iterating through mask data results in a
// segmentation fault. Setting this property to `NO`, prevents each test case from executing the
// remaining flow after a failure. Since expected and actual mask sizes are compared before
// iterating through them, this prevents any illegal memory access.
self.continueAfterFailure = NO;
}

+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath =
[[NSBundle bundleForClass:[MPPInteractiveSegmenterTests class]] pathForResource:fileName
ofType:extension];
return filePath;
}

#pragma mark Image Mode Tests

- (void)testSegmentWithCategoryMaskSucceeds {
MPPInteractiveSegmenterOptions *options =
[self interactiveSegmenterOptionsWithModelFileInfo:kDeepLabModelFileInfo];
options.shouldOutputConfidenceMasks = NO;
options.shouldOutputCategoryMask = YES;

MPPInteractiveSegmenter *interactiveSegmenter =
[self createInteractiveSegmenterWithOptionsSucceeds:options];

MPPRegionOfInterest *regionOfInterest = [[MPPRegionOfInterest alloc]
initWithNormalizedKeyPoint:[[MPPNormalizedKeypoint alloc]
initWithLocation:CGPointMake(0.44, 0.7)
label:nil
score:0.0f]];
[self assertResultsOfSegmentImageWithFileInfo:kCatsAndDogsImageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:kCatsAndDogsMaskImage1FileInfo
withmaskSimilarityThreshold:0.84f
shouldHaveConfidenceMasks:NO];
}

- (void)testSegmentWithConfidenceMaskSucceeds {
MPPInteractiveSegmenterOptions *options =
[self interactiveSegmenterOptionsWithModelFileInfo:kDeepLabModelFileInfo];

MPPInteractiveSegmenter *interactiveSegmenter =
[self createInteractiveSegmenterWithOptionsSucceeds:options];

MPPRegionOfInterest *regionOfInterest = [[MPPRegionOfInterest alloc]
initWithNormalizedKeyPoint:[[MPPNormalizedKeypoint alloc]
initWithLocation:CGPointMake(0.44, 0.7)
label:nil
score:0.0f]];

[self assertResultsOfSegmentImageWithFileInfo:kCatsAndDogsImageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter
hasConfidenceMasksCount:2
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:kCatsAndDogsMaskImage1FileInfo
atIndex:1
withmaskSimilarityThreshold:0.84f
shouldHaveCategoryMask:NO];
}

#pragma mark - Image Segmenter Initializers

- (MPPInteractiveSegmenterOptions *)interactiveSegmenterOptionsWithModelFileInfo:
(MPPFileInfo *)fileInfo {
MPPInteractiveSegmenterOptions *options = [[MPPInteractiveSegmenterOptions alloc] init];
options.baseOptions.modelAssetPath = fileInfo.path;
return options;
}

- (MPPInteractiveSegmenter *)createInteractiveSegmenterWithOptionsSucceeds:
(MPPInteractiveSegmenterOptions *)options {
NSError *error;
MPPInteractiveSegmenter *interactiveSegmenter =
[[MPPInteractiveSegmenter alloc] initWithOptions:options error:&error];
XCTAssertNotNil(interactiveSegmenter);
XCTAssertNil(error);

return interactiveSegmenter;
}

- (void)assertCreateInteractiveSegmenterWithOptions:(MPPInteractiveSegmenterOptions *)options
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPInteractiveSegmenter *interactiveSegmenter =
[[MPPInteractiveSegmenter alloc] initWithOptions:options error:&error];

XCTAssertNil(interactiveSegmenter);
AssertEqualErrors(error, expectedError);
}

#pragma mark Assert Segmenter Results
- (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
usingInteractiveSegmenter:
(MPPInteractiveSegmenter *)interactiveSegmenter
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:
(MPPFileInfo *)expectedCategoryMaskFileInfo
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold
shouldHaveConfidenceMasks:(BOOL)shouldHaveConfidenceMasks {
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];

XCTAssertNotNil(result.categoryMask);

if (shouldHaveConfidenceMasks) {
XCTAssertNotNil(result.confidenceMasks);
} else {
XCTAssertNil(result.confidenceMasks);
}

[self assertCategoryMask:result.categoryMask
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:expectedCategoryMaskFileInfo
withmaskSimilarityThreshold:maskSimilarityThreshold];
}

- (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
regionOfInterest:
(MPPRegionOfInterest *)regionOfInterest
usingInteractiveSegmenter:
(MPPInteractiveSegmenter *)interactiveSegmenter
hasConfidenceMasksCount:
(NSUInteger)expectedConfidenceMasksCount
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
(MPPFileInfo *)expectedConfidenceMaskFileInfo
atIndex:(NSInteger)index
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold
shouldHaveCategoryMask:(BOOL)shouldHaveCategoryMask {
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];

[self assertInteractiveSegmenterResult:result
hasConfidenceMasksCount:expectedConfidenceMasksCount
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:expectedConfidenceMaskFileInfo
atIndex:index
withmaskSimilarityThreshold:maskSimilarityThreshold
shouldHaveCategoryMask:shouldHaveCategoryMask];
}

- (void)assertInteractiveSegmenterResult:(MPPImageSegmenterResult *)result
hasConfidenceMasksCount:
(NSUInteger)expectedConfidenceMasksCount
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
(MPPFileInfo *)expectedConfidenceMaskFileInfo
atIndex:(NSInteger)index
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold
shouldHaveCategoryMask:(BOOL)shouldHaveCategoryMask {
XCTAssertNotNil(result.confidenceMasks);

XCTAssertEqual(result.confidenceMasks.count, expectedConfidenceMasksCount);

if (shouldHaveCategoryMask) {
XCTAssertNotNil(result.categoryMask);
} else {
XCTAssertNil(result.categoryMask);
}

XCTAssertLessThan(index, result.confidenceMasks.count);

[self assertConfidenceMask:result.confidenceMasks[index]
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:expectedConfidenceMaskFileInfo
withmaskSimilarityThreshold:maskSimilarityThreshold];
}

- (MPPImageSegmenterResult *)segmentImageWithFileInfo:(MPPFileInfo *)fileInfo
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
usingInteractiveSegmenter:
(MPPInteractiveSegmenter *)interactiveSegmenter {
MPPImage *image = [MPPImage imageWithFileInfo:fileInfo];
XCTAssertNotNil(image);

NSError *error;

MPPImageSegmenterResult *result = [interactiveSegmenter segmentImage:image
regionOfInterest:regionOfInterest
error:&error];

XCTAssertNil(error);
XCTAssertNotNil(result);

return result;
}

- (void)assertCategoryMask:(MPPMask *)categoryMask
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:
(MPPFileInfo *)expectedCategoryMaskImageFileInfo
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold {
MPPMask *expectedCategoryMask =
[[MPPMask alloc] initWithImageFileInfo:expectedCategoryMaskImageFileInfo];

XCTAssertEqual(categoryMask.width, expectedCategoryMask.width);
XCTAssertEqual(categoryMask.height, expectedCategoryMask.height);

size_t maskSize = categoryMask.width * categoryMask.height;

const UInt8 *categoryMaskPixelData = categoryMask.uint8Data;
const UInt8 *expectedCategoryMaskPixelData = expectedCategoryMask.uint8Data;

NSInteger consistentPixels = 0;

for (int i = 0; i < maskSize; i++) {
consistentPixels +=
categoryMaskPixelData[i] * kMagnificationFactor == expectedCategoryMaskPixelData[i] ? 1 : 0;
}

XCTAssertGreaterThan((float)consistentPixels / (float)maskSize, maskSimilarityThreshold);
}

- (void)assertConfidenceMask:(MPPMask *)confidenceMask
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
(MPPFileInfo *)expectedConfidenceMaskImageFileInfo
withmaskSimilarityThreshold:
(const float)maskSimilarityThreshold {
MPPMask *expectedConfidenceMask =
[[MPPMask alloc] initWithImageFileInfo:expectedConfidenceMaskImageFileInfo];

XCTAssertEqual(confidenceMask.width, expectedConfidenceMask.width);
XCTAssertEqual(confidenceMask.height, expectedConfidenceMask.height);

size_t maskSize = confidenceMask.width * confidenceMask.height;

XCTAssertGreaterThan(
softIOU(confidenceMask.float32Data, expectedConfidenceMask.float32Data, maskSize),
maskSimilarityThreshold);
}

@end
Loading

0 comments on commit dbcd3d0

Please sign in to comment.