Skip to content

Commit

Permalink
Merge pull request #5678 from priankakariatyml:ios-audio-embedder-upd…
Browse files Browse the repository at this point in the history
…ated-tests

PiperOrigin-RevId: 686969268
  • Loading branch information
copybara-github committed Oct 17, 2024
2 parents ea80187 + 4e6a57f commit b5b99fe
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 5 deletions.
10 changes: 10 additions & 0 deletions mediapipe/tasks/ios/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ TENSORFLOW_LITE_C_DEPS = [
CALCULATORS_AND_GRAPHS = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
Expand Down Expand Up @@ -116,6 +117,9 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifier.h",
"//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifierOptions.h",
"//mediapipe/tasks/ios/audio/audio_classifier:sources/MPPAudioClassifierResult.h",
"//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedder.h",
"//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedderOptions.h",
"//mediapipe/tasks/ios/audio/audio_embedder:sources/MPPAudioEmbedderResult.h",
"//mediapipe/tasks/ios/audio/core:sources/MPPAudioData.h",
"//mediapipe/tasks/ios/audio/core:sources/MPPAudioDataFormat.h",
"//mediapipe/tasks/ios/audio/core:sources/MPPAudioRecord.h",
Expand Down Expand Up @@ -194,6 +198,9 @@ apple_static_xcframework(
":MPPAudioClassifier.h",
":MPPAudioClassifierOptions.h",
":MPPAudioClassifierResult.h",
":MPPAudioEmbedder.h",
":MPPAudioEmbedderOptions.h",
":MPPAudioEmbedderResult.h",
":MPPAudioData.h",
":MPPAudioDataFormat.h",
":MPPAudioRecord.h",
Expand All @@ -202,12 +209,15 @@ apple_static_xcframework(
":MPPCategory.h",
":MPPClassificationResult.h",
":MPPCommon.h",
":MPPEmbedding.h",
":MPPEmbeddingResult.h",
":MPPFloatBuffer.h",
":MPPTaskOptions.h",
":MPPTaskResult.h",
],
deps = [
"//mediapipe/tasks/ios/audio/audio_classifier:MPPAudioClassifier",
"//mediapipe/tasks/ios/audio/audio_embedder:MPPAudioEmbedder",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ NS_SWIFT_NAME(AudioEmbedder)
- (BOOL)embedAsyncAudioBlock:(MPPAudioData *)audioBlock
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(audioBlock:timestampInMilliseconds:));
NS_SWIFT_NAME(embedAsync(audioBlock:timestampInMilliseconds:));

/**
* Closes and cleans up the MediaPipe audio embedder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ NS_SWIFT_NAME(AudioEmbedderResult)
* @return An instance of `AudioEmbedderResult` initialized with the given array of
* `EmbeddingResult` objects and timestamp (in milliseconds).
*/
- (instancetype)initWithEmbeddingResults:(nullable NSArray<MPPEmbeddingResult *> *)embeddingResults
- (instancetype)initWithEmbeddingResults:(NSArray<MPPEmbeddingResult *> *)embeddingResults
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;

- (instancetype)init NS_UNAVAILABLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ + (MPPAudioEmbedderResult *)audioEmbedderResultWithEmbeddingResultPacket:(const
cppEmbeddingResults = packet.Get<std::vector<EmbeddingResultProto>>();
} else {
// If packet does not contain protobuf of a type expected by the audio embedder.
return [[MPPAudioEmbedderResult alloc] initWithEmbeddingResults:nil
return [[MPPAudioEmbedderResult alloc] initWithEmbeddingResults:@[]
timestampInMilliseconds:timestampInMilliseconds];
}

Expand Down
147 changes: 145 additions & 2 deletions mediapipe/tasks/ios/test/audio/audio_embedder/MPPAudioEmbedderTests.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#import "mediapipe/tasks/ios/test/audio/core/utils/sources/MPPAudioData+TestUtils.h"
#import "mediapipe/tasks/ios/test/utils/sources/MPPFileInfo.h"

#include <array>

static MPPFileInfo *const kYamnetModelFileInfo =
[[MPPFileInfo alloc] initWithName:@"yamnet_embedding_metadata" type:@"tflite"];
static MPPFileInfo *const kSpeech16KHzMonoFileInfo =
Expand Down Expand Up @@ -62,7 +64,11 @@
XCTAssertEqual(embedding.floatEmbedding.count, expectedLength); \
}

@interface MPPAudioEmbedderTests : XCTestCase <MPPAudioEmbedderStreamDelegate>
@interface MPPAudioEmbedderTests : XCTestCase <MPPAudioEmbedderStreamDelegate> {
NSDictionary<NSString *, id> *_16kHZAudioStreamSucceedsTestDict;
NSDictionary<NSString *, id> *_48kHZAudioStreamSucceedsTestDict;
NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
}
@end

@implementation MPPAudioEmbedderTests
Expand Down Expand Up @@ -185,6 +191,14 @@ - (void)testEmbedWithSilenceSucceeds {
isQuantized:options.quantize
expectedEmbeddingResultsCount:expectedEmbedderResultsCount
expectedEmbeddingLength:kExpectedEmbeddingLength];
const std::array<float, 3> expectedEmbeddingValuesSubset = {2.07613f, 0.392721f, 0.543622f};
const float valueDifferenceTolerance = 4e-6f;

NSArray<NSNumber *> *floatEmbedding = result.embeddingResults[0].embeddings[0].floatEmbedding;
for (int i = 0; i < expectedEmbeddingValuesSubset.size(); i++) {
XCTAssertEqualWithAccuracy(expectedEmbeddingValuesSubset[i], floatEmbedding[i].floatValue,
valueDifferenceTolerance);
}
}

- (void)testEmbedAfterCloseFailsInAudioClipsMode {
Expand Down Expand Up @@ -337,17 +351,120 @@ - (void)testCreateAudioRecordWithInvalidChannelCountFails {
AssertEqualErrors(error, expectedError);
}

- (void)testEmbedWithAudioStreamModeAndOutOfOrderTimestampsFails {
MPPAudioEmbedder *audioEmbedder =
[self audioEmbedderInStreamModeWithModelFileInfo:kYamnetModelFileInfo];
NSArray<MPPTimestampedAudioData *> *streamedAudioDataList =
[MPPAudioEmbedderTests streamedAudioDataListforYamnet];

XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"embedWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = 1;

_outOfOrderTimestampTestDict = @{
kAudioStreamTestsDictEmbedderKey : audioEmbedder,
kAudioStreamTestsDictExpectationKey : expectation
};

// Can safely access indices 1 and 0 `streamedAudioDataList` count is already asserted.
XCTAssertTrue([audioEmbedder embedAsyncAudioBlock:streamedAudioDataList[1].audioData
timestampInMilliseconds:streamedAudioDataList[1].timestampInMilliseconds
error:nil]);

NSError *error;
XCTAssertFalse([audioEmbedder
embedAsyncAudioBlock:streamedAudioDataList[0].audioData
timestampInMilliseconds:streamedAudioDataList[0].timestampInMilliseconds
error:&error]);

NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}];
AssertEqualErrors(error, expectedError);

[audioEmbedder closeWithError:nil];
}

