Skip to content

Commit

Permalink
Merge pull request #5566 from priankakariatyml:ios-llm-inference-refa…
Browse files Browse the repository at this point in the history
…ctor

PiperOrigin-RevId: 676182779
  • Loading branch information
copybara-github committed Sep 18, 2024
2 parents e8805fc + e1a2141 commit 7946760
Show file tree
Hide file tree
Showing 7 changed files with 689 additions and 314 deletions.
4 changes: 3 additions & 1 deletion mediapipe/tasks/ios/genai/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])

exports_files([
"sources/LlmTaskRunner.swift",
"sources/GenAiInferenceError.swift",
"sources/LlmSessionRunner.swift",
"sources/LlmTaskRunner.swift",
])

swift_library(
name = "LlmTaskRunner",
srcs = [
"sources/GenAiInferenceError.swift",
"sources/LlmSessionRunner.swift",
"sources/LlmTaskRunner.swift",
],
# This ensures the compiler does not complain about MediaPipeTasksGenAIC being built separately.
Expand Down
21 changes: 16 additions & 5 deletions mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public enum GenAiInferenceError: Error {
case failedToInitializeSession(String?)
case failedToInitializeEngine(String?)
case failedToAddQueryToSession(String, String?)
case failedToCloneSession(String?)
}

extension GenAiInferenceError: LocalizedError {
Expand All @@ -31,19 +32,27 @@ extension GenAiInferenceError: LocalizedError {
case .invalidResponse:
return "The response returned by the model is invalid."
case .illegalMethodCall:
return "Response generation is already in progress."
return
"""
Response generation is already in progress. The request in progress may have been \
initated on the current session or on one of the sessions created from the `LlmInference` \
that was used to create the current session.
"""
case .failedToComputeSizeInTokens(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occured."
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to compute size of text in tokens: \(explanation)"
case .failedToInitializeSession(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occured."
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to initialize LlmInference session: \(explanation)"
case .failedToInitializeEngine(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occured."
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to initialize LlmInference engine: \(explanation)"
case .failedToAddQueryToSession(let query, let message):
let explanation = message.flatMap { $0 } ?? "An internal error occured."
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to add query: \(query) to LlmInference session: \(explanation)"
case .failedToCloneSession(let message):
let explanation = message.flatMap { $0 } ?? "An internal error occurred."
return "Failed to clone LlmInference session: \(explanation)"
}
}
}
Expand All @@ -68,6 +77,8 @@ extension GenAiInferenceError: CustomNSError {
return 4
case .failedToAddQueryToSession:
return 5
case .failedToCloneSession:
return 6
}
}
}
222 changes: 222 additions & 0 deletions mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import MediaPipeTasksGenAIC

/// This class is used to create and call appropriate methods on the C `LlmInferenceEngine_Session`
/// to initialize, execute and terminate any MediaPipe `LlmInference.Session`.
final class LlmSessionRunner {
typealias CLlmSession = UnsafeMutableRawPointer

/// The underlying C LLM session managed by this `LlmSessionRunner`.
private var cLlmSession: CLlmSession?

/// Creates a new instance of `LlmSessionRunner` with the given C LLM session.
///
/// - Parameters:
/// - cLlmSession: A session created by a C LLM engine.
init(cLlmSession: UnsafeMutableRawPointer) {
self.cLlmSession = cLlmSession
}

/// Adds query chunk to the C LLM session. This can be called multiple times to add multiple query
/// chunks before calling `predict` or `predictAsync`. The query chunks will be processed in the
/// order they are added, similar to a concatenated prompt, but able to be processed in chunks.
///
/// - Parameters:
/// - inputText: Query chunk to be added to the C session.
/// - Throws: An error if query chunk could not be added successfully.
func addQueryChunk(inputText: String) throws {
var cErrorMessage: UnsafeMutablePointer<CChar>? = nil

guard
(inputText.withCString { cInputText in
LlmInferenceEngine_Session_AddQueryChunk(cLlmSession, cInputText, &cErrorMessage)
}) == StatusCode.success.rawValue
else {
throw GenAiInferenceError.failedToAddQueryToSession(
inputText, String(allocatedCErrorMessage: cErrorMessage))
}
}

/// Invokes the C LLM session with the previously added query chunks synchronously to generate an
/// array of `String` responses from the LLM.
///
/// - Returns: Array of `String` responses from the LLM.
/// - Throws: An error if the LLM's response is invalid.
func predict() throws -> [String] {
/// No safe guards for the call since the C++ APIs only throw fatal errors.
/// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the
/// call completes.
var responseContext = LlmInferenceEngine_Session_PredictSync(cLlmSession)

defer {
withUnsafeMutablePointer(to: &responseContext) {
LlmInferenceEngine_CloseResponseContext($0)
}
}

/// Throw an error if response is invalid.
guard let responseStrings = LlmSessionRunner.responseStrings(from: responseContext) else {
throw GenAiInferenceError.invalidResponse
}

return responseStrings
}

/// Invokes the C LLM session with the previously added query chunks asynchronously to generate an
/// array of `String` responses from the LLM. The `progress` callback returns the partial
/// responses from the LLM or any errors. `completion` callback is invoked once the LLM is done
/// generating responses.
///
/// - Parameters:
/// - progress: A callback invoked when a partial response is available from the C LLM Session.
/// - completion: A callback invoked when the C LLM Session finishes response generation.
/// - Throws: An error if the LLM's response is invalid.
func predictAsync(
progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void,
completion: @escaping (() -> Void)
) {
let callbackInfo = CallbackInfo(progress: progress, completion: completion)
let callbackContext = UnsafeMutableRawPointer(Unmanaged.passRetained(callbackInfo).toOpaque())

LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext) {
context, responseContext in
guard let cContext = context else {
return
}

guard let cResponse = responseContext?.pointee else {
/// This failure is unlikely to happen. But throwing an error for the sake of completeness.
///
/// If `responseContext` is nil, we have no way of knowing whether this was the last
/// response. The code assumes that this was not the last response and lets the context
/// continue in memory by taking an unretained value for it. This is to ensure that the
/// context pointer returned by C in the subsequent callbacks is not dangling, thereby
/// avoiding a seg fault. This has the downside that the context would continue indefinitely
/// in memory if it was indeed the last response. The context would never get cleaned up.
/// This will only be a problem if the failure happens on too many calls to `predictAsync`
/// and leads to an out of memory error.
let cCallbackInfo = Unmanaged<CallbackInfo>.fromOpaque(cContext).takeUnretainedValue()
cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse)
return
}

/// `takeRetainedValue()` decrements the reference count incremented by `passRetained()`. Only
/// take a retained value if the LLM has finished generating responses to prevent the context
/// from being deallocated in between response generation.
let cCallbackInfo =
cResponse.done
? Unmanaged<CallbackInfo>.fromOpaque(cContext).takeRetainedValue()
: Unmanaged<CallbackInfo>.fromOpaque(cContext).takeUnretainedValue()

if let responseStrings = LlmSessionRunner.responseStrings(from: cResponse) {
cCallbackInfo.progress(responseStrings, nil)
} else {
cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse)
}

LlmInferenceEngine_CloseResponseContext(responseContext)

/// Call completion callback if LLM has generated its last response.
if cResponse.done {
cCallbackInfo.completion()
}
}
}

