Skip to content

Commit

Permalink
Do not free response in PredictAsync callback
Browse files Browse the repository at this point in the history
This breaks Flutter, which runs all callbacks asynchronously.

This is breaking change for this experimental API.

Tested with
`INPUT_PROMPT="Q: What is the tallest building in Paris? \
A:" bash third_party/odml/infra/genai/inference/c/run_llm_inference_api.sh`
PiperOrigin-RevId: 631827309
  • Loading branch information
schmidt-sebastian authored and copybara-github committed May 8, 2024
1 parent 3c7bde9 commit 03d9901
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
5 changes: 3 additions & 2 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ ODML_EXPORT LlmResponseContext LlmInferenceEngine_Session_PredictSync(

// Run callback function in async mode.
// The callback will be invoked multiple times until `response_context.done`
// is `true`.
// is `true`. You need to invoke `LlmInferenceEngine_CloseResponseContext` after
// each invocation to free memory.
// The callback context can be a pointer to any user defined data structure as
// it is passed to the callback unmodified.
ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync(
LlmInferenceEngine_Session* session, void* callback_context,
const char* input,
void (*callback)(void* callback_context,
const LlmResponseContext response_context));
LlmResponseContext* response_context));

// Tokenizes an input prompt using a pre-existing processor and returns its
// length in tokens. Returns -1 if tokenization fails.
Expand Down
14 changes: 6 additions & 8 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ void LlmInferenceEngine_Session_PredictAsync(
LlmInferenceEngine_Session* session, void* callback_context,
const char* input,
void (*callback)(void* callback_context,
const LlmResponseContext response_context)) {
LlmResponseContext* response_context)) {
auto cpu_session = reinterpret_cast<LlmInferenceEngineCpu_Session*>(session);

cpu_session->cpu_callback = [=](std::string responses) -> void {
Expand All @@ -376,13 +376,11 @@ void LlmInferenceEngine_Session_PredictAsync(
}

snprintf(result[0], responses.size() + 1, "%s", responses.c_str());
LlmResponseContext response_context = {
.response_array = result,
.response_count = 1,
.done = cpu_session->early_stop,
};
callback(callback_context, response_context);
LlmInferenceEngine_CloseResponseContext(&response_context);
auto response_context = std::make_unique<LlmResponseContext>();
response_context->response_array = result,
response_context->response_count = 1,
response_context->done = cpu_session->early_stop;
callback(callback_context, response_context.release());
};

cpu_session->prompt = input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ ABSL_FLAG(
namespace {

// Only cout the first response
void async_callback_print(void*, const LlmResponseContext response_context) {
std::cout << response_context.response_array[0] << std::flush;
void async_callback_print(void*, LlmResponseContext* response_context) {
std::cout << response_context->response_array[0] << std::flush;
LlmInferenceEngine_CloseResponseContext(response_context);
}

} // namespace
Expand Down
13 changes: 9 additions & 4 deletions mediapipe/tasks/ios/genai/core/sources/LlmTaskRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ final class LlmTaskRunner {
let returnCode = withUnsafePointer(to: sessionConfig) {
LlmInferenceEngine_CreateSession($0, &self.cLlmSession, &cErrorMessage)
}
if (returnCode != 0) {
if returnCode != 0 {
let errorMessage: String? = cErrorMessage == nil ? nil : String(cString: cErrorMessage!)
throw GenAiInferenceError.failedToInitializeSession(errorMessage)
}
Expand Down Expand Up @@ -82,23 +82,28 @@ final class LlmTaskRunner {
guard let cContext = context else {
return
}
guard let cResponse = responseContext?.pointee else {
return
}

/// `takeRetainedValue()` decrements the reference count incremented by `passRetained()`. Only
/// take a retained value if the LLM has finished generating responses to prevent the context
/// from being deallocated in between response generation.
let cCallbackInfo =
responseContext.done
cResponse.done
? Unmanaged<CallbackInfo>.fromOpaque(cContext).takeRetainedValue()
: Unmanaged<CallbackInfo>.fromOpaque(cContext).takeUnretainedValue()

if let responseStrings = LlmTaskRunner.responseStrings(from: responseContext) {
if let responseStrings = LlmTaskRunner.responseStrings(from: cResponse) {
cCallbackInfo.progress(responseStrings, nil)
} else {
cCallbackInfo.progress(nil, GenAiInferenceError.invalidResponse)
}

LlmInferenceEngine_CloseResponseContext(responseContext)

/// Call completion callback if LLM has generated its last response.
if responseContext.done {
if cResponse.done {
cCallbackInfo.completion()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jbyteArray ToByteArray(JNIEnv* env, const LlmResponseContext& context) {
}

void ProcessAsyncResponse(void* callback_ref,
const LlmResponseContext response_context) {
LlmResponseContext* response_context) {
jobject object_ref = reinterpret_cast<jobject>(callback_ref);
JNIEnv* env = GetJNIEnv();
if (env == nullptr) {
Expand All @@ -95,7 +95,9 @@ void ProcessAsyncResponse(void* callback_ref,
jmethodID method_id =
env->GetMethodID(class_ref, method_name.c_str(), "([B)V");

const jbyteArray response_context_bytes = ToByteArray(env, response_context);
const jbyteArray response_context_bytes = ToByteArray(env, *response_context);
LlmInferenceEngine_CloseResponseContext(response_context);

env->CallVoidMethod(object_ref, method_id, response_context_bytes);
}

Expand Down

0 comments on commit 03d9901

Please sign in to comment.