From 96e3a694a1a5e82828bd2ca467e1fa07f9772550 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 12 Dec 2024 09:15:46 -0800 Subject: [PATCH] Add vision modality to the Java LLM API PiperOrigin-RevId: 705528226 --- .../com/google/mediapipe/tasks/core/BUILD | 1 + .../mediapipe/tasks/core/LlmTaskRunner.java | 186 ++++++++++++++++++ .../genai/llminference/GraphOptions.java | 42 ++++ .../genai/llminference/LlmInference.java | 26 ++- .../llminference/LlmInferenceSession.java | 27 +++ .../llminference/VisionModelOptions.java | 36 ++++ 6 files changed, 313 insertions(+), 5 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 8b7e18aff6..656f92c60e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -85,6 +85,7 @@ android_library( deps = [ ":core_java", ":logging", + "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite", "//third_party/java/protobuf:protobuf_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java index c735969be1..0e83f21dfa 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java @@ -15,6 +15,14 @@ package com.google.mediapipe.tasks.core; import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.PixelFormat; +import android.media.Image; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.ByteBufferExtractor; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.framework.image.MPImageProperties; +import com.google.mediapipe.framework.image.MediaImageExtractor; import com.google.mediapipe.tasks.core.OutputHandler.ProgressListener; import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmModelSettings; import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig; @@ -22,6 +30,7 @@ import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger; import com.google.mediapipe.tasks.core.logging.TasksStatsLogger; import com.google.protobuf.InvalidProtocolBufferException; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; @@ -38,6 +47,109 @@ public final class LlmTaskRunner implements AutoCloseable { private final TasksStatsLogger statsLogger; private final AtomicBoolean isProcessing; + /** + * Describes how pixel bits encode color. A pixel may be an alpha mask, a grayscale, RGB, or ARGB. + * + *

This matches the SkColorType enum in https://api.skia.org/SkColorType_8h.html. + */ + private enum SkColorType { + /** Uninitialized. */ + UNKNOWN(0), + /** Pixel with alpha in 8-bit byte. */ + ALPHA_8(1), + /** Pixel with 5 bits red, 6 bits green, 5 bits blue, in 16-bit word. */ + RGB_565(2), + /** Pixel with 4 bits for alpha, red, green, blue; in 16-bit word. */ + ARGB_4444(3), + /** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word. */ + RGBA_8888(4), + /** Pixel with 8 bits each for red, green, blue; in 32-bit word. */ + RGB_888X(5), + /** Pixel with 8 bits for blue, green, red, alpha; in 32-bit word. */ + BGRA_8888(6), + /** 10 bits for red, green, blue; 2 bits for alpha; in 32-bit word. */ + RGBA_1010102(7), + /** 10 bits for blue, green, red; 2 bits for alpha; in 32-bit word. */ + BGRA_1010102(8), + /** Pixel with 10 bits each for red, green, blue; in 32-bit word. */ + RGB_101010X(9), + /** Pixel with 10 bits each for blue, green, red; in 32-bit word. */ + BGR101010X(10), + /** Pixel with 10 bits each for blue, green, red; in 32-bit word, extended range. */ + BGR_101010X_XR(11), + /** Pixel with 10 bits each for blue, green, red, alpha; in 64-bit word, extended range. */ + BGRA_10101010_XR(12), + /** + * Pixel with 10 used bits (most significant) followed by 6 unused bits for red, green, blue, + * alpha; in 64-bit word. + */ + RGBA_10X6(13), + /** Pixel with grayscale level in 8-bit byte. */ + GRAY_8(14), + /** Pixel with half floats in [0,1] for red, green, blue, alpha; in 64-bit word. */ + RGBA_F16NORM(15), + /** Pixel with half floats for red, green, blue, alpha; in 64-bit word. */ + RGBA_F16(16), + /** Pixel with half floats for red, green, blue; in 64-bit word. */ + RGB_F16F16F16X(17), + /** Pixel using C float for red, green, blue, alpha; in 128-bit word. */ + RGBA_F32(18), + /** Pixel with a uint8_t for red and green. */ + R8G8_UNORM(19), + /** Pixel with a half float for alpha. */ + A16_FLOAT(20), + /** Pixel with a half float for red and green. */ + R16G16_FLOAT(21), + /** Pixel with a little endian uint16_t for alpha. */ + A16_UNORM(22), + /** Pixel with a little endian uint16_t for red and green. */ + R16G16_UNORM(23), + /** Pixel with a little endian uint16_t for red, green, blue and alpha. */ + R16G16B16A16_UNORM(24), + /** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word, gamma encoded. */ + SRGBA_8888(25), + /** Pixel with a uint8_t for red. */ + R8_UNORM(26); + + private final int value; + + SkColorType(int value) { + this.value = value; + } + + /** Returns the integer value associated with this color type. */ + int getValue() { + return value; + } + } + + /** + * Describes how to interpret the alpha component of a pixel. A pixel may be opaque, or alpha, + * describing multiple levels of transparency. + * + *

This matches the SkColorType enum in https://api.skia.org/SkAlphaType_8h.html. + */ + private enum SkAlphaType { + UNIITALIZED(0), + /** Pixel is opaque */ + OPAQUE(1), + /** Pixel components are premultiplied by alpha */ + PREMULTIPLIED(2), + /** Pixel components are independent of alpha */ + UNPREMULTIPLIED(3); + + private final int value; + + SkAlphaType(int value) { + this.value = value; + } + + /** Returns the integer value associated with this alpha type. */ + int getValue() { + return value; + } + }; + /** The session to use for LLM inference calls. */ public static final class LlmSession { private final long sessionHandle; @@ -79,6 +191,21 @@ public void addQueryChunk(LlmSession session, String input) { nativeAddQueryChunk(session.sessionHandle, input); } + /** Adds a new image to the session context. */ + public void addImage(LlmSession session, MPImage input) { + validateState(); + long imageHandle = createImage(input); + try { + // TODO: Remove this dummy chunk. + // Since AddImage cannot distinguish if start_id is being added, + // use a dummy chunk to make sure the start_id is being added properly. + nativeAddQueryChunk(session.sessionHandle, ""); + nativeAddImage(session.sessionHandle, imageHandle); + } finally { + nativeDeleteSkBitmap(imageHandle); + } + } + /** Invokes the LLM with the given session and waits for the result. */ public List predictSync(LlmSession session) { validateState(); @@ -160,6 +287,58 @@ public void close() { } } + private long createImage(MPImage image) { + MPImageProperties properties = image.getContainedImageProperties().get(0); + + SkAlphaType skAlphaType = SkAlphaType.OPAQUE; + ByteBuffer buffer; + SkColorType skColorType; + + int width = image.getWidth(); + int height = image.getHeight(); + + if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) { + buffer = ByteBufferExtractor.extract(image); + + switch (properties.getImageFormat()) { + case MPImage.IMAGE_FORMAT_RGBA: + skColorType = SkColorType.RGBA_8888; + break; + case MPImage.IMAGE_FORMAT_RGB: + skColorType = SkColorType.RGB_888X; + break; + case MPImage.IMAGE_FORMAT_ALPHA: + skColorType = SkColorType.ALPHA_8; + break; + default: + throw new UnsupportedOperationException( + "Unsupported MediaPipe Image image format: " + properties.getImageFormat()); + } + } else if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) { + Bitmap bitmap = BitmapExtractor.extract(image); + if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { + throw new UnsupportedOperationException("Bitmap must use ARGB_8888 config."); + } + skColorType = SkColorType.RGBA_8888; + + buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); + bitmap.copyPixelsToBuffer(buffer); + } else if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) { + Image mediaImage = MediaImageExtractor.extract(image); + if (mediaImage.getFormat() != PixelFormat.RGBA_8888) { + throw new UnsupportedOperationException("Android media image must use RGBA_8888 config."); + } + buffer = mediaImage.getPlanes()[0].getBuffer(); + skColorType = SkColorType.RGBA_8888; + } else { + throw new UnsupportedOperationException( + "Unsupported Image container type: " + properties.getStorageType()); + } + + return nativeCreateSkBitmap( + buffer, width, height, skColorType.getValue(), skAlphaType.getValue()); + } + private void validateState() { if (isProcessing.get()) { throw new IllegalStateException("Previous invocation still processing. Wait for done=true."); @@ -187,4 +366,11 @@ private void validateState() { private static native void nativePredictAsync(long sessionPointer, long callbackContextHandle); private static native int nativeSizeInTokens(long sessionPointer, String input); + + private static native long nativeCreateSkBitmap( + ByteBuffer buffer, int width, int height, int colorType, int alphaType); + + private static native void nativeAddImage(long sessionPointer, long imagePointer); + + private static native void nativeDeleteSkBitmap(long imagePointer); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java new file mode 100644 index 0000000000..c041917cae --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java @@ -0,0 +1,42 @@ +package com.google.mediapipe.tasks.genai.llminference; + +import com.google.auto.value.AutoValue; + +/** Configuration for the inference graph. */ +@AutoValue +public abstract class GraphOptions { + + /** + * Returns whether to configure the graph to include the token cost calculator, which allows users + * to only compute the cost of a prompt. + */ + public abstract boolean includeTokenCostCalculator(); + + /** Returns whether to configure the graph to include the vision modality. */ + public abstract boolean enableVisionModality(); + + /** Returns a new {@link Builder} instance. */ + public static Builder builder() { + return new AutoValue_GraphOptions.Builder() + .setIncludeTokenCostCalculator(true) + .setEnableVisionModality(false); + } + + /** Builder for {@link GraphConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets whether to configure the graph to include the token cost calculator. */ + public abstract Builder setIncludeTokenCostCalculator(boolean includeTokenCostCalculator); + + /** Sets whether to configure the graph to include the vision modality. */ + public abstract Builder setEnableVisionModality(boolean enableVisionModality); + + /** AutoValue generated builder method. */ + abstract GraphOptions autoBuild(); + + /** Builds a new {@link GraphConfig} instance. */ + public GraphOptions build() { + return autoBuild(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java index dba1e5b955..194fe6bdba 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java @@ -35,7 +35,7 @@ public final class LlmInference implements AutoCloseable { /** Creates an LlmInference Task. */ public static LlmInference createFromOptions(Context context, LlmInferenceOptions options) { // Configure LLM model settings. - LlmModelSettings modelSettings = + LlmModelSettings.Builder modelSettings = LlmModelSettings.newBuilder() .setModelPath(options.modelPath()) .setCacheDir(context.getCacheDir().getAbsolutePath()) @@ -43,10 +43,20 @@ public static LlmInference createFromOptions(Context context, LlmInferenceOption .setMaxTokens(options.maxTokens()) .setMaxTopK(options.maxTopK()) .setNumberOfSupportedLoraRanks(options.supportedLoraRanks().size()) - .addAllSupportedLoraRanks(options.supportedLoraRanks()) - .build(); + .addAllSupportedLoraRanks(options.supportedLoraRanks()); - return new LlmInference(context, STATS_TAG, modelSettings, options.resultListener()); + if (options.visionModelOptions().isPresent()) { + VisionModelOptions visionModelOptions = options.visionModelOptions().get(); + + LlmModelSettings.VisionModelSettings.Builder visionModelSettings = + LlmModelSettings.VisionModelSettings.newBuilder(); + visionModelOptions.getEncoderPath().ifPresent(visionModelSettings::setEncoderPath); + visionModelOptions.getAdapterPath().ifPresent(visionModelSettings::setAdapterPath); + + modelSettings.setVisionModelSettings(visionModelSettings.build()); + } + + return new LlmInference(context, STATS_TAG, modelSettings.build(), options.resultListener()); } /** Constructor to initialize an {@link LlmInference}. */ @@ -196,9 +206,12 @@ public abstract static class Builder { */ public abstract Builder setMaxTopK(int maxTopK); - /** The supported lora ranks for the base model. Used by GPU only. */ + /** Sets the supported lora ranks for the base model. Used by GPU only. */ public abstract Builder setSupportedLoraRanks(List supportedLoraRanks); + /** Sets the model options to use for vision modality. */ + public abstract Builder setVisionModelOptions(VisionModelOptions visionModelOptions); + abstract LlmInferenceOptions autoBuild(); /** Validates and builds the {@link ImageGeneratorOptions} instance. */ @@ -232,6 +245,9 @@ public final LlmInferenceOptions build() { /** The error listener to use for the {@link LlmInference#generateAsync} API. */ public abstract Optional errorListener(); + /** The model options to for vision modality. */ + public abstract Optional visionModelOptions(); + /** Returns a new builder with the same values as this instance. */ public abstract Builder toBuilder(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java index 2d070cd207..d862667293 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java @@ -1,6 +1,7 @@ package com.google.mediapipe.tasks.genai.llminference; import com.google.auto.value.AutoValue; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.LlmTaskRunner; import com.google.mediapipe.tasks.core.LlmTaskRunner.LlmSession; import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig; @@ -37,6 +38,16 @@ public static LlmInferenceSession createFromOptions( sessionConfig.setLoraPath(""); } + if (options.graphOptions().isPresent()) { + GraphOptions graphOptions = options.graphOptions().get(); + LlmSessionConfig.GraphConfig graphConfig = + LlmSessionConfig.GraphConfig.newBuilder() + .setIncludeTokenCostCalculator(graphOptions.includeTokenCostCalculator()) + .setEnableVisionModality(graphOptions.enableVisionModality()) + .build(); + sessionConfig.setGraphConfig(graphConfig); + } + LlmTaskRunner taskRunner = llmInference.getTaskRunner(); LlmSession session = taskRunner.createSession(sessionConfig.build()); return new LlmInferenceSession(taskRunner, session); @@ -60,6 +71,16 @@ public void addQueryChunk(String inputText) { taskRunner.addQueryChunk(session, inputText); } + /** + * Add an image to the session. + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public void addImage(MPImage image) { + taskRunner.addImage(session, image); + } + /** * Generates a response based on the previously added query chunks synchronously. Use {@link * #addQueryChunk(String)} to add at least one query chunk before calling this function. @@ -170,6 +191,9 @@ public abstract static class Builder { */ public abstract Builder setLoraPath(String loraPath); + /** Sets the parameters to customize the graph. */ + public abstract Builder setGraphOptions(GraphOptions graphOptions); + abstract LlmInferenceSessionOptions autoBuild(); /** Validates and builds the {@link LlmInferenceSessionOptions} instance. */ @@ -196,6 +220,9 @@ public final LlmInferenceSessionOptions build() { */ public abstract Optional loraPath(); + /** Returns the parameters to customize the graph. */ + public abstract Optional graphOptions(); + /** Instantiates a new LlmInferenceOptions builder. */ public static Builder builder() { return new AutoValue_LlmInferenceSession_LlmInferenceSessionOptions.Builder() diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java new file mode 100644 index 0000000000..1effe75ab8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java @@ -0,0 +1,36 @@ +package com.google.mediapipe.tasks.genai.llminference; + +import com.google.auto.value.AutoValue; +import java.util.Optional; + +/** Options for configuring vision modality */ +@AutoValue +public abstract class VisionModelOptions { + /** Returns the path to the vision encoder model file. */ + public abstract Optional getEncoderPath(); + + /** Path to the vision adapter model file. */ + public abstract Optional getAdapterPath(); + + /** Builder for {@link VisionModelOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the path to the vision encoder model file. */ + public abstract Builder setEncoderPath(String encoderPath); + + /** Sets the to the vision adapter model file. */ + public abstract Builder setAdapterPath(String adapterPath); + + abstract VisionModelOptions autoBuild(); + + /** Validates and builds the {@link VisionModelOptions} instance. */ + public final VisionModelOptions build() { + return autoBuild(); + } + } + + /** Instantiates a new VisionModelOption builder. */ + public static Builder builder() { + return new AutoValue_VisionModelOptions.Builder(); + } +}