Skip to content

Commit

Permalink
Add Vision Modality to the MediaPipe LLM JNI Layer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705521252
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Dec 12, 2024
1 parent 911eb32 commit ad6c03e
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ cc_library(
"//mediapipe/java/com/google/mediapipe/framework/jni:class_registry",
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
"//mediapipe/tasks/cc/genai/inference/c:llm_inference_engine_hdr", # needed with ENABLE_ODML_MAVEN_BUILD
"//mediapipe/tasks/cc/genai/inference/proto:tflite_delegate_options_cc_proto",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_cc_proto",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_cc_proto",
"//third_party/skia/HEAD:core",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
76 changes: 72 additions & 4 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <jni.h>

#include <cstdint>
#include <cstdlib>
#include <memory>
#include <string>

#include "absl/log/absl_log.h"
Expand All @@ -25,8 +27,13 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h"
#include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h"
#include "mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h"
#include "mediapipe/tasks/cc/genai/inference/proto/tflite_delegate_options.pb.h"
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto/llm_options.pb.h"
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto/llm_response_context.pb.h"
#include "third_party/skia/HEAD/include/core/SkAlphaType.h"
#include "third_party/skia/HEAD/include/core/SkBitmap.h"
#include "third_party/skia/HEAD/include/core/SkImage.h"
#include "third_party/skia/HEAD/include/core/SkImageInfo.h"

namespace {

Expand All @@ -37,14 +44,22 @@ using mediapipe::android::JStringToStdString;
using mediapipe::android::ThrowIfError;
using mediapipe::java::GetJNIEnv;

const bool kDefaultIncludeTokenCostCalculator = true;

LlmModelSettings ParseModelSettings(void* bytes, int size) {
LlmModelSettingsProto input;
input.ParseFromArray(bytes, size);

LlmModelSettings output;
output.model_path = strdup(input.model_path().c_str());
output.vision_encoder_path = nullptr;
output.vision_adapter_path = nullptr;
output.vision_encoder_path =
input.vision_model_settings().has_encoder_path()
? strdup(input.vision_model_settings().encoder_path().c_str())
: nullptr;
output.vision_adapter_path =
input.vision_model_settings().has_adapter_path()
? strdup(input.vision_model_settings().adapter_path().c_str())
: nullptr;
output.cache_dir = strdup(input.cache_dir().c_str());
output.sequence_batch_size = input.sequence_batch_size();
output.num_decode_steps_per_sync = input.num_decode_steps_per_sync();
Expand All @@ -57,6 +72,8 @@ LlmModelSettings ParseModelSettings(void* bytes, int size) {
for (int i = 0; i < input.supported_lora_ranks_size(); ++i) {
output.supported_lora_ranks[i] = input.supported_lora_ranks(i);
}
} else {
output.supported_lora_ranks = nullptr;
}
output.llm_activation_data_type = kLlmActivationDataTypeDefault;
output.num_draft_tokens = 0;
Expand All @@ -76,14 +93,20 @@ LlmSessionConfig ParseSessionConfig(void* bytes, int size) {
if (input.has_lora_path()) {
output.lora_path = strdup(input.lora_path().c_str());
}
output.include_token_cost_calculator = true;
output.enable_vision_modality = false;
output.include_token_cost_calculator =
input.graph_config().has_include_token_cost_calculator()
? input.graph_config().include_token_cost_calculator()
: kDefaultIncludeTokenCostCalculator;
output.enable_vision_modality = input.graph_config().enable_vision_modality();
return output;
}

void FreeModelSettings(LlmModelSettings* model_settings) {
delete model_settings->model_path;
delete model_settings->vision_adapter_path;
delete model_settings->vision_encoder_path;
delete model_settings->cache_dir;
delete[] model_settings->supported_lora_ranks;
model_settings->model_path = nullptr;
model_settings->cache_dir = nullptr;
}
Expand Down Expand Up @@ -211,6 +234,20 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeAddQueryChunk)(JNIEnv* env, jclass thiz,
}
}

JNIEXPORT void JNICALL JNI_METHOD(nativeAddImage)(JNIEnv* env, jclass thiz,
jlong session_handle,
jlong image_handle) {
char* error_msg = nullptr;
int error_code = LlmInferenceEngine_Session_AddImage(
reinterpret_cast<void*>(session_handle),
reinterpret_cast<void*>(image_handle), &error_msg);
if (error_code) {
ThrowIfError(env, absl::InternalError(
absl::StrCat("Failed to add image:, %s", error_msg)));
free(error_msg);
}
}

JNIEXPORT jbyteArray JNICALL
JNI_METHOD(nativePredictSync)(JNIEnv* env, jclass thiz, jlong session_handle) {
LlmResponseContext response_context = LlmInferenceEngine_Session_PredictSync(
Expand Down Expand Up @@ -259,3 +296,34 @@ JNIEXPORT jint JNICALL JNI_METHOD(nativeSizeInTokens)(JNIEnv* env, jclass thiz,
}
return size;
}

JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateSkBitmap)(
JNIEnv* env, jclass thiz, jobject byte_buffer, jint width, jint height,
jint color_type, jint alpha_type) {
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InternalError("Cannot get direct access to the "
"input buffer. It should be created "
"using allocateDirect."));
}

