Skip to content

Commit

Permalink
Merge pull request #5076 from priankakariatyml:ios-interactive-segmen…
Browse files Browse the repository at this point in the history
…ter-result

PiperOrigin-RevId: 599273039
  • Loading branch information
copybara-github committed Jan 17, 2024
2 parents 6fc47c8 + 13e8cbf commit b62093b
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ NS_SWIFT_NAME(RegionOfInterest)
* @return An instance of `RegionOfInterest` initialized with the given normalized key points that
* make up scribbles over the object that the user wants to segment.
*/
- (instancetype)initWitScribbles:(NSArray<MPPNormalizedKeypoint *> *)scribbles
- (instancetype)initWithScribbles:(NSArray<MPPNormalizedKeypoint *> *)scribbles
NS_DESIGNATED_INITIALIZER;

- (instancetype)init NS_UNAVAILABLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ - (instancetype)initWithNormalizedKeyPoint:(MPPNormalizedKeypoint *)normalizedKe
return self;
}

- (instancetype)initWitScribbles:(NSArray<MPPNormalizedKeypoint *> *)scribbles {
- (instancetype)initWithScribbles:(NSArray<MPPNormalizedKeypoint *> *)scribbles {
self = [super init];
if (self) {
_scribbles = scribbles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ objc_library(
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
"//mediapipe/tasks/ios/test/vision/utils:MPPMaskTestUtils",
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
"//mediapipe/tasks/ios/vision/interactive_segmenter:MPPInteractiveSegmenter",
"//mediapipe/tasks/ios/vision/interactive_segmenter:MPPInteractiveSegmenterResult",
] + select({
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPMask+TestUtils.h"
#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenter.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenterResult.h"

#include <iostream>
#include <vector>
Expand Down Expand Up @@ -188,9 +188,9 @@ - (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
(MPPFileInfo *)expectedCategoryMaskFileInfo
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold
shouldHaveConfidenceMasks:(BOOL)shouldHaveConfidenceMasks {
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];
MPPInteractiveSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];

XCTAssertNotNil(result.categoryMask);

Expand All @@ -217,9 +217,9 @@ - (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
atIndex:(NSInteger)index
withmaskSimilarityThreshold:(const float)maskSimilarityThreshold
shouldHaveCategoryMask:(BOOL)shouldHaveCategoryMask {
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];
MPPInteractiveSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
regionOfInterest:regionOfInterest
usingInteractiveSegmenter:interactiveSegmenter];

[self assertInteractiveSegmenterResult:result
hasConfidenceMasksCount:expectedConfidenceMasksCount
Expand All @@ -229,7 +229,7 @@ - (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
shouldHaveCategoryMask:shouldHaveCategoryMask];
}

- (void)assertInteractiveSegmenterResult:(MPPImageSegmenterResult *)result
- (void)assertInteractiveSegmenterResult:(MPPInteractiveSegmenterResult *)result
hasConfidenceMasksCount:
(NSUInteger)expectedConfidenceMasksCount
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
Expand All @@ -254,18 +254,18 @@ - (void)assertInteractiveSegmenterResult:(MPPImageSegmenterResult *)result
withmaskSimilarityThreshold:maskSimilarityThreshold];
}

- (MPPImageSegmenterResult *)segmentImageWithFileInfo:(MPPFileInfo *)fileInfo
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
usingInteractiveSegmenter:
(MPPInteractiveSegmenter *)interactiveSegmenter {
- (MPPInteractiveSegmenterResult *)segmentImageWithFileInfo:(MPPFileInfo *)fileInfo
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
usingInteractiveSegmenter:
(MPPInteractiveSegmenter *)interactiveSegmenter {
MPPImage *image = [MPPImage imageWithFileInfo:fileInfo];
XCTAssertNotNil(image);

NSError *error;

MPPImageSegmenterResult *result = [interactiveSegmenter segmentImage:image
regionOfInterest:regionOfInterest
error:&error];
MPPInteractiveSegmenterResult *result = [interactiveSegmenter segmentImage:image
regionOfInterest:regionOfInterest
error:&error];

XCTAssertNil(error);
XCTAssertNotNil(result);
Expand Down
14 changes: 12 additions & 2 deletions mediapipe/tasks/ios/vision/interactive_segmenter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ objc_library(
deps = ["//mediapipe/tasks/ios/core:MPPTaskOptions"],
)

objc_library(
name = "MPPInteractiveSegmenterResult",
srcs = ["sources/MPPInteractiveSegmenterResult.m"],
hdrs = ["sources/MPPInteractiveSegmenterResult.h"],
deps = [
"//mediapipe/tasks/ios/core:MPPTaskResult",
"//mediapipe/tasks/ios/vision/core:MPPMask",
],
)

objc_library(
name = "MPPInteractiveSegmenter",
srcs = ["sources/MPPInteractiveSegmenter.mm"],
Expand All @@ -35,6 +45,7 @@ objc_library(
module_name = "MPPInteractiveSegmenter",
deps = [
":MPPInteractiveSegmenterOptions",
":MPPInteractiveSegmenterResult",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
Expand All @@ -46,9 +57,8 @@ objc_library(
"//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
"//mediapipe/tasks/ios/vision/image_segmenter/utils:MPPImageSegmenterResultHelpers",
"//mediapipe/tasks/ios/vision/interactive_segmenter/utils:MPPInteractiveSegmenterOptionsHelpers",
"//mediapipe/tasks/ios/vision/interactive_segmenter/utils:MPPInteractiveSegmenterResultHelpers",
"//mediapipe/util:label_map_cc_proto",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#import "mediapipe/tasks/ios/components/containers/sources/MPPRegionOfInterest.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenterOptions.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenterResult.h"

NS_ASSUME_NONNULL_BEGIN

Expand Down Expand Up @@ -93,11 +93,11 @@ NS_SWIFT_NAME(InteractiveSegmenter)
*
* @param image The `MPImage` on which segmentation is to be performed.
*
* @return An `ImageSegmenterResult` that contains the segmented masks.
* @return An `InteractiveSegmenterResult` that contains the segmented masks.
*/
- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
error:(NSError **)error
- (nullable MPPInteractiveSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
error:(NSError **)error
NS_SWIFT_NAME(segment(image:regionOfInterest:));

/**
Expand All @@ -117,13 +117,13 @@ NS_SWIFT_NAME(InteractiveSegmenter)
*
* @param image The `MPImage` on which segmentation is to be performed.
* @param completionHandler A block to be invoked with the results of performing segmentation on the
* image. The block takes two arguments, the optional `ImageSegmenterResult` that contains the
* image. The block takes two arguments, the optional `InteractiveSegmenterResult` that contains the
* segmented masks if the segmentation was successful and an optional error populated upon failure.
* The lifetime of the returned masks is only guaranteed for the duration of the block.
*/
- (void)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
withCompletionHandler:(void (^)(MPPImageSegmenterResult *_Nullable result,
withCompletionHandler:(void (^)(MPPInteractiveSegmenterResult *_Nullable result,
NSError *_Nullable error))completionHandler
NS_SWIFT_NAME(segment(image:regionOfInterest:completion:));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
#import "mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/utils/sources/MPPInteractiveSegmenterOptions+Helpers.h"
#import "mediapipe/tasks/ios/vision/interactive_segmenter/utils/sources/MPPInteractiveSegmenterResult+Helpers.h"

#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/util/label_map.pb.h"
Expand Down Expand Up @@ -136,9 +136,9 @@ - (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error
return [self initWithOptions:options error:error];
}

- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
error:(NSError **)error {
- (nullable MPPInteractiveSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
error:(NSError **)error {
return [self segmentImage:image
regionOfInterest:regionOfInterest
shouldCopyOutputMaskPacketData:YES
Expand All @@ -147,13 +147,13 @@ - (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image

- (void)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
withCompletionHandler:(void (^)(MPPImageSegmenterResult *_Nullable result,
withCompletionHandler:(void (^)(MPPInteractiveSegmenterResult *_Nullable result,
NSError *_Nullable error))completionHandler {
NSError *error = nil;
MPPImageSegmenterResult *result = [self segmentImage:image
regionOfInterest:regionOfInterest
shouldCopyOutputMaskPacketData:NO
error:&error];
MPPInteractiveSegmenterResult *result = [self segmentImage:image
regionOfInterest:regionOfInterest
shouldCopyOutputMaskPacketData:NO
error:&error];
completionHandler(result, error);
}

Expand Down Expand Up @@ -197,10 +197,10 @@ - (void)segmentImage:(MPPImage *)image
return labels;
}

- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
shouldCopyOutputMaskPacketData:(BOOL)shouldCopyMaskPacketData
error:(NSError **)error {
- (nullable MPPInteractiveSegmenterResult *)segmentImage:(MPPImage *)image
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
shouldCopyOutputMaskPacketData:(BOOL)shouldCopyMaskPacketData
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [_visionTaskRunner inputPacketMapWithMPPImage:image
regionOfInterest:CGRectZero
error:error];
Expand All @@ -222,32 +222,33 @@ - (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
[_visionTaskRunner processPacketMap:inputPacketMap.value() error:error];

return [MPPInteractiveSegmenter
imageSegmenterResultWithOptionalOutputPacketMap:outputPacketMap
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
interactiveSegmenterResultWithOptionalOutputPacketMap:outputPacketMap
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
}

+ (nullable MPPImageSegmenterResult *)
imageSegmenterResultWithOptionalOutputPacketMap:(std::optional<PacketMap> &)outputPacketMap
shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData {
+ (nullable MPPInteractiveSegmenterResult *)
interactiveSegmenterResultWithOptionalOutputPacketMap:
(std::optional<PacketMap> &)outputPacketMap
shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData {
if (!outputPacketMap.has_value()) {
return nil;
}

PacketMap &outputPacketMapValue = outputPacketMap.value();

return [MPPImageSegmenterResult
imageSegmenterResultWithConfidenceMasksPacket:outputPacketMapValue[kConfidenceMasksStreamName
.cppString]
categoryMaskPacket:outputPacketMapValue[kCategoryMaskStreamName
.cppString]
qualityScoresPacket:outputPacketMapValue[kQualityScoresStreamName
.cppString]
timestampInMilliseconds:outputPacketMapValue[kImageOutStreamName
.cppString]
.Timestamp()
.Value() /
kMicrosecondsPerMillisecond
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
return [MPPInteractiveSegmenterResult
interactiveSegmenterResultWithConfidenceMasksPacket:outputPacketMapValue
[kConfidenceMasksStreamName.cppString]
categoryMaskPacket:outputPacketMapValue
[kCategoryMaskStreamName.cppString]
qualityScoresPacket:outputPacketMapValue
[kQualityScoresStreamName.cppString]
timestampInMilliseconds:outputPacketMapValue[kImageOutStreamName
.cppString]
.Timestamp()
.Value() /
kMicrosecondsPerMillisecond
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
}

@end
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"

NS_ASSUME_NONNULL_BEGIN

/** Represents the segmentation results generated by `ImageSegmenter`. */
NS_SWIFT_NAME(InteractiveSegmenterResult)
@interface MPPInteractiveSegmenterResult : MPPTaskResult

/**
* An optional array of `Mask` objects. Each `Mask` in the array holds a 32 bit float array of size
* `image width` * `image height` which represents the confidence mask for each category. Each
* element of the float array represents the confidence with which the model predicted that the
* corresponding pixel belongs to the category that the mask represents, usually in the range [0,1].
*/
@property(nonatomic, readonly, nullable) NSArray<MPPMask *> *confidenceMasks;

/**
* An optional `Mask` that holds a`UInt8` array of size `image width` * `image height`. Each element
* of this array represents the class to which the pixel in the original image was predicted to
* belong to.
*/
@property(nonatomic, readonly, nullable) MPPMask *categoryMask;

/**
* The quality scores of the result masks, in the range of [0, 1]. Defaults to `1` if the model
* doesn't output quality scores. Each element corresponds to the score of the category in the model
* outputs.
*/
@property(nonatomic, readonly, nullable) NSArray<NSNumber *> *qualityScores;

/**
* Initializes a new `ImageSegmenterResult` with the given array of confidence masks, category mask,
* quality scores and timestamp (in milliseconds).
*
* @param confidenceMasks An optional array of `Mask` objects. Each `Mask` in the array must
* be of type `float32`.
* @param categoryMask An optional `Mask` object of type `uInt8`.
* @param qualityScores The quality scores of the result masks of type NSArray<NSNumber *> *. Each
* `NSNumber` in the array holds a `float`.
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
*
* @return An instance of `ImageSegmenterResult` initialized with the given array of confidence
* masks, category mask, quality scores and timestamp (in milliseconds).
*/
- (instancetype)initWithConfidenceMasks:(nullable NSArray<MPPMask *> *)confidenceMasks
categoryMask:(nullable MPPMask *)categoryMask
qualityScores:(nullable NSArray<NSNumber *> *)qualityScores
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;

@end

NS_ASSUME_NONNULL_END
Loading

0 comments on commit b62093b

Please sign in to comment.