diff --git a/mediapipe/tasks/ios/genai/core/BUILD b/mediapipe/tasks/ios/genai/core/BUILD index 5caba7098e..7c9e83b64d 100644 --- a/mediapipe/tasks/ios/genai/core/BUILD +++ b/mediapipe/tasks/ios/genai/core/BUILD @@ -18,14 +18,16 @@ licenses(["notice"]) package(default_visibility = ["//mediapipe/tasks:internal"]) exports_files([ - "sources/LlmTaskRunner.swift", "sources/GenAiInferenceError.swift", + "sources/LlmSessionRunner.swift", + "sources/LlmTaskRunner.swift", ]) swift_library( name = "LlmTaskRunner", srcs = [ "sources/GenAiInferenceError.swift", + "sources/LlmSessionRunner.swift", "sources/LlmTaskRunner.swift", ], # This ensures the compiler does not complain about MediaPipeTasksGenAIC being built separately. diff --git a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift index f4ebd1941b..db6e535392 100644 --- a/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift +++ b/mediapipe/tasks/ios/genai/core/sources/GenAiInferenceError.swift @@ -22,6 +22,7 @@ public enum GenAiInferenceError: Error { case failedToInitializeSession(String?) case failedToInitializeEngine(String?) case failedToAddQueryToSession(String, String?) + case failedToCloneSession(String?) } extension GenAiInferenceError: LocalizedError { @@ -31,19 +32,27 @@ extension GenAiInferenceError: LocalizedError { case .invalidResponse: return "The response returned by the model is invalid." case .illegalMethodCall: - return "Response generation is already in progress." + return + """ + Response generation is already in progress. The request in progress may have been \ + initated on the current session or on one of the sessions created from the `LlmInference` \ + that was used to create the current session. + """ case .failedToComputeSizeInTokens(let message): - let explanation = message.flatMap { $0 } ?? "An internal error occured." + let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to compute size of text in tokens: \(explanation)" case .failedToInitializeSession(let message): - let explanation = message.flatMap { $0 } ?? "An internal error occured." + let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to initialize LlmInference session: \(explanation)" case .failedToInitializeEngine(let message): - let explanation = message.flatMap { $0 } ?? "An internal error occured." + let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to initialize LlmInference engine: \(explanation)" case .failedToAddQueryToSession(let query, let message): - let explanation = message.flatMap { $0 } ?? "An internal error occured." + let explanation = message.flatMap { $0 } ?? "An internal error occurred." return "Failed to add query: \(query) to LlmInference session: \(explanation)" + case .failedToCloneSession(let message): + let explanation = message.flatMap { $0 } ?? "An internal error occurred." + return "Failed to clone LlmInference session: \(explanation)" } } } @@ -68,6 +77,8 @@ extension GenAiInferenceError: CustomNSError { return 4 case .failedToAddQueryToSession: return 5 + case .failedToCloneSession: + return 6 } } } diff --git a/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift b/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift new file mode 100644 index 0000000000..b46e2b750a --- /dev/null +++ b/mediapipe/tasks/ios/genai/core/sources/LlmSessionRunner.swift @@ -0,0 +1,222 @@ +// Copyright 2024 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import MediaPipeTasksGenAIC + +/// This class is used to create and call appropriate methods on the C `LlmInferenceEngine_Session` +/// to initialize, execute and terminate any MediaPipe `LlmInference.Session`. +final class LlmSessionRunner { + typealias CLlmSession = UnsafeMutableRawPointer + + /// The underlying C LLM session managed by this `LlmSessionRunner`. + private var cLlmSession: CLlmSession? + + /// Creates a new instance of `LlmSessionRunner` with the given C LLM session. + /// + /// - Parameters: + /// - cLlmSession: A session created by a C LLM engine. + init(cLlmSession: UnsafeMutableRawPointer) { + self.cLlmSession = cLlmSession + } + + /// Adds query chunk to the C LLM session. This can be called multiple times to add multiple query + /// chunks before calling `predict` or `predictAsync`. The query chunks will be processed in the + /// order they are added, similar to a concatenated prompt, but able to be processed in chunks. + /// + /// - Parameters: + /// - inputText: Query chunk to be added to the C session. + /// - Throws: An error if query chunk could not be added successfully. + func addQueryChunk(inputText: String) throws { + var cErrorMessage: UnsafeMutablePointer? = nil + + guard + (inputText.withCString { cInputText in + LlmInferenceEngine_Session_AddQueryChunk(cLlmSession, cInputText, &cErrorMessage) + }) == StatusCode.success.rawValue + else { + throw GenAiInferenceError.failedToAddQueryToSession( + inputText, String(allocatedCErrorMessage: cErrorMessage)) + } + } + + /// Invokes the C LLM session with the previously added query chunks synchronously to generate an + /// array of `String` responses from the LLM. + /// + /// - Returns: Array of `String` responses from the LLM. + /// - Throws: An error if the LLM's response is invalid. + func predict() throws -> [String] { + /// No safe guards for the call since the C++ APIs only throw fatal errors. + /// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the + /// call completes. + var responseContext = LlmInferenceEngine_Session_PredictSync(cLlmSession) + + defer { + withUnsafeMutablePointer(to: &responseContext) { + LlmInferenceEngine_CloseResponseContext($0) + } + } + + /// Throw an error if response is invalid. + guard let responseStrings = LlmSessionRunner.responseStrings(from: responseContext) else { + throw GenAiInferenceError.invalidResponse + } + + return responseStrings + } + + /// Invokes the C LLM session with the previously added query chunks asynchronously to generate an + /// array of `String` responses from the LLM. The `progress` callback returns the partial + /// responses from the LLM or any errors. `completion` callback is invoked once the LLM is done + /// generating responses. + /// + /// - Parameters: + /// - progress: A callback invoked when a partial response is available from the C LLM Session. + /// - completion: A callback invoked when the C LLM Session finishes response generation. + /// - Throws: An error if the LLM's response is invalid. + func predictAsync( + progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void, + completion: @escaping (() -> Void) + ) { + let callbackInfo = CallbackInfo(progress: progress, completion: completion) + let callbackContext = UnsafeMutableRawPointer(Unmanaged.passRetained(callbackInfo).toOpaque()) + + LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext) { + context, responseContext in + guard let cContext = context else { + return + } + + guard let cResponse = responseContext?.pointee else { + /// This failure is unlikely to happen. But throwing an error for the sake of completeness. + /// + /// If `responseContext` is nil, we have no way of knowing whether this was the last + /// response. The code assumes that this was not the last response and lets the context + /// continue in memory by taking an unretained value for it. This is to ensure that the + /// context pointer returned by C in the subsequent callbacks is not dangling, thereby + /// avoiding a seg fault. This has the downside that the context would continue indefinitely + /// in memory if it was indeed the last response. The context would never get cleaned up. + /// This will only be a problem if the failure happens on too many calls to `predictAsync` + /// and leads to an out of memory error. + let cCallbackInfo = Unmanaged.fromOpaque(cContext).takeUnretainedValue() + cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse) + return + } + + /// `takeRetainedValue()` decrements the reference count incremented by `passRetained()`. Only + /// take a retained value if the LLM has finished generating responses to prevent the context + /// from being deallocated in between response generation. + let cCallbackInfo = + cResponse.done + ? Unmanaged.fromOpaque(cContext).takeRetainedValue() + : Unmanaged.fromOpaque(cContext).takeUnretainedValue() + + if let responseStrings = LlmSessionRunner.responseStrings(from: cResponse) { + cCallbackInfo.progress(responseStrings, nil) + } else { + cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse) + } + + LlmInferenceEngine_CloseResponseContext(responseContext) + + /// Call completion callback if LLM has generated its last response. + if cResponse.done { + cCallbackInfo.completion() + } + } + } + + /// Invokes the C LLM session to tokenize an input prompt using a pre-existing processor and + /// returns its length in tokens. + /// + /// - Parameters: + /// - text: An input prompt. + /// - Returns: Length of the input prompt in tokens. + /// - Throws: An error if the number of tokens in the input prompt cannot be calculated. + func sizeInTokens(text: String) throws -> Int { + var cErrorMessage: UnsafeMutablePointer? + + let sizeInTokens = text.withCString { cText in + LlmInferenceEngine_Session_SizeInTokens(cLlmSession, cText, &cErrorMessage) + } + + guard sizeInTokens > -1 else { + throw GenAiInferenceError.failedToComputeSizeInTokens( + String(allocatedCErrorMessage: cErrorMessage)) + } + + return Int(sizeInTokens) + } + + /// Creates a clone of the current instance of `LlmSessionRunner` by cloning the underlying C + /// LLM session. + /// + /// - Returns: Cloned `LlmSessionRunner`. + /// - Throws: An error if the underlying C LLM session could not be cloned. + func clone() throws -> LlmSessionRunner { + var clonedCLlmSession: UnsafeMutableRawPointer? + var cErrorMessage: UnsafeMutablePointer? = nil + guard + LlmInferenceEngine_Session_Clone(cLlmSession, &clonedCLlmSession, &cErrorMessage) + == StatusCode.success.rawValue, + let clonedCLlmSession + else { + throw GenAiInferenceError.failedToCloneSession(String(allocatedCErrorMessage: cErrorMessage)) + } + + return LlmSessionRunner(cLlmSession: clonedCLlmSession) + } + + deinit { + LlmInferenceEngine_Session_Delete(cLlmSession) + } +} + +extension LlmSessionRunner { + /// A wrapper class whose object will be used as the C++ callback context. + /// The progress and completion callbacks cannot be invoked without a context. + class CallbackInfo { + typealias ProgressCallback = (_ partialResult: [String]?, _ error: Error?) -> Void + typealias CompletionCallback = () -> Void + + let progress: ProgressCallback + let completion: CompletionCallback + + init( + progress: @escaping (ProgressCallback), + completion: @escaping (CompletionCallback) + ) { + self.progress = progress + self.completion = completion + } + } +} + +extension LlmSessionRunner { + private class func responseStrings(from responseContext: LlmResponseContext) -> [String]? { + guard let cResponseArray = responseContext.response_array else { + return nil + } + + var responseStrings: [String] = [] + for responseIndex in 0..? = nil - let returnCodeCreateEngine = withUnsafePointer(to: modelSettings) { - LlmInferenceEngine_CreateEngine($0, &self.cLlmEngine, &cErrorMessage) - } - if returnCodeCreateEngine != 0 { - let errorMessage = cErrorMessage.flatMap { String(cString: $0) } - throw GenAiInferenceError.failedToInitializeEngine(errorMessage) - } - cErrorMessage = nil - let returnCodeCreateSession = withUnsafePointer(to: sessionConfig) { - LlmInferenceEngine_CreateSession(self.cLlmEngine, $0, &self.cLlmSession, &cErrorMessage) - } - if returnCodeCreateSession != 0 { - let errorMessage = cErrorMessage.flatMap { String(cString: $0) } - throw GenAiInferenceError.failedToInitializeSession(errorMessage) + guard + (withUnsafePointer(to: modelSettings) { + LlmInferenceEngine_CreateEngine($0, &self.cLlmEngine, &cErrorMessage) + }) == StatusCode.success.rawValue + else { + throw GenAiInferenceError.failedToInitializeEngine( + String(allocatedCErrorMessage: cErrorMessage)) } } - /// Invokes the C inference engine with the given input text to generate an array of `String` - /// responses from the LLM. + /// Creates a new C LLM session from the current C engine and returns an `LlmSessionRunner` + /// that wraps around the newly created C session. The session runner is responsible for managing + /// its underlying C session. + /// + /// Note: On each invocation, this method returns a new instance of the session runner configured + /// to the values provided in the session config. Thus, if you provide the session config of a + /// currently active LLM session, this method will create and return a duplicate session runner + /// configured to the same values. The task runner does not keep track of the currently active + /// session runners. /// /// - Parameters: - /// - inputText: A `String` that is used to query the LLM. - /// - Throws: An error if the LLM's response is invalid. - func predict(inputText: String) throws -> [String] { - /// No safe guards for the call since the C++ APIs only throw fatal errors. - /// `LlmInferenceEngine_Session_PredictSync()` will always return a `LlmResponseContext` if the - /// call completes. - var cErrorMessage: UnsafeMutablePointer? = nil - var returnCode = inputText.withCString { cInputText in - LlmInferenceEngine_Session_AddQueryChunk(cLlmSession, cInputText, &cErrorMessage) - } - if returnCode != 0 { - let errorMessage = cErrorMessage.flatMap { String(cString: $0) } - throw GenAiInferenceError.failedToAddQueryToSession(inputText, errorMessage) - } - var responseContext = LlmInferenceEngine_Session_PredictSync(cLlmSession) - - defer { - withUnsafeMutablePointer(to: &responseContext) { - LlmInferenceEngine_CloseResponseContext($0) - } - } - - /// Throw an error if response is invalid. - guard let responseStrings = LlmTaskRunner.responseStrings(from: responseContext) else { - throw GenAiInferenceError.invalidResponse - } - - return responseStrings - } - - func predict( - inputText: String, progress: @escaping (_ partialResult: [String]?, _ error: Error?) -> Void, - completion: @escaping (() -> Void) - ) throws { - var cErrorMessage: UnsafeMutablePointer? = nil - var returnCode = inputText.withCString { cInputText in - LlmInferenceEngine_Session_AddQueryChunk(cLlmSession, cInputText, &cErrorMessage) - } - if returnCode != 0 { - let errorMessage = cErrorMessage.flatMap { String(cString: $0) } - throw GenAiInferenceError.failedToAddQueryToSession(inputText, errorMessage) - } - - /// `strdup(inputText)` prevents input text from being deallocated as long as callbacks are - /// being invoked. `CallbackInfo` takes care of freeing the memory of `inputText` when it is - /// deallocated. - let callbackInfo = CallbackInfo( - inputText: strdup(inputText), progress: progress, completion: completion) - let callbackContext = UnsafeMutableRawPointer(Unmanaged.passRetained(callbackInfo).toOpaque()) - - LlmInferenceEngine_Session_PredictAsync(cLlmSession, callbackContext) { - context, responseContext in - guard let cContext = context else { - return - } - guard let cResponse = responseContext?.pointee else { - return - } - - /// `takeRetainedValue()` decrements the reference count incremented by `passRetained()`. Only - /// take a retained value if the LLM has finished generating responses to prevent the context - /// from being deallocated in between response generation. - let cCallbackInfo = - cResponse.done - ? Unmanaged.fromOpaque(cContext).takeRetainedValue() - : Unmanaged.fromOpaque(cContext).takeUnretainedValue() - - if let responseStrings = LlmTaskRunner.responseStrings(from: cResponse) { - cCallbackInfo.progress(responseStrings, nil) - } else { - cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse) - } - - LlmInferenceEngine_CloseResponseContext(responseContext) - - /// Call completion callback if LLM has generated its last response. - if cResponse.done { - cCallbackInfo.completion() - } - } - } - - func sizeInTokens(text: String) throws -> Int { + /// - sessionConfig: C session config of type `LlmSessionConfig` that configures how to execute + /// the model. + /// - Returns: A new instance of `LlmSessionRunner`. + /// - Throws: An error if the engine could not be initialized. + func createSessionRunner(sessionConfig: LlmSessionConfig) throws -> LlmSessionRunner { var cErrorMessage: UnsafeMutablePointer? + var cLlmSession: UnsafeMutableRawPointer? - let sizeInTokens = text.withCString { cText in - LlmInferenceEngine_Session_SizeInTokens(cLlmSession, cText, &cErrorMessage) - } - - guard sizeInTokens > -1 else { - var errorMessage: String? - if let cErrorMessage { - errorMessage = String(cString: cErrorMessage) - free(cErrorMessage) - } - - throw GenAiInferenceError.failedToComputeSizeInTokens(errorMessage) + guard + (withUnsafePointer(to: sessionConfig) { + LlmInferenceEngine_CreateSession(cLlmEngine, $0, &cLlmSession, &cErrorMessage) + }) == StatusCode.success.rawValue, + let cLlmSession + else { + throw GenAiInferenceError.failedToInitializeSession( + String(allocatedCErrorMessage: cErrorMessage)) } - return Int(sizeInTokens) + let llmSessionRunner = LlmSessionRunner(cLlmSession: cLlmSession) + return llmSessionRunner } deinit { - LlmInferenceEngine_Session_Delete(cLlmSession) LlmInferenceEngine_Engine_Delete(cLlmEngine) } } -extension LlmTaskRunner { - /// A wrapper class whose object will be used as the C++ callback context. - /// The progress and completion callbacks cannot be invoked without a context. - class CallbackInfo { - typealias ProgressCallback = (_ partialResult: [String]?, _ error: Error?) -> Void - typealias CompletionCallback = () -> Void - - let inputText: UnsafeMutablePointer? - let progress: ProgressCallback - let completion: CompletionCallback - - init( - inputText: UnsafeMutablePointer?, progress: @escaping (ProgressCallback), - completion: @escaping (CompletionCallback) - ) { - self.inputText = inputText - self.progress = progress - self.completion = completion +extension String { + init?(allocatedCErrorMessage: UnsafeMutablePointer?) { + guard let allocatedCErrorMessage else { + return nil } - deinit { - free(inputText) - } + self.init(cString: allocatedCErrorMessage) + free(allocatedCErrorMessage) } } -extension LlmTaskRunner { - private class func responseStrings(from responseContext: LlmResponseContext) -> [String]? { - guard let cResponseArray = responseContext.response_array else { - return nil - } - - var responseStrings: [String] = [] - for responseIndex in 0.. String { + /// Disallow response generation if another response generation call initiated by any + /// `LlmInference` used to create the current session is already in progress. + /// + /// TODO: If simultaneous response generations on multiple sessions or the same session + /// are allowed to happen it leads to a crash. Investigate if this can be handled by C++. + try llmInference.shouldContinueWithResponseGeneration() + + defer { + llmInference.markResponseGenerationCompleted() + } + + let tokens = try llmSessionRunner.predict() + + guard let humanReadableLlmResponse = Session.humanReadableString(llmResponses: tokens) + else { + throw GenAiInferenceError.invalidResponse + } + + return humanReadableLlmResponse + } + + /// Generates a response based on the previously added query chunks asynchronously. The + /// `progress` callback returns the partial responses from the LLM or any errors. + /// `completion` callback is invoked once the LLM is done generating responses. + /// Use `addQueryChunk(inputText:)` to add at least one query chunk before calling this function. + /// Note: You cannot invoke simultaneous response generation calls on active sessions created + /// using the same `LlmInference`. You have to wait for the currently running response + /// generation call to complete before initiating another one. + /// + /// - Parameters: + /// - progress: A callback invoked when a partial response is available from the LLM. + /// - completion: A callback invoked when the LLM finishes response generation. + /// - Throws: An error if the LLM's response is invalid or if a response generation is + /// currently in progress on any session initialized from the `LlmInference` used to create + /// this session. + @objc public func generateResponseAsync( + progress: @escaping (_ partialResponse: String?, _ error: Error?) -> Void, + completion: @escaping (() -> Void) + ) throws { + /// Disallow response generation if another response generation call initiated by any + /// `LlmInference` used to create the current session is already in progress. + /// + /// TODO: If simultaneous response generations on multiple sessions or the same session + /// are allowed to happen it leads to a crash. Investigate if this can be handled by C++. + try llmInference.shouldContinueWithResponseGeneration() + + /// Used to make a decision about whitespace stripping. + var receivedFirstToken = true + + llmSessionRunner.predictAsync( + progress: { partialResponseStrings, error in + guard let responseStrings = partialResponseStrings, + let humanReadableLlmResponse = Session.humanReadableString( + llmResponses: responseStrings, stripLeadingWhitespaces: receivedFirstToken) + else { + progress(nil, GenAiInferenceError.invalidResponse) + return + } + + /// Reset state after first response is processed. + receivedFirstToken = false + + progress(humanReadableLlmResponse, nil) + }, + completion: { [weak self] in + self?.llmInference.markResponseGenerationCompleted() + completion() + }) + } + + /// Generates a response based on the previously added query chunks asynchronously. + /// Use `addQueryChunk(inputText:)` to add at least one query chunk before calling this + /// function. + /// Note: You cannot invoke simultaneous response generation calls on active sessions created + /// using the same `LlmInference`. You have to wait for the currently running response + /// generation call to complete before initiating another one. + /// + /// + /// - Returns: An async throwing stream that contains the partial responses from the LLM. + /// If a response generation is currently in progress on any session initialized from the + /// `LlmInference` used to create this session, the async throwing stream finishes by + /// throwing an error. + @available(iOS 13, macOS 10.15, tvOS 13, watchOS 6, *) + public func generateResponseAsync() -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + do { + try generateResponseAsync( + progress: { partialResponse, error in + if let error { + continuation.finish(throwing: error) + } else if let partialResponse { + continuation.yield(partialResponse) + } + }, + completion: { + continuation.finish() + }) + } catch { + continuation.finish(throwing: error) + } + } + } + + /// Returns the size in tokens of the provided text. + /// You may use this function to verify the size before submitting the prompt to ensure it + /// doesn't exceed the configured maximum token size. + /// + /// - Parameters: + /// - text: The input text whose size in tokens is to be counted. + /// - Returns: The size in tokens of the provided text. + /// - Throws: An error if calculating the size in tokens of the provided text fails. + public func sizeInTokens(text: String) throws -> Int { + return try llmSessionRunner.sizeInTokens(text: text) + } + + /// Clones the current session. + /// You can continue prompting the LLM from where you left off using the cloned session. + /// + /// - Returns: A new instance of `Session` which is cloned from the current session. + /// - Throws: An error if cloning the current session fails. + public func clone() throws -> Session { + let clonedSessionRunner = try llmSessionRunner.clone() + return Session(llmSessionRunner: clonedSessionRunner, llmInference: self.llmInference) + } + + private static func humanReadableString( + llmResponses: [String], stripLeadingWhitespaces: Bool = true + ) -> String? { + if llmResponses.isEmpty { + return "" + } + guard let llmResponse = llmResponses.first else { + return nil + } + return llmResponse.humanReadableString(stripLeadingWhitespaces: stripLeadingWhitespaces) + } + + } +} + +// Extension to `LlmInference.Session` for defining `LlmInference.Session.Options` +extension LlmInference.Session { + /// Options for setting up a `LlmInference.Session`. + /// + /// Note: Inherits from `NSObject` for Objective-C interoperability. + @objc(MPPLLMInferenceSessionOptions) public final class Options: NSObject { + /// The top K number of tokens to be sampled from for each decoding step. A value of 1 means + /// greedy decoding. Defaults to 40. + @objc public var topk: Int = 40 + + /// Maximum cumulative probability over the tokens to sample from in each decoding step for + /// top-p / nucleus sampling. + @objc public var topp: Float = 1.0 + + /// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults + /// to 0.8. + @objc public var temperature: Float = 0.8 + + /// The random seed for sampling tokens. + @objc public var randomSeed: Int = 0 + + /// The optional absolute path to the LoRA model asset bundle stored locally on the device. + /// This is only compatible with GPU models. + @objc public var loraPath: String? + } +} + +/// An extension to `String` to add some utility functions. +extension String { + private static let tokenSplitter = "▁" + /// Note this is NOT an underscore: ▁(U+2581) + private static let newLine = "<0x0A>" + private static let eod = "\\[eod\\]" + + fileprivate func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? { + var humanReadableString = self.replacingOccurrences(of: String.tokenSplitter, with: " ") + .replacingOccurrences(of: String.newLine, with: "\n") + humanReadableString = + stripLeadingWhitespaces + ? humanReadableString.trimmingCharacters(in: .whitespaces) : humanReadableString + return humanReadableString.components(separatedBy: String.eod).first + } +} diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift index 8f429c28fc..55ac6ddfcf 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift @@ -17,6 +17,10 @@ import MediaPipeTasksGenAIC /// A MediaPipe task that performs inference using a given Large Language Model. /// +/// An instance of `LlmInference` will only be deallocated after all sessions created from it are +/// destroyed. This means that an LLM inference can stay in memory even if a reference to it goes +/// out of scope if at least one of its sessions outlives its scope. +/// /// Note: Inherits from `NSObject` for Objective C interoperability. @objc(MPPLLMInference) public final class LlmInference: NSObject { private static let numberOfDecodeStepsPerSync = 3 @@ -26,73 +30,56 @@ import MediaPipeTasksGenAIC private let llmTaskRunner: LlmTaskRunner + /// Serial queue that reads and updates `responseGenerationInProgress` to restrict simultaneous + /// execution of response generation functions across sessions created from this + /// `LlmInference`. private let responseGenerationInProgressQueue = DispatchQueue( - label: LlmInference.responseGenerationInProgressQueueName, - attributes: .concurrent) + label: LlmInference.responseGenerationInProgressQueueName) /// Tracks whether a response generation is in progress. - /// Readers writers lock to prevent race condition as this variable can be accessed from multiple - /// threads. - private var responseGenerationInProgressInternal = false - private var responseGenerationInProgress: Bool { - get { - responseGenerationInProgressQueue.sync { - return self.responseGenerationInProgressInternal - } - } - set { - responseGenerationInProgressQueue.async(flags: .barrier) { - self.responseGenerationInProgressInternal = newValue - } - } - } - private var supportedLoraRanks: UnsafeMutableBufferPointer? + private var responseGenerationInProgress = false /// Creates a new instance of `LlmInference` with the given options. + /// An instance of `LlmInference` will only be deallocated after all sessions created from it are + /// destroyed. This means that an LLM inference can stay in memory even if the reference to it + /// goes out of scope if at least one of its sessions outlives its scope. /// /// - Parameters: - /// - options: The options of type `LlmInference.Options` to use for configuring the `LlmInference`. + /// - options: The options of type `LlmInference.Options` to use for configuring the + /// `LlmInference`. /// - Throws: An error if `LlmInference` instance could not be initialized. @objc public init(options: Options) throws { - let modelPath = strdup(options.modelPath) - let cacheDirectory = strdup(FileManager.default.temporaryDirectory.path) - let loraPath = strdup(options.loraPath == nil ? "" : options.loraPath!) - - defer { - free(modelPath) - free(cacheDirectory) - free(loraPath) - } - var loraRanks: UnsafeMutableBufferPointer? - options.supportedLoraRanks.withUnsafeMutableBufferPointer { pointer in - loraRanks = pointer + let cacheDirectory = FileManager.default.temporaryDirectory.path + + let sequenceBatchSize = LlmInference.sequenceBatchSize + let numberOfDecodeStepsPerSync = LlmInference.numberOfDecodeStepsPerSync + llmTaskRunner = try options.modelPath.withCString { modelPath in + try cacheDirectory.withCString { cacheDirectory in + try options.supportedLoraRanks.withUnsafeMutableBufferPointer { supportedLoraRanks in + let modelSetting = LlmModelSettings( + model_path: modelPath, + cache_dir: cacheDirectory, + max_num_tokens: options.maxTokens, + num_decode_steps_per_sync: numberOfDecodeStepsPerSync, + sequence_batch_size: sequenceBatchSize, + number_of_supported_lora_ranks: supportedLoraRanks.count, + supported_lora_ranks: supportedLoraRanks.baseAddress, + max_top_k: options.maxTopk, + llm_activation_data_type: options.activationDataType.activationDataTypeC, + num_draft_tokens: 0) + return try LlmTaskRunner(modelSettings: modelSetting) + } + } } - supportedLoraRanks = loraRanks - - let modelSetting = LlmModelSettings( - model_path: modelPath, - cache_dir: cacheDirectory, - max_num_tokens: options.maxTokens, - num_decode_steps_per_sync: LlmInference.numberOfDecodeStepsPerSync, - sequence_batch_size: LlmInference.sequenceBatchSize, - number_of_supported_lora_ranks: options.supportedLoraRanks.count, - supported_lora_ranks: supportedLoraRanks?.baseAddress, - max_top_k: options.topk, - llm_activation_data_type: options.activationDataType.activationDataTypeC, - num_draft_tokens: 0) - let sessionConfig = LlmSessionConfig( - topk: options.topk, - topp: 1.0, - temperature: options.temperature, - random_seed: options.randomSeed, - lora_path: loraPath) - try llmTaskRunner = LlmTaskRunner(modelSettings: modelSetting, sessionConfig: sessionConfig) super.init() } /// A convenience initializer that creates a new instance of `LlmInference` from an absolute path /// to a model asset bundle stored locally on the device and the default `LlmInference.Options`. + /// An instance of `LlmInference` will only be deallocated after all sessions created from it are + /// destroyed. This means that an LLM inference can stay in memory even if the reference to it + /// goes out of scope if at least one of its sessions outlives its scope. /// /// - Parameters: /// - modelPath: The absolute path to a model asset bundle stored locally on the device. @@ -102,31 +89,37 @@ import MediaPipeTasksGenAIC try self.init(options: options) } - /// Generates a response based on the input text. + /// Creates and returns a session runner that wraps around a new session created by the underlying + /// LLM engine. + /// + /// - Parameters: + /// - sessionConfig: The C config of type `LlmSessionConfig` that configures how to execute the + /// model. + /// - Returns: + /// - An `LlmSessionRunner` that wraps around a new session. + /// - Throws: An error if the underlying engine could not create a session. + func createSessionRunner(sessionConfig: LlmSessionConfig) throws -> LlmSessionRunner { + return try llmTaskRunner.createSessionRunner(sessionConfig: sessionConfig) + } + + /// Generates a response based on the input text. This function creates a new session for each + /// call. If you want to have a stateful inference, use `LlmInference.Session`'s + /// `generateResponse()` instead. /// /// - Parameters: /// - inputText: A `String` that is used to query the LLM. /// - Throws: An error if the LLM's response is invalid. @objc public func generateResponse(inputText: String) throws -> String { - - /// Disallow response generation if another response generation call is already in progress. - try shouldContinueWithResponseGeneration() - - let tokens = try llmTaskRunner.predict(inputText: inputText) - - responseGenerationInProgress = false - - guard let humanReadableLlmResponse = LlmInference.humanReadableString(llmResponses: tokens) - else { - throw GenAiInferenceError.invalidResponse - } - - return humanReadableLlmResponse + let session = try LlmInference.Session(llmInference: self) + try session.addQueryChunk(inputText: inputText) + return try session.generateResponse() } /// Generates a response based on the input text asynchronously. The `progress` callback returns /// the partial responses from the LLM or any errors. `completion` callback is invoked once the - /// LLM is done generating responses. + /// LLM is done generating responses. This function creates a new session for each call. + /// If you want to have a stateful inference, use `LlmInference.Session`'s + /// `generateResponseAsync(progress: completion:) throws` instead. /// /// - Parameters: /// - progress: A callback invoked when a partial response is available from the LLM. @@ -137,36 +130,14 @@ import MediaPipeTasksGenAIC progress: @escaping (_ partialResponse: String?, _ error: Error?) -> Void, completion: @escaping (() -> Void) ) throws { - /// Disallow response generation if another response generation call is already in progress. - try shouldContinueWithResponseGeneration() - - /// Used to make a decision about whitespace stripping. - var receivedFirstToken = true - - try llmTaskRunner.predict( - inputText: inputText, - progress: { partialResponseStrings, error in - - guard let responseStrings = partialResponseStrings, - let humanReadableLlmResponse = LlmInference.humanReadableString( - llmResponses: responseStrings, stripLeadingWhitespaces: receivedFirstToken) - else { - progress(nil, GenAiInferenceError.invalidResponse) - return - } - - /// Reset state after first response is processed. - receivedFirstToken = false - - progress(humanReadableLlmResponse, nil) - }, - completion: { [weak self] in - self?.responseGenerationInProgress = false - completion() - }) + let session = try LlmInference.Session(llmInference: self) + try session.addQueryChunk(inputText: inputText) + try session.generateResponseAsync(progress: progress, completion: completion) } - /// Generates a response based on the input text asynchronously. + /// Generates a response based on the input text asynchronously. This function creates a new + /// session for each call. If you want to have a stateful inference, use `LlmInference.Session`'s + /// `generateResponseAsync() -> AsyncThrowingStream` instead. /// /// - Parameters: /// - inputText: The prompt used to query the LLM. @@ -175,8 +146,9 @@ import MediaPipeTasksGenAIC public func generateResponseAsync(inputText: String) -> AsyncThrowingStream { AsyncThrowingStream { continuation in do { - try generateResponseAsync( - inputText: inputText, + let session = try LlmInference.Session(llmInference: self) + try session.addQueryChunk(inputText: inputText) + try session.generateResponseAsync( progress: { partialResponse, error in if let error { continuation.finish(throwing: error) @@ -193,32 +165,38 @@ import MediaPipeTasksGenAIC } } - /// Returns the size in tokens of the provided text. + /// If no response generation using any session created from this `LlmInference` is currently in + /// progress, this function updates the response generation state to `true` and returns + /// successfully thereby granting access to its caller to execute response generation. + /// If this function throws an error, the invoking session must abort the response generation + /// call. This function must be called before invoking the response generation function on the + /// underlying `LlmSessionRunner`. /// - /// You may use this function to verify this size before submitting the prompt to ensure it - /// doesn't exceed the configured maximum token size. - public func sizeInTokens(text: String) throws -> Int { - return try llmTaskRunner.sizeInTokens(text: text) - } - - /// Throw error if response generation is in progress or update response generation state. - private func shouldContinueWithResponseGeneration() throws { - if responseGenerationInProgress { - throw GenAiInferenceError.illegalMethodCall + /// - Throws: An error if response generation is already in progress. + func shouldContinueWithResponseGeneration() throws { + /// `responseGenerationInProgressQueue` is a serial queue. Executing a sync block on a serial + /// queue ensures that at any time only one call to this function tests and writes the current + /// state of response generation. All other calls are blocked until the state is + /// updated. If the state indicates that response generation is currently in progress, the + /// block throws an error. Since it is a synchronous block that blocks execution until it is + /// complete, the error is in turn propagated as an error thrown by the function. + try responseGenerationInProgressQueue.sync { + if !responseGenerationInProgress { + responseGenerationInProgress = true + } else { + throw GenAiInferenceError.illegalMethodCall + } } - - responseGenerationInProgress = true } - private static func humanReadableString( - llmResponses: [String], stripLeadingWhitespaces: Bool = true - ) -> String? { - guard let llmResponse = llmResponses.first else { - return "" + /// Marks response generation as complete by updating the state to `false`. Any session created + /// using this `LlmInference` must use this function to indicate the completion of response + /// generation using the underlying `LlmSessionRunner`. + func markResponseGenerationCompleted() { + responseGenerationInProgressQueue.sync { + responseGenerationInProgress = false } - return llmResponse.humanReadableString(stripLeadingWhitespaces: stripLeadingWhitespaces) } - } // Extension to `LlmInference` for defining `LlmInference.Options` @@ -234,29 +212,20 @@ extension LlmInference { /// tokens the model needs to handle. @objc public var maxTokens: Int = 512 - /// The top K number of tokens to be sampled from for each decoding step. A value of 1 means - /// greedy decoding. Defaults to 40. - @objc public var topk: Int = 40 - - /// The randomness when decoding the next token. A value of 0.0f means greedy decoding. Defaults - /// to 0.8. - @objc public var temperature: Float = 0.8 - - /// The random seed for sampling tokens. - @objc public var randomSeed: Int = 0 + /// Maximum top k, which is the max Top-K value supported for all sessions created with the + /// `LlmInference`, used by GPU only. If a session with Top-K value larger than this is being + /// asked to be created, it will be rejected(throw error). A value of 1 means only greedy + // decoding is supported for any sessions created with this `LlmInference`. Default value is 40. + @objc public var maxTopk: Int = 40 /// The supported lora ranks for the base model. Used by GPU only. @objc public var supportedLoraRanks: [Int] = [] - /// The absolute path to the LoRA model asset bundle stored locally on the device. Optional. - /// This is only compatible with GPU models. - @objc public var loraPath: String? - /// The activation data type for the model. @objc public var activationDataType: ActivationDataType = .default - /// Creates a new instance of `Options` with the modelPath and default values of - /// `maxTokens`, `topK``, `temperature` and `randomSeed`. + /// Creates a new instance of `Options` with the given `modelPath` and default values of + /// `maxTokens`, `maxTopk`, `supportedLoraRanks` and `activationDataType`. /// This function is only intended to be used from Objective C. /// /// - Parameters: @@ -295,20 +264,3 @@ extension LlmInference.ActivationDataType { } } } - -/// An extension to `String` to add some utility functions. -extension String { - private static let tokenSplitter = "▁" - /// Note this is NOT an underscore: ▁(U+2581) - private static let newLine = "<0x0A>" - private static let eod = "\\[eod\\]" - - fileprivate func humanReadableString(stripLeadingWhitespaces: Bool = true) -> String? { - var humanReadableString = self.replacingOccurrences(of: String.tokenSplitter, with: " ") - .replacingOccurrences(of: String.newLine, with: "\n") - humanReadableString = - stripLeadingWhitespaces - ? humanReadableString.trimmingCharacters(in: .whitespaces) : humanReadableString - return humanReadableString.components(separatedBy: String.eod).first - } -}