- (void)testClassifyWithAudioStreamModeSucceeds {
[self embedUsingYamnetAsyncAudioFileWithInfo:kSpeech16KHzMonoFileInfo
info:&_16kHZAudioStreamSucceedsTestDict];
[self embedUsingYamnetAsyncAudioFileWithInfo:kSpeech48KHzMonoFileInfo
info:&_48kHZAudioStreamSucceedsTestDict];
}

#pragma mark MPPAudioEmbedderStreamDelegate

- (void)audioEmbedder:(MPPAudioEmbedder *)audioEmbedder
didFinishEmbeddingWithResult:(MPPAudioEmbedderResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
// TODO: Add assertion for the result when stream mode inference tests are added.
// Can safely test for yamnet results before `audioEmbedder` object tests since only yamnet with
// 16khz and 48khz speech files are used for async tests.
[MPPAudioEmbedderTests
assertNonQuantizedAudioEmbedderYamnetStreamModeResult:result
timestampInMilliseconds:timestampInMilliseconds
expectedEmbeddingLength:kExpectedEmbeddingLength];

if (audioEmbedder == _outOfOrderTimestampTestDict[kAudioStreamTestsDictEmbedderKey]) {
[_outOfOrderTimestampTestDict[kAudioStreamTestsDictExpectationKey] fulfill];
} else if (audioEmbedder == _16kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictEmbedderKey]) {
[_16kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictExpectationKey] fulfill];
} else if (audioEmbedder == _48kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictEmbedderKey]) {
[_48kHZAudioStreamSucceedsTestDict[kAudioStreamTestsDictExpectationKey] fulfill];
}
}

