From adc16aeb81ac9907baa020b524b1f5c8eba770d0 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 11 Dec 2024 10:37:17 -0800 Subject: [PATCH] Add vision modality to the C API PiperOrigin-RevId: 705157449 --- .../genai/inference/c/llm_inference_engine.h | 18 ++++++++++++++++++ .../inference/c/llm_inference_engine_cpu.cc | 7 +++++++ .../sources/LlmInference+Session.swift | 4 +++- .../genai/inference/sources/LlmInference.swift | 2 ++ .../com/google/mediapipe/tasks/core/jni/llm.cc | 4 ++++ 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h index b4d33cb618..91f9a08c70 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h @@ -58,6 +58,12 @@ typedef struct { // Path to the model artifact. const char* model_path; + // Path to the vision encoder to use for vision modality. Optional. + const char* vision_encoder_path; + + // Path to the vision adapter to use for vision modality. Optional. + const char* vision_adapter_path; + // Directory path for storing model related tokenizer and cache weights. the // user is responsible for providing the directory that can be writable by the // program. @@ -121,6 +127,13 @@ typedef struct { // Path to the LoRA tflite flatbuffer file. Optional. // This is only compatible with GPU models. const char* lora_path; + + // Whether to configure the graph to include the token cost calculator, + // which allows users to only compute the cost of a prompt. + bool include_token_cost_calculator; + + // Whether to configure the graph to include the vision modality. + bool enable_vision_modality; } LlmSessionConfig; // LlmResponseContext is the return type for @@ -166,6 +179,11 @@ ODML_EXPORT void LlmInferenceEngine_Session_Delete( ODML_EXPORT int LlmInferenceEngine_Session_AddQueryChunk( LlmInferenceEngine_Session* session, const char* input, char** error_msg); +// Adds an SKBitmap to the session. +ODML_EXPORT int LlmInferenceEngine_Session_AddImage( + LlmInferenceEngine_Session* session, const void* sk_bitmap, + char** error_msg); + // Return the generated output based on the previously added query chunks in // sync mode. ODML_EXPORT LlmResponseContext diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc index 40524aee95..7c1e5bd233 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine_cpu.cc @@ -584,6 +584,13 @@ int LlmInferenceEngine_Session_AddQueryChunk( return 0; } +ODML_EXPORT int LlmInferenceEngine_Session_AddImage( + LlmInferenceEngine_Session* session, const void* sk_bitmap, + char** error_msg) { + *error_msg = strdup("Not implemented"); + return 12; +} + LlmResponseContext LlmInferenceEngine_Session_PredictSync( LlmInferenceEngine_Session* session) { LlmInferenceEngine_Session_PredictAsync( diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift index 3fe740de2a..a01b80e20c 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference+Session.swift @@ -54,7 +54,9 @@ extension LlmInference { topp: options.topp, temperature: options.temperature, random_seed: options.randomSeed, - lora_path: nil) + lora_path: nil, + include_token_cost_calculator: true, + enable_vision_modality: false) /// If `loraPath` is != nil, modify session config with the corresponding C string and invoke /// the method to create session runner within the scope where the C String of the `loraPath` diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift index 312bb23b2f..f5d7cd940c 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift @@ -65,6 +65,8 @@ import MediaPipeTasksGenAIC try options.supportedLoraRanks.withUnsafeMutableBufferPointer { supportedLoraRanks in let modelSetting = LlmModelSettings( model_path: modelPath, + vision_encoder_path: nil, + vision_adapter_path: nil, cache_dir: cacheDirectory, max_num_tokens: options.maxTokens, num_decode_steps_per_sync: numberOfDecodeStepsPerSync, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc index 18d5ceb2c8..e0f3db7d84 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc @@ -43,6 +43,8 @@ LlmModelSettings ParseModelSettings(void* bytes, int size) { LlmModelSettings output; output.model_path = strdup(input.model_path().c_str()); + output.vision_encoder_path = nullptr; + output.vision_adapter_path = nullptr; output.cache_dir = strdup(input.cache_dir().c_str()); output.sequence_batch_size = input.sequence_batch_size(); output.num_decode_steps_per_sync = input.num_decode_steps_per_sync(); @@ -74,6 +76,8 @@ LlmSessionConfig ParseSessionConfig(void* bytes, int size) { if (input.has_lora_path()) { output.lora_path = strdup(input.lora_path().c_str()); } + output.include_token_cost_calculator = true; + output.enable_vision_modality = false; return output; }