Skip to content

Commit

Permalink
Merge pull request #5640 from priankakariatyml:audio-classifier-impl
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678331552
  • Loading branch information
copybara-github committed Sep 24, 2024
2 parents 6e96542 + 63bb10d commit fae654f
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ NS_SWIFT_NAME(AudioClassifier)
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(audioBlock:timestampInMilliseconds:));

/**
* Closes and cleans up the MediaPipe audio classifier.
*
* For audio classifiers initialized with `.audioStream` mode, ensure that this method is called
* after all audio blocks in an audio stream are sent for inference using
* `classifyAsync(audioBlock:timestampInMilliseconds:)`. Otherwise, the audio classifier will not
* process the last audio block (of type `AudioData`) in the stream if its `bufferLength` is shorter
* than the model's input length. Once an audio classifier is closed, you cannot send any inference
* requests to it. You must create a new instance of `AudioClassifier` to send any pending requests.
* Ensure that you are ready to dispose off the audio classifier before this method is invoked.
*
* @return Returns successfully if the task was closed. Otherwise, throws an error
* indicating the reason for failure.
*/
- (BOOL)closeWithError:(NSError **)error;

- (instancetype)init NS_UNAVAILABLE;

/**
Expand All @@ -169,7 +185,8 @@ NS_SWIFT_NAME(AudioClassifier)
+ (MPPAudioRecord *)createAudioRecordWithChannelCount:(NSUInteger)channelCount
sampleRate:(double)sampleRate
bufferLength:(NSUInteger)bufferLength
error:(NSError **)error;
error:(NSError **)error
NS_SWIFT_NAME(createAudioRecord(channelCount:sampleRate:bufferLength:));

+ (instancetype)new NS_UNAVAILABLE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ + (MPPAudioRecord *)createAudioRecordWithChannelCount:(NSUInteger)channelCount
error:error];
}

- (BOOL)closeWithError:(NSError **)error {
return [_audioTaskRunner closeWithError:error];
}

#pragma mark - Private

- (void)processAudioStreamResult:(absl::StatusOr<PacketMap>)audioStreamResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,21 @@ + (nullable MPPAudioClassifierResult *)audioClassifierResultWithClassificationsP
NSInteger timestampInMilliseconds =
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);

if (!packet.ValidateAsType<std::vector<ClassificationResultProto>>().ok()) {
std::vector<ClassificationResultProto> cppClassificationResults;
if (packet.ValidateAsType<ClassificationResultProto>().ok()) {
// If `runningMode = .audioStream`, only a single `ClassificationResult` will be returned in the
// result packet.
cppClassificationResults.emplace_back(packet.Get<ClassificationResultProto>());
} else if (packet.ValidateAsType<std::vector<ClassificationResultProto>>().ok()) {
// If `runningMode = .audioStream`, a vector of timestamped `ClassificationResult`s will be
// returned in the result packet.
cppClassificationResults = packet.Get<std::vector<ClassificationResultProto>>();
} else {
// If packet does not contain protobuf of a type expected by the audio classifier.
return [[MPPAudioClassifierResult alloc] initWithClassificationResults:@[]
timestampInMilliseconds:timestampInMilliseconds];
}

std::vector<ClassificationResultProto> cppClassificationResults =
packet.Get<std::vector<ClassificationResultProto>>();

NSMutableArray<MPPClassificationResult *> *classificationResults =
[NSMutableArray arrayWithCapacity:cppClassificationResults.size()];

Expand Down
2 changes: 2 additions & 0 deletions mediapipe/tasks/ios/test/audio/audio_classifier/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ objc_library(
"//mediapipe/tasks/ios/audio/core:MPPAudioData",
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/test/audio/core/utils:AVAudioFileTestUtils",
"//mediapipe/tasks/ios/test/audio/core/utils:AVAudioPCMBufferTestUtils",
"//mediapipe/tasks/ios/test/audio/core/utils:MPPAudioDataTestUtils",
"//mediapipe/tasks/ios/test/utils:MPPFileInfo",
"//third_party/apple_frameworks:XCTest",
],
Expand Down

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions mediapipe/tasks/ios/test/audio/core/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,12 @@ objc_library(
"//third_party/apple_frameworks:Foundation",
],
)

objc_library(
name = "MPPAudioDataTestUtils",
srcs = ["sources/MPPAudioData+TestUtils.m"],
hdrs = ["sources/MPPAudioData+TestUtils.h"],
deps = [
"//mediapipe/tasks/ios/audio/core:MPPAudioData",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ @implementation AVAudioFile (TestUtils)
MPPAudioData *audioData =
[[MPPAudioData alloc] initWithFormat:audioDataFormat
sampleCount:lengthToBeLoaded / audioDataFormat.channelCount];

// Can safely access `floatChannelData[0]` since the input file is expected to have atleast 1
// channel.
MPPFloatBuffer *floatBuffer =
[[MPPFloatBuffer alloc] initWithData:audioPCMBuffer.floatChannelData[currentPosition]
[[MPPFloatBuffer alloc] initWithData:audioPCMBuffer.floatChannelData[0] + currentPosition
length:lengthToBeLoaded];
[audioData loadBuffer:floatBuffer offset:0 length:floatBuffer.length error:nil];

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/audio/core/sources/MPPAudioData.h"

NS_ASSUME_NONNULL_BEGIN

/** Helper utility for initializing `MPPAudioData` for MediaPipe iOS audio library tests. */
@interface MPPAudioData (TestUtils)

/**
* Initializes an `MPPAudioData` from channel count, sample rate and sample count.
*
* @param channelCount Number of channels.
* @param sampleRate Sample rate.
* @param sampleCount Sample count.
*
* @return The `MPPAudioData` object with the specified channel count, sample rate and sample count.
*/
- (instancetype)initWithChannelCount:(NSUInteger)channelCount
sampleRate:(double)sampleRate
sampleCount:(NSUInteger)sampleCount;

@end

NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#import "mediapipe/tasks/ios/test/audio/core/utils/sources/MPPAudioData+TestUtils.h"

@implementation MPPAudioData (TestUtils)

- (instancetype)initWithChannelCount:(NSUInteger)channelCount
sampleRate:(double)sampleRate
sampleCount:(NSUInteger)sampleCount {
MPPAudioDataFormat *audioDataFormat =
[[MPPAudioDataFormat alloc] initWithChannelCount:channelCount sampleRate:sampleRate];
return [self initWithFormat:audioDataFormat sampleCount:sampleCount];
}

@end

0 comments on commit fae654f

Please sign in to comment.