#pragma mark Audio Stream Mode Test Helpers

// info is strong here since address of global variables will be passed to this function. By default
// `NSDictionary **` will be `NSDictionary * __autoreleasing *.
- (void)embedUsingYamnetAsyncAudioFileWithInfo:(MPPFileInfo *)audioFileInfo
info:(NSDictionary<NSString *, id> *__strong *)info {
MPPAudioEmbedder *audioEmbedder =
[self audioEmbedderInStreamModeWithModelFileInfo:kYamnetModelFileInfo];

NSArray<MPPTimestampedAudioData *> *streamedAudioDataList =
[MPPAudioEmbedderTests streamedAudioDataListforYamnet];

XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:[NSString
stringWithFormat:@"embedWithStreamMode_%@", audioFileInfo.name]];
expectation.expectedFulfillmentCount = streamedAudioDataList.count;

*info = @{
kAudioStreamTestsDictEmbedderKey : audioEmbedder,
kAudioStreamTestsDictExpectationKey : expectation
};

for (MPPTimestampedAudioData *timestampedAudioData in streamedAudioDataList) {
XCTAssertTrue([audioEmbedder embedAsyncAudioBlock:timestampedAudioData.audioData
timestampInMilliseconds:timestampedAudioData.timestampInMilliseconds
error:nil]);
}

[audioEmbedder closeWithError:nil];

NSTimeInterval timeout = 1.0f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}

#pragma mark Audio Embedder Initializers

- (MPPAudioEmbedder *)audioEmbedderInStreamModeWithModelFileInfo:(MPPFileInfo *)fileInfo {
MPPAudioEmbedderOptions *options =
[MPPAudioEmbedderTests audioEmbedderOptionsWithModelFileInfo:fileInfo];
options.runningMode = MPPAudioRunningModeAudioStream;
options.audioEmbedderStreamDelegate = self;

MPPAudioEmbedder *audioEmbedder = [MPPAudioEmbedderTests audioEmbedderWithOptions:options];

return audioEmbedder;
}

+ (MPPAudioEmbedderOptions *)audioEmbedderOptionsWithModelFileInfo:(MPPFileInfo *)modelFileInfo {
MPPAudioEmbedderOptions *options = [[MPPAudioEmbedderOptions alloc] init];
options.baseOptions.modelAssetPath = modelFileInfo.path;
Expand Down Expand Up @@ -421,6 +538,21 @@ + (void)assertAudioEmbedderResult:(MPPAudioEmbedderResult *)result
}
}

+ (void)assertNonQuantizedAudioEmbedderYamnetStreamModeResult:(MPPAudioEmbedderResult *)result
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
expectedEmbeddingLength:(NSInteger)expectedEmbeddingLength {
// In stream mode, `result` will only have one embedding result corresponding to the timestamp at
// which it was sent for inference.
XCTAssertEqual(result.embeddingResults.count, 1);

MPPEmbeddingResult *embeddingResult = result.embeddingResults[0];
AssertEmbeddingResultHasOneEmbedding(embeddingResult);
AssertEmbeddingHasCorrectTypeAndDimension(embeddingResult.embeddings[0], NO,
expectedEmbeddingLength);

XCTAssertEqual(result.timestampInMilliseconds, timestampInMilliseconds);
}

+ (MPPAudioEmbedderResult *)embedAudioClipWithFileInfo:(MPPFileInfo *)fileInfo
usingAudioEmbedder:(MPPAudioEmbedder *)audioEmbedder {
MPPAudioData *audioData = [[MPPAudioData alloc] initWithFileInfo:fileInfo];
Expand All @@ -430,4 +562,15 @@ + (MPPAudioEmbedderResult *)embedAudioClipWithFileInfo:(MPPFileInfo *)fileInfo
return result;
}

+ (NSArray<MPPTimestampedAudioData *> *)streamedAudioDataListforYamnet {
NSArray<MPPTimestampedAudioData *> *streamedAudioDataList =
[AVAudioFile streamedAudioBlocksFromAudioFileWithInfo:kSpeech16KHzMonoFileInfo
modelSampleCount:kYamnetSampleCount
modelSampleRate:kYamnetSampleRate];

XCTAssertEqual(streamedAudioDataList.count, 5);

return streamedAudioDataList;
}

@end

0 comments on commit b5b99fe

Please sign in to comment.