Skip to content

Commit

Permalink
Add Session API support to LLM Java API
Browse files Browse the repository at this point in the history
This is a breaking API change as it moves some Session options to LlmInferenceSessionOptions. I am also moving the API dependency on Guava, as we should not require 3P users to use Guava on mobile.

PiperOrigin-RevId: 678926266
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Sep 26, 2024
1 parent e64c471 commit 83fe5a4
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 132 deletions.
1 change: 0 additions & 1 deletion mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ android_library(
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite",
"//third_party/java/protobuf:protobuf_lite",
"@maven//:androidx_annotation_annotation",
"@maven//:com_google_guava_guava",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package com.google.mediapipe.tasks.core;

import android.content.Context;
import androidx.annotation.Nullable;
import com.google.mediapipe.tasks.core.OutputHandler.ProgressListener;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmModelSettings;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig;
Expand Down Expand Up @@ -122,13 +121,19 @@ public int sizeInTokens(LlmSession session, String text) {
}
}

/** If provided, removes the session and frees up its context. */
public void deleteSession(@Nullable LlmSession session) {
/** Clones the current session. */
public LlmSession cloneSession(LlmSession session) {
validateState();
if (session != null) {
nativeDeleteSession(session.sessionHandle);
statsLogger.logSessionEnd();
}
long clonedSessionHandle = nativeCloneSession(session.sessionHandle);
statsLogger.logSessionClone();
return new LlmSession(clonedSessionHandle);
}

/** Removes the session and frees up its context. */
public void deleteSession(LlmSession session) {
validateState();
nativeDeleteSession(session.sessionHandle);
statsLogger.logSessionEnd();
}

private LlmResponseContext parseResponse(byte[] response) {
Expand All @@ -140,11 +145,11 @@ private LlmResponseContext parseResponse(byte[] response) {
}

private void onAsyncResponse(byte[] responseBytes) {
LlmResponseContext respone = parseResponse(responseBytes);
if (respone.getDone()) {
LlmResponseContext response = parseResponse(responseBytes);
if (response.getDone()) {
isProcessing.set(false);
}
resultListener.get().run(respone.getResponsesList(), respone.getDone());
resultListener.get().run(response.getResponsesList(), response.getDone());
}

@Override
Expand All @@ -167,6 +172,8 @@ private void validateState() {

private static native long nativeCreateSession(byte[] sessionConfig, long enginePointer);

private static native long nativeCloneSession(long sessionPointer);

private static native void nativeDeleteSession(long sessionPointer);

private static native void nativeAddQueryChunk(long sessionPointer, String input);
Expand Down
14 changes: 14 additions & 0 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateSession)(
return reinterpret_cast<jlong>(session);
}

JNIEXPORT jlong JNICALL JNI_METHOD(nativeCloneSession)(JNIEnv* env, jclass thiz,
jlong session_handle) {
void* session = nullptr;
char* error_msg = nullptr;
int error_code = LlmInferenceEngine_Session_Clone(
reinterpret_cast<void*>(session_handle), &session, &error_msg);
if (error_code) {
ThrowIfError(env, absl::InternalError(absl::StrCat(
"Failed to clone session: %s", error_msg)));
free(error_msg);
}
return reinterpret_cast<jlong>(session);
}

JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteSession)(JNIEnv* env, jclass thiz,
jlong session_handle) {
LlmInferenceEngine_Session_Delete(reinterpret_cast<void*>(session_handle));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteEngine)(JNIEnv *, jclass, jlong);
JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateSession)(JNIEnv *, jclass,
jbyteArray, jlong);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativeCloneSession
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL JNI_METHOD(nativeCloneSession)(JNIEnv *, jclass, jlong);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativeDeleteSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ private TasksStatsDummyLogger() {}
@Override
public void logSessionStart() {}

/** Logs the cloning of a MediaPipe Tasks API session. */
@Override
public void logSessionClone() {}

/**
* Records MediaPipe Tasks API receiving CPU input data.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ static StatsSnapshot createDefault() {
/** Logs the start of a MediaPipe Tasks API session. */
public void logSessionStart();

/** Logs the cloning of a MediaPipe Tasks API session. */
public void logSessionClone();

/**
* Records MediaPipe Tasks API receiving CPU input data.
*
Expand Down
Loading

0 comments on commit 83fe5a4

Please sign in to comment.