From f457fe7a8fbf14093243885c7eca808f3b3b4e81 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 12 Jan 2024 14:33:06 -0800 Subject: [PATCH] Migrate TextGenerator Java API to C Wrapper PiperOrigin-RevId: 597956904 --- mediapipe/framework/deps/file_helpers.cc | 5 +- .../com/google/mediapipe/tasks/core/BUILD | 1 + .../mediapipe/tasks/core/LlmTaskRunner.java | 8 +- .../mediapipe/tasks/core/OutputHandler.java | 16 ++- .../com/google/mediapipe/tasks/core/jni/BUILD | 7 +- .../tasks/core/jni/llm_inference_engine.h | 128 ++++++++++++++++++ 6 files changed, 155 insertions(+), 10 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm_inference_engine.h diff --git a/mediapipe/framework/deps/file_helpers.cc b/mediapipe/framework/deps/file_helpers.cc index b356709814..24e522d348 100644 --- a/mediapipe/framework/deps/file_helpers.cc +++ b/mediapipe/framework/deps/file_helpers.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/deps/file_helpers.h" +#include "absl/strings/str_cat.h" + #ifdef _WIN32 #include #include @@ -249,7 +251,8 @@ absl::Status Exists(absl::string_view file_name) { case EACCES: return mediapipe::PermissionDeniedError("Insufficient permissions."); default: - return absl::NotFoundError("The path does not exist."); + return absl::NotFoundError( + absl::StrCat("The path does not exist: ", file_name)); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index f288e774cd..29ca37251f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -76,6 +76,7 @@ android_library( ], manifest = "AndroidManifest.xml", deps = [ + ":core_java", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite", "//third_party/java/protobuf:protobuf_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java index efc9d7e29f..20a72bece7 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java @@ -14,6 +14,7 @@ package com.google.mediapipe.tasks.core; +import com.google.mediapipe.tasks.core.OutputHandler.ValueListener; import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmModelParameters; import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmSessionConfig; import com.google.mediapipe.tasks.core.jni.LlmResponseContextProto.LlmResponseContext; @@ -21,7 +22,6 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.List; import java.util.Optional; -import java.util.function.Function; /** * Internal Task Runner class for all LLM Tasks. @@ -30,13 +30,13 @@ */ public final class LlmTaskRunner implements AutoCloseable { private final long sessionHandle; - private final Optional, Void>> resultListener; + private final Optional>> resultListener; private final long callbackHandle; public LlmTaskRunner( LlmModelParameters modelParameters, LlmSessionConfig sessionConfig, - Optional, Void>> resultListener) { + Optional>> resultListener) { this.sessionHandle = nativeCreateSession(modelParameters.toByteArray(), sessionConfig.toByteArray()); @@ -73,7 +73,7 @@ private List parseResponse(byte[] reponse) { } private void onAsyncResponse(byte[] responseBytes) { - resultListener.get().apply(parseResponse(responseBytes)); + resultListener.get().run(parseResponse(responseBytes)); } @Override diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java index dba9184834..1c204da824 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -32,21 +32,29 @@ public interface OutputPacketConverter { } /** - * Interface for the customizable MediaPipe task result listener that can reteive both task result - * objects and the corresponding input data. + * Interface for the customizable MediaPipe task result listener that can retrieve both task + * result objects and the corresponding input data. */ public interface ResultListener { void run(OutputT result, InputT input); } /** - * Interface for the customizable MediaPipe task result listener that can only reteive task result - * objects. + * Interface for the customizable MediaPipe task result listener that can only retrieve task + * result objects. */ public interface PureResultListener { void run(OutputT result); } + /** + * Interface for the customizable MediaPipe task result listener that only receives a task's + * output value. + */ + public interface ValueListener { + void run(OutputT result); + } + private static final String TAG = "OutputHandler"; // A task-specific graph output packet converter that should be implemented per task. private OutputPacketConverter outputPacketConverter; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD index aebf013467..3c755ce02c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD @@ -40,16 +40,21 @@ cc_library_with_tflite( alwayslink = 1, ) +cc_library( + name = "llm_inference_engine_hdr", + hdrs = ["llm_inference_engine.h"], +) + cc_library( name = "llm", srcs = ["llm.cc"], hdrs = ["llm.h"], deps = [ + ":llm_inference_engine_hdr", "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_cc_proto", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_cc_proto", "//third_party/odml/infra/genai/inference/c:libllm_inference_engine", - "//third_party/odml/infra/genai/inference/c:libllm_inference_engine_deps", "@com_google_absl//absl/status", ] + select({ "//mediapipe:android": [], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm_inference_engine.h b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm_inference_engine.h new file mode 100644 index 0000000000..15f8dbbe55 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm_inference_engine.h @@ -0,0 +1,128 @@ +#ifndef MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_ +#define MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_ + +#include +#include + +#ifndef ODML_EXPORT +#define ODML_EXPORT __attribute__((visibility("default"))) +#endif // ODML_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void LlmInferenceEngine_Session; + +// Supported model types. +enum LlmModelType { + // Unknown + kUNKNOWN_MODEL_TYPE, + + // Falcon with 1B parameters. + kFalcon1B, + + // GMini with 2B parameters. + kGMini2B, +}; + +// Attention types. +enum LlmAttentionType { + // Multi-head Attention. + kMHA, + + // Multi-query Attention. + kMQA, +}; + +// Backend to execute the large language model. +enum LlmBackend { + // CPU + kCPU, + + // GPU + kGPU, +}; + +// LlmModelParameters should accurately describe the model used. +typedef struct { + // Set a supported model types. + enum LlmModelType model_type; + + // Path to the directory that contains spm.model and the weight directory. + const char* model_path; + + // MHA or MQA. + enum LlmAttentionType attention_type; + + // Start token id will be appended to the query before feeding into the model. + int start_token_id; + + // Stop token/word that indicates the response is completed. + const char** stop_tokens; + + // Number of stop tokens. + size_t stop_tokens_size; +} LlmModelParameters; + +// LlmSessionConfig configures how to execute the model. +typedef struct { + // Select a supported backend. + enum LlmBackend backend; + + // Sequence batch size for encoding. + size_t sequence_batch_size; + + // Output batch size for decoding.(for gpu) + size_t num_decode_tokens; + + // Maximum sequence length stands for the total number of tokens from input + // and output. + size_t max_sequence_length; + + // Use fake weights instead of loading from file. + bool use_fake_weights; +} LlmSessionConfig; + +// LlmResponseContext is the return type for +// LlmInferenceEngine_Session_PredictSync. +typedef struct { + // An array of string. The size of the array depends on the number of + // responses. + char** response_array; + + // Number of responses. + int response_count; +} LlmResponseContext; + +// Frees all context within the LlmResponseContext including itself. +ODML_EXPORT void LlmInferenceEngine_CloseResponseContext( + LlmResponseContext response_context); + +// Create a LlmInferenceEngine session for executing a query. +ODML_EXPORT LlmInferenceEngine_Session* LlmInferenceEngine_CreateSession( + const LlmModelParameters* model_parameters, + const LlmSessionConfig* session_config); + +// Free the session, will wait until graph is done executing. +ODML_EXPORT void LlmInferenceEngine_Session_Delete( + LlmInferenceEngine_Session* session); + +// Return the generated output in sync mode. +ODML_EXPORT LlmResponseContext LlmInferenceEngine_Session_PredictSync( + LlmInferenceEngine_Session* session, const char* input); + +// Run callback function in async mode. +// The callback context can be a pointer to any user defined data structure as +// it is passed to the callback unmodified. +ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync( + LlmInferenceEngine_Session* session, void* callback_context, + const char* input, + void (*callback)(void* callback_context, + const LlmResponseContext response_context)); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_