/// Invokes the C LLM session to tokenize an input prompt using a pre-existing processor and
/// returns its length in tokens.
///
/// - Parameters:
/// - text: An input prompt.
/// - Returns: Length of the input prompt in tokens.
/// - Throws: An error if the number of tokens in the input prompt cannot be calculated.
func sizeInTokens(text: String) throws -> Int {
var cErrorMessage: UnsafeMutablePointer<CChar>?

let sizeInTokens = text.withCString { cText in
LlmInferenceEngine_Session_SizeInTokens(cLlmSession, cText, &cErrorMessage)
}

guard sizeInTokens > -1 else {
throw GenAiInferenceError.failedToComputeSizeInTokens(
String(allocatedCErrorMessage: cErrorMessage))
}

return Int(sizeInTokens)
}

/// Creates a clone of the current instance of `LlmSessionRunner` by cloning the underlying C
/// LLM session.
///
/// - Returns: Cloned `LlmSessionRunner`.
/// - Throws: An error if the underlying C LLM session could not be cloned.
func clone() throws -> LlmSessionRunner {
var clonedCLlmSession: UnsafeMutableRawPointer?
var cErrorMessage: UnsafeMutablePointer<CChar>? = nil
guard
LlmInferenceEngine_Session_Clone(cLlmSession, &clonedCLlmSession, &cErrorMessage)
== StatusCode.success.rawValue,
let clonedCLlmSession
else {
throw GenAiInferenceError.failedToCloneSession(String(allocatedCErrorMessage: cErrorMessage))
}

return LlmSessionRunner(cLlmSession: clonedCLlmSession)
}

deinit {
LlmInferenceEngine_Session_Delete(cLlmSession)
}
}

extension LlmSessionRunner {
/// A wrapper class whose object will be used as the C++ callback context.
/// The progress and completion callbacks cannot be invoked without a context.
class CallbackInfo {
typealias ProgressCallback = (_ partialResult: [String]?, _ error: Error?) -> Void
typealias CompletionCallback = () -> Void

let progress: ProgressCallback
let completion: CompletionCallback

init(
progress: @escaping (ProgressCallback),
completion: @escaping (CompletionCallback)
) {
self.progress = progress
self.completion = completion
}
}
}

extension LlmSessionRunner {
private class func responseStrings(from responseContext: LlmResponseContext) -> [String]? {
guard let cResponseArray = responseContext.response_array else {
return nil
}

var responseStrings: [String] = []
for responseIndex in 0..<Int(responseContext.response_count) {
guard let cResponseString = cResponseArray[responseIndex] else {
return nil
}
responseStrings.append(String(cString: cResponseString))
}

return responseStrings
}
}
Loading

0 comments on commit 7946760

Please sign in to comment.