SkColorType sk_color_type = static_cast<SkColorType>(color_type);
SkAlphaType sk_alpha_type = static_cast<SkAlphaType>(alpha_type);
SkImageInfo imageInfo =
SkImageInfo::Make(width, height, sk_color_type, sk_alpha_type);

auto bitmap = std::make_unique<SkBitmap>();
bool success =
bitmap->installPixels(imageInfo, buffer_data, imageInfo.minRowBytes());
if (!success) {
ThrowIfError(env, absl::InternalError("Cannot initialize SkBitmap."));
}

return reinterpret_cast<jlong>(bitmap.release());
}

JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteSkBitmap)(JNIEnv*, jclass,
jlong bitmap_handle) {
delete reinterpret_cast<SkBitmap*>(bitmap_handle);
}
25 changes: 25 additions & 0 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteSession)(JNIEnv *, jclass, jlong);
JNIEXPORT void JNICALL JNI_METHOD(nativeAddQueryChunk)(JNIEnv *, jclass, jlong,
jstring);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativeAddImage
* Signature: (JLL)V
*/
JNIEXPORT void JNICALL JNI_METHOD(nativeAddImage)(JNIEnv *, jclass, jlong,
jlong);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativePredictSync
Expand Down Expand Up @@ -109,6 +117,23 @@ JNIEXPORT void JNICALL JNI_METHOD(nativePredictAsync)(JNIEnv *, jclass, jlong,
JNIEXPORT jint JNICALL JNI_METHOD(nativeSizeInTokens)(JNIEnv *, jclass, jlong,
jstring);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativeCreateSkBitmap
* Signature: (Ljava/nio/ByteBuffer;IIII)J
*/
JNIEXPORT jlong JNICALL JNI_METHOD(nativeCreateSkBitmap)(JNIEnv *, jclass,
jobject, jint, jint,
jint, jint);

/*
* Class: com_google_mediapipe_tasks_core_LlmTaskRunner
* Method: nativeDeleteSkBitmap
* Signature: (J)V
*/
JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteSkBitmap)(JNIEnv *, jclass,
jlong);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ message LlmSessionConfig {
// The absolute path to the LoRA model asset bundle stored locally on the
// device. This is only compatible with GPU models.
optional string lora_path = 4;

// Parameters to customize the graph.
message GraphConfig {
// Whether to configure the graph to include the token cost calculator,
// which allows users to only compute the cost of a prompt.
optional bool include_token_cost_calculator = 1;

// Whether to configure the graph to include the vision modality. Only one
// of enable_vision_modality or enable_audio_modality can be true currently.
optional bool enable_vision_modality = 2;
}

// Parameters to customize the graph.
optional GraphConfig graph_config = 5;
}

// Configurable model parameters for creating an LLM inference engine.
Expand Down Expand Up @@ -72,4 +86,15 @@ message LlmModelSettings {
// means only greedy decoding is supported for any sessions created with this
// engine.
uint32 max_top_k = 8;

// A container for vision model related settings.
message VisionModelSettings {
// Path to the vision encoder model file.
optional string encoder_path = 1;

// Path to the vision adapter model file.
optional string adapter_path = 2;
}

optional VisionModelSettings vision_model_settings = 9;
}

0 comments on commit ad6c03e

Please sign in to comment.