Skip to content

Commit

Permalink
Migrate TextGenerator Java API to C Wrapper
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597956904
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Jan 12, 2024
1 parent c460c2e commit f457fe7
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 10 deletions.
5 changes: 4 additions & 1 deletion mediapipe/framework/deps/file_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "mediapipe/framework/deps/file_helpers.h"

#include "absl/strings/str_cat.h"

#ifdef _WIN32
#include <Windows.h>
#include <direct.h>
Expand Down Expand Up @@ -249,7 +251,8 @@ absl::Status Exists(absl::string_view file_name) {
case EACCES:
return mediapipe::PermissionDeniedError("Insufficient permissions.");
default:
return absl::NotFoundError("The path does not exist.");
return absl::NotFoundError(
absl::StrCat("The path does not exist: ", file_name));
}
}

Expand Down
1 change: 1 addition & 0 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ android_library(
],
manifest = "AndroidManifest.xml",
deps = [
":core_java",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite",
"//third_party/java/protobuf:protobuf_lite",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

package com.google.mediapipe.tasks.core;

import com.google.mediapipe.tasks.core.OutputHandler.ValueListener;
import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmModelParameters;
import com.google.mediapipe.tasks.core.jni.LlmOptionsProto.LlmSessionConfig;
import com.google.mediapipe.tasks.core.jni.LlmResponseContextProto.LlmResponseContext;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

/**
* Internal Task Runner class for all LLM Tasks.
Expand All @@ -30,13 +30,13 @@
*/
public final class LlmTaskRunner implements AutoCloseable {
private final long sessionHandle;
private final Optional<Function<List<String>, Void>> resultListener;
private final Optional<ValueListener<List<String>>> resultListener;
private final long callbackHandle;

public LlmTaskRunner(
LlmModelParameters modelParameters,
LlmSessionConfig sessionConfig,
Optional<Function<List<String>, Void>> resultListener) {
Optional<ValueListener<List<String>>> resultListener) {
this.sessionHandle =
nativeCreateSession(modelParameters.toByteArray(), sessionConfig.toByteArray());

Expand Down Expand Up @@ -73,7 +73,7 @@ private List<String> parseResponse(byte[] reponse) {
}

private void onAsyncResponse(byte[] responseBytes) {
resultListener.get().apply(parseResponse(responseBytes));
resultListener.get().run(parseResponse(responseBytes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,29 @@ public interface OutputPacketConverter<OutputT extends TaskResult, InputT> {
}

/**
* Interface for the customizable MediaPipe task result listener that can reteive both task result
* objects and the corresponding input data.
* Interface for the customizable MediaPipe task result listener that can retrieve both task
* result objects and the corresponding input data.
*/
public interface ResultListener<OutputT extends TaskResult, InputT> {
void run(OutputT result, InputT input);
}

/**
* Interface for the customizable MediaPipe task result listener that can only reteive task result
* objects.
* Interface for the customizable MediaPipe task result listener that can only retrieve task
* result objects.
*/
public interface PureResultListener<OutputT extends TaskResult> {
void run(OutputT result);
}

/**
* Interface for the customizable MediaPipe task result listener that only receives a task's
* output value.
*/
public interface ValueListener<OutputT> {
void run(OutputT result);
}

private static final String TAG = "OutputHandler";
// A task-specific graph output packet converter that should be implemented per task.
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,21 @@ cc_library_with_tflite(
alwayslink = 1,
)

cc_library(
name = "llm_inference_engine_hdr",
hdrs = ["llm_inference_engine.h"],
)

cc_library(
name = "llm",
srcs = ["llm.cc"],
hdrs = ["llm.h"],
deps = [
":llm_inference_engine_hdr",
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_cc_proto",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_cc_proto",
"//third_party/odml/infra/genai/inference/c:libllm_inference_engine",
"//third_party/odml/infra/genai/inference/c:libllm_inference_engine_deps",
"@com_google_absl//absl/status",
] + select({
"//mediapipe:android": [],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#ifndef MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_
#define MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_

#include <cstddef>
#include <cstdint>

#ifndef ODML_EXPORT
#define ODML_EXPORT __attribute__((visibility("default")))
#endif // ODML_EXPORT

#ifdef __cplusplus
extern "C" {
#endif

typedef void LlmInferenceEngine_Session;

// Supported model types.
enum LlmModelType {
// Unknown
kUNKNOWN_MODEL_TYPE,

// Falcon with 1B parameters.
kFalcon1B,

// GMini with 2B parameters.
kGMini2B,
};

// Attention types.
enum LlmAttentionType {
// Multi-head Attention.
kMHA,

// Multi-query Attention.
kMQA,
};

// Backend to execute the large language model.
enum LlmBackend {
// CPU
kCPU,

// GPU
kGPU,
};

// LlmModelParameters should accurately describe the model used.
typedef struct {
// Set a supported model types.
enum LlmModelType model_type;

// Path to the directory that contains spm.model and the weight directory.
const char* model_path;

// MHA or MQA.
enum LlmAttentionType attention_type;

// Start token id will be appended to the query before feeding into the model.
int start_token_id;

// Stop token/word that indicates the response is completed.
const char** stop_tokens;

// Number of stop tokens.
size_t stop_tokens_size;
} LlmModelParameters;

// LlmSessionConfig configures how to execute the model.
typedef struct {
// Select a supported backend.
enum LlmBackend backend;

// Sequence batch size for encoding.
size_t sequence_batch_size;

// Output batch size for decoding.(for gpu)
size_t num_decode_tokens;

// Maximum sequence length stands for the total number of tokens from input
// and output.
size_t max_sequence_length;

// Use fake weights instead of loading from file.
bool use_fake_weights;
} LlmSessionConfig;

// LlmResponseContext is the return type for
// LlmInferenceEngine_Session_PredictSync.
typedef struct {
// An array of string. The size of the array depends on the number of
// responses.
char** response_array;

// Number of responses.
int response_count;
} LlmResponseContext;

// Frees all context within the LlmResponseContext including itself.
ODML_EXPORT void LlmInferenceEngine_CloseResponseContext(
LlmResponseContext response_context);

// Create a LlmInferenceEngine session for executing a query.
ODML_EXPORT LlmInferenceEngine_Session* LlmInferenceEngine_CreateSession(
const LlmModelParameters* model_parameters,
const LlmSessionConfig* session_config);

// Free the session, will wait until graph is done executing.
ODML_EXPORT void LlmInferenceEngine_Session_Delete(
LlmInferenceEngine_Session* session);

// Return the generated output in sync mode.
ODML_EXPORT LlmResponseContext LlmInferenceEngine_Session_PredictSync(
LlmInferenceEngine_Session* session, const char* input);

// Run callback function in async mode.
// The callback context can be a pointer to any user defined data structure as
// it is passed to the callback unmodified.
ODML_EXPORT void LlmInferenceEngine_Session_PredictAsync(
LlmInferenceEngine_Session* session, void* callback_context,
const char* input,
void (*callback)(void* callback_context,
const LlmResponseContext response_context));

#ifdef __cplusplus
} // extern C
#endif

#endif // MEDIAPIPE_TASKS_JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_LLM_INFERENCE_ENGINE_H_

0 comments on commit f457fe7

Please sign in to comment.