diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index 806a52a729..360097ea73 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -71,6 +71,7 @@ TENSORFLOW_LITE_C_DEPS = [ CALCULATORS_AND_GRAPHS = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", + "//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", @@ -116,6 +117,9 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifier.h", "//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifierOptions.h", "//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifierResult.h", + "//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedder.h", + "//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedderOptions.h", + "//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedderResult.h", "//mediapipe/tasks/ios/audio/core:sources/MPPAudioData.h", "//mediapipe/tasks/ios/audio/core:sources/MPPAudioDataFormat.h", "//mediapipe/tasks/ios/audio/core:sources/MPPAudioRecord.h", @@ -194,6 +198,9 @@ apple_static_xcframework( ":MPPAudioClassifier.h", ":MPPAudioClassifierOptions.h", ":MPPAudioClassifierResult.h", + ":MPPAudioEmbedder.h", + ":MPPAudioEmbedderOptions.h", + ":MPPAudioEmbedderResult.h", ":MPPAudioData.h", ":MPPAudioDataFormat.h", ":MPPAudioRecord.h", @@ -202,12 +209,15 @@ apple_static_xcframework( ":MPPCategory.h", ":MPPClassificationResult.h", ":MPPCommon.h", + ":MPPEmbedding.h", + ":MPPEmbeddingResult.h", ":MPPFloatBuffer.h", ":MPPTaskOptions.h", ":MPPTaskResult.h", ], deps = [ "//mediapipe/tasks/ios/audio/audio_classifier:MPPAudioClassifier", + "//mediapipe/tasks/ios/audio/audio_embedder:MPPAudioEmbedder", ], ) diff --git a/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedder.h b/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedder.h index 5ba7aa8d67..ec1cacc317 100644 --- a/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedder.h +++ b/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedder.h @@ -124,7 +124,7 @@ NS_SWIFT_NAME(AudioEmbedder) - (BOOL)embedAsyncAudioBlock:(MPPAudioData *)audioBlock timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error - NS_SWIFT_NAME(classifyAsync(audioBlock:timestampInMilliseconds:)); + NS_SWIFT_NAME(embedAsync(audioBlock:timestampInMilliseconds:)); /** * Closes and cleans up the MediaPipe audio embedder. diff --git a/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedderResult.h b/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedderResult.h index 497705225a..6fa433ebf4 100644 --- a/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedderResult.h +++ b/mediapipe/tasks/ios/audio/audio_embedder/sources/MPPAudioEmbedderResult.h @@ -43,7 +43,7 @@ NS_SWIFT_NAME(AudioEmbedderResult) * @return An instance of `AudioEmbedderResult` initialized with the given array of * `EmbeddingResult` objects and timestamp (in milliseconds). */ -- (instancetype)initWithEmbeddingResults:(nullable NSArray *)embeddingResults +- (instancetype)initWithEmbeddingResults:(NSArray *)embeddingResults timestampInMilliseconds:(NSInteger)timestampInMilliseconds; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/audio/audio_embedder/utils/sources/MPPAudioEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/audio/audio_embedder/utils/sources/MPPAudioEmbedderResult+Helpers.mm index 7cf565ec02..92d0f7f9f7 100644 --- a/mediapipe/tasks/ios/audio/audio_embedder/utils/sources/MPPAudioEmbedderResult+Helpers.mm +++ b/mediapipe/tasks/ios/audio/audio_embedder/utils/sources/MPPAudioEmbedderResult+Helpers.mm @@ -42,7 +42,7 @@ + (MPPAudioEmbedderResult *)audioEmbedderResultWithEmbeddingResultPacket:(const cppEmbeddingResults = packet.Get>(); } else { // If packet does not contain protobuf of a type expected by the audio embedder. - return [[MPPAudioEmbedderResult alloc] initWithEmbeddingResults:nil + return [[MPPAudioEmbedderResult alloc] initWithEmbeddingResults:@[] timestampInMilliseconds:timestampInMilliseconds]; } diff --git a/mediapipe/tasks/ios/test/audio/audio_embedder/MPPAudioEmbedderTests.mm b/mediapipe/tasks/ios/test/audio/audio_embedder/MPPAudioEmbedderTests.mm index d5e5612832..d2bba80ab6 100644 --- a/mediapipe/tasks/ios/test/audio/audio_embedder/MPPAudioEmbedderTests.mm +++ b/mediapipe/tasks/ios/test/audio/audio_embedder/MPPAudioEmbedderTests.mm @@ -22,6 +22,8 @@ #import "mediapipe/tasks/ios/test/audio/core/utils/sources/MPPAudioData+TestUtils.h" #import "mediapipe/tasks/ios/test/utils/sources/MPPFileInfo.h" +#include + static MPPFileInfo *const kYamnetModelFileInfo = [[MPPFileInfo alloc] initWithName:@"yamnet_embedding_metadata" type:@"tflite"]; static MPPFileInfo *const kSpeech16KHzMonoFileInfo = @@ -62,7 +64,11 @@ XCTAssertEqual(embedding.floatEmbedding.count, expectedLength); \ } -@interface MPPAudioEmbedderTests : XCTestCase +@interface MPPAudioEmbedderTests : XCTestCase { + NSDictionary *_16kHZAudioStreamSucceedsTestDict; + NSDictionary *_48kHZAudioStreamSucceedsTestDict; + NSDictionary *_outOfOrderTimestampTestDict; +} @end @implementation MPPAudioEmbedderTests @@ -185,6 +191,14 @@ - (void)testEmbedWithSilenceSucceeds { isQuantized:options.quantize expectedEmbeddingResultsCount:expectedEmbedderResultsCount expectedEmbeddingLength:kExpectedEmbeddingLength]; + const std::array expectedEmbeddingValuesSubset = {2.07613f, 0.392721f, 0.543622f}; + const float valueDifferenceTolerance = 4e-6f; + + NSArray *floatEmbedding = result.embeddingResults[0].embeddings[0].floatEmbedding; + for (int i = 0; i < expectedEmbeddingValuesSubset.size(); i++) { + XCTAssertEqualWithAccuracy(expectedEmbeddingValuesSubset[i], floatEmbedding[i].floatValue, + valueDifferenceTolerance); + } } - (void)testEmbedAfterCloseFailsInAudioClipsMode { @@ -337,17 +351,120 @@ - (void)testCreateAudioRecordWithInvalidChannelCountFails { AssertEqualErrors(error, expectedError); } +- (void)testEmbedWithAudioStreamModeAndOutOfOrderTimestampsFails { + MPPAudioEmbedder *audioEmbedder = + [self audioEmbedderInStreamModeWithModelFileInfo:kYamnetModelFileInfo]; + NSArray *streamedAudioDataList = + [MPPAudioEmbedderTests streamedAudioDataListforYamnet]; + + XCTestExpectation *expectation = + [[XCTestExpectation alloc] initWithDescription:@"embedWithOutOfOrderTimestampsAndLiveStream"]; + expectation.expectedFulfillmentCount = 1; + + _outOfOrderTimestampTestDict = @{ + kAudioStreamTestsDictEmbedderKey : audioEmbedder, + kAudioStreamTestsDictExpectationKey : expectation + }; + + // Can safely access indices 1 and 0 `streamedAudioDataList` count is already asserted. + XCTAssertTrue([audioEmbedder embedAsyncAudioBlock:streamedAudioDataList[1].audioData + timestampInMilliseconds:streamedAudioDataList[1].timestampInMilliseconds + error:nil]); + + NSError *error; + XCTAssertFalse([audioEmbedder + embedAsyncAudioBlock:streamedAudioDataList[0].audioData + timestampInMilliseconds:streamedAudioDataList[0].timestampInMilliseconds + error:&error]); + + NSError *expectedError = + [NSError errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing." + }]; + AssertEqualErrors(error, expectedError); + + [audioEmbedder closeWithError:nil]; +} + +- (void)testClassifyWithAudioStreamModeSucceeds { + [self embedUsingYamnetAsyncAudioFileWithInfo:kSpeech16KHzMonoFileInfo + info:&_16kHZAudioStreamSucceedsTestDict]; + [self embedUsingYamnetAsyncAudioFileWithInfo:kSpeech48KHzMonoFileInfo + info:&_48kHZAudioStreamSucceedsTestDict]; +} + #pragma mark MPPAudioEmbedderStreamDelegate - (void)audioEmbedder:(MPPAudioEmbedder *)audioEmbedder didFinishEmbeddingWithResult:(MPPAudioEmbedderResult *)result timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError *)error { - // TODO: Add assertion for the result when stream mode inference tests are added. + // Can safely test for yamnet results before `audioEmbedder` object tests since only yamnet with + // 16khz and 48khz speech files are used for async tests. + [MPPAudioEmbedderTests + assertNonQuantizedAudioEmbedderYamnetStreamModeResult:result + timestampInMilliseconds:timestampInMilliseconds + expectedEmbeddingLength:kExpectedEmbeddingLength]; + + if (audioEmbedder == _outOfOrderTimestampTestDict[kAudioStreamTestsDictEmbedderKey]) { + [_outOfOrderTimestampTestDict[kAudioStreamTestsDictExpectationKey] fulfill]; + } else if (audioEmbedder == _16kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictEmbedderKey]) { + [_16kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictExpectationKey] fulfill]; + } else if (audioEmbedder == _48kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictEmbedderKey]) { + [_48kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictExpectationKey] fulfill]; + } +} + +#pragma mark Audio Stream Mode Test Helpers + +// info is strong here since address of global variables will be passed to this function. By default +// `NSDictionary **` will be `NSDictionary * __autoreleasing *. +- (void)embedUsingYamnetAsyncAudioFileWithInfo:(MPPFileInfo *)audioFileInfo + info:(NSDictionary *__strong *)info { + MPPAudioEmbedder *audioEmbedder = + [self audioEmbedderInStreamModeWithModelFileInfo:kYamnetModelFileInfo]; + + NSArray *streamedAudioDataList = + [MPPAudioEmbedderTests streamedAudioDataListforYamnet]; + + XCTestExpectation *expectation = [[XCTestExpectation alloc] + initWithDescription:[NSString + stringWithFormat:@"embedWithStreamMode_%@", audioFileInfo.name]]; + expectation.expectedFulfillmentCount = streamedAudioDataList.count; + + *info = @{ + kAudioStreamTestsDictEmbedderKey : audioEmbedder, + kAudioStreamTestsDictExpectationKey : expectation + }; + + for (MPPTimestampedAudioData *timestampedAudioData in streamedAudioDataList) { + XCTAssertTrue([audioEmbedder embedAsyncAudioBlock:timestampedAudioData.audioData + timestampInMilliseconds:timestampedAudioData.timestampInMilliseconds + error:nil]); + } + + [audioEmbedder closeWithError:nil]; + + NSTimeInterval timeout = 1.0f; + [self waitForExpectations:@[ expectation ] timeout:timeout]; } #pragma mark Audio Embedder Initializers +- (MPPAudioEmbedder *)audioEmbedderInStreamModeWithModelFileInfo:(MPPFileInfo *)fileInfo { + MPPAudioEmbedderOptions *options = + [MPPAudioEmbedderTests audioEmbedderOptionsWithModelFileInfo:fileInfo]; + options.runningMode = MPPAudioRunningModeAudioStream; + options.audioEmbedderStreamDelegate = self; + + MPPAudioEmbedder *audioEmbedder = [MPPAudioEmbedderTests audioEmbedderWithOptions:options]; + + return audioEmbedder; +} + + (MPPAudioEmbedderOptions *)audioEmbedderOptionsWithModelFileInfo:(MPPFileInfo *)modelFileInfo { MPPAudioEmbedderOptions *options = [[MPPAudioEmbedderOptions alloc] init]; options.baseOptions.modelAssetPath = modelFileInfo.path; @@ -421,6 +538,21 @@ + (void)assertAudioEmbedderResult:(MPPAudioEmbedderResult *)result } } ++ (void)assertNonQuantizedAudioEmbedderYamnetStreamModeResult:(MPPAudioEmbedderResult *)result + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + expectedEmbeddingLength:(NSInteger)expectedEmbeddingLength { + // In stream mode, `result` will only have one embedding result corresponding to the timestamp at + // which it was sent for inference. + XCTAssertEqual(result.embeddingResults.count, 1); + + MPPEmbeddingResult *embeddingResult = result.embeddingResults[0]; + AssertEmbeddingResultHasOneEmbedding(embeddingResult); + AssertEmbeddingHasCorrectTypeAndDimension(embeddingResult.embeddings[0], NO, + expectedEmbeddingLength); + + XCTAssertEqual(result.timestampInMilliseconds, timestampInMilliseconds); +} + + (MPPAudioEmbedderResult *)embedAudioClipWithFileInfo:(MPPFileInfo *)fileInfo usingAudioEmbedder:(MPPAudioEmbedder *)audioEmbedder { MPPAudioData *audioData = [[MPPAudioData alloc] initWithFileInfo:fileInfo]; @@ -430,4 +562,15 @@ + (MPPAudioEmbedderResult *)embedAudioClipWithFileInfo:(MPPFileInfo *)fileInfo return result; } ++ (NSArray *)streamedAudioDataListforYamnet { + NSArray *streamedAudioDataList = + [AVAudioFile streamedAudioBlocksFromAudioFileWithInfo:kSpeech16KHzMonoFileInfo + modelSampleCount:kYamnetSampleCount + modelSampleRate:kYamnetSampleRate]; + + XCTAssertEqual(streamedAudioDataList.count, 5); + + return streamedAudioDataList; +} + @end