diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
index 8b7e18aff6..656f92c60e 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
@@ -85,6 +85,7 @@ android_library(
deps = [
":core_java",
":logging",
+ "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite",
"//third_party/java/protobuf:protobuf_lite",
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java
index c735969be1..0e83f21dfa 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/LlmTaskRunner.java
@@ -15,6 +15,14 @@
package com.google.mediapipe.tasks.core;
import android.content.Context;
+import android.graphics.Bitmap;
+import android.graphics.PixelFormat;
+import android.media.Image;
+import com.google.mediapipe.framework.image.BitmapExtractor;
+import com.google.mediapipe.framework.image.ByteBufferExtractor;
+import com.google.mediapipe.framework.image.MPImage;
+import com.google.mediapipe.framework.image.MPImageProperties;
+import com.google.mediapipe.framework.image.MediaImageExtractor;
import com.google.mediapipe.tasks.core.OutputHandler.ProgressListener;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmModelSettings;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig;
@@ -22,6 +30,7 @@
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
import com.google.protobuf.InvalidProtocolBufferException;
+import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -38,6 +47,109 @@ public final class LlmTaskRunner implements AutoCloseable {
private final TasksStatsLogger statsLogger;
private final AtomicBoolean isProcessing;
+ /**
+ * Describes how pixel bits encode color. A pixel may be an alpha mask, a grayscale, RGB, or ARGB.
+ *
+ *
This matches the SkColorType enum in https://api.skia.org/SkColorType_8h.html.
+ */
+ private enum SkColorType {
+ /** Uninitialized. */
+ UNKNOWN(0),
+ /** Pixel with alpha in 8-bit byte. */
+ ALPHA_8(1),
+ /** Pixel with 5 bits red, 6 bits green, 5 bits blue, in 16-bit word. */
+ RGB_565(2),
+ /** Pixel with 4 bits for alpha, red, green, blue; in 16-bit word. */
+ ARGB_4444(3),
+ /** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word. */
+ RGBA_8888(4),
+ /** Pixel with 8 bits each for red, green, blue; in 32-bit word. */
+ RGB_888X(5),
+ /** Pixel with 8 bits for blue, green, red, alpha; in 32-bit word. */
+ BGRA_8888(6),
+ /** 10 bits for red, green, blue; 2 bits for alpha; in 32-bit word. */
+ RGBA_1010102(7),
+ /** 10 bits for blue, green, red; 2 bits for alpha; in 32-bit word. */
+ BGRA_1010102(8),
+ /** Pixel with 10 bits each for red, green, blue; in 32-bit word. */
+ RGB_101010X(9),
+ /** Pixel with 10 bits each for blue, green, red; in 32-bit word. */
+ BGR101010X(10),
+ /** Pixel with 10 bits each for blue, green, red; in 32-bit word, extended range. */
+ BGR_101010X_XR(11),
+ /** Pixel with 10 bits each for blue, green, red, alpha; in 64-bit word, extended range. */
+ BGRA_10101010_XR(12),
+ /**
+ * Pixel with 10 used bits (most significant) followed by 6 unused bits for red, green, blue,
+ * alpha; in 64-bit word.
+ */
+ RGBA_10X6(13),
+ /** Pixel with grayscale level in 8-bit byte. */
+ GRAY_8(14),
+ /** Pixel with half floats in [0,1] for red, green, blue, alpha; in 64-bit word. */
+ RGBA_F16NORM(15),
+ /** Pixel with half floats for red, green, blue, alpha; in 64-bit word. */
+ RGBA_F16(16),
+ /** Pixel with half floats for red, green, blue; in 64-bit word. */
+ RGB_F16F16F16X(17),
+ /** Pixel using C float for red, green, blue, alpha; in 128-bit word. */
+ RGBA_F32(18),
+ /** Pixel with a uint8_t for red and green. */
+ R8G8_UNORM(19),
+ /** Pixel with a half float for alpha. */
+ A16_FLOAT(20),
+ /** Pixel with a half float for red and green. */
+ R16G16_FLOAT(21),
+ /** Pixel with a little endian uint16_t for alpha. */
+ A16_UNORM(22),
+ /** Pixel with a little endian uint16_t for red and green. */
+ R16G16_UNORM(23),
+ /** Pixel with a little endian uint16_t for red, green, blue and alpha. */
+ R16G16B16A16_UNORM(24),
+ /** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word, gamma encoded. */
+ SRGBA_8888(25),
+ /** Pixel with a uint8_t for red. */
+ R8_UNORM(26);
+
+ private final int value;
+
+ SkColorType(int value) {
+ this.value = value;
+ }
+
+ /** Returns the integer value associated with this color type. */
+ int getValue() {
+ return value;
+ }
+ }
+
+ /**
+ * Describes how to interpret the alpha component of a pixel. A pixel may be opaque, or alpha,
+ * describing multiple levels of transparency.
+ *
+ *
This matches the SkColorType enum in https://api.skia.org/SkAlphaType_8h.html.
+ */
+ private enum SkAlphaType {
+ UNIITALIZED(0),
+ /** Pixel is opaque */
+ OPAQUE(1),
+ /** Pixel components are premultiplied by alpha */
+ PREMULTIPLIED(2),
+ /** Pixel components are independent of alpha */
+ UNPREMULTIPLIED(3);
+
+ private final int value;
+
+ SkAlphaType(int value) {
+ this.value = value;
+ }
+
+ /** Returns the integer value associated with this alpha type. */
+ int getValue() {
+ return value;
+ }
+ };
+
/** The session to use for LLM inference calls. */
public static final class LlmSession {
private final long sessionHandle;
@@ -79,6 +191,21 @@ public void addQueryChunk(LlmSession session, String input) {
nativeAddQueryChunk(session.sessionHandle, input);
}
+ /** Adds a new image to the session context. */
+ public void addImage(LlmSession session, MPImage input) {
+ validateState();
+ long imageHandle = createImage(input);
+ try {
+ // TODO: Remove this dummy chunk.
+ // Since AddImage cannot distinguish if start_id is being added,
+ // use a dummy chunk to make sure the start_id is being added properly.
+ nativeAddQueryChunk(session.sessionHandle, "");
+ nativeAddImage(session.sessionHandle, imageHandle);
+ } finally {
+ nativeDeleteSkBitmap(imageHandle);
+ }
+ }
+
/** Invokes the LLM with the given session and waits for the result. */
public List predictSync(LlmSession session) {
validateState();
@@ -160,6 +287,58 @@ public void close() {
}
}
+ private long createImage(MPImage image) {
+ MPImageProperties properties = image.getContainedImageProperties().get(0);
+
+ SkAlphaType skAlphaType = SkAlphaType.OPAQUE;
+ ByteBuffer buffer;
+ SkColorType skColorType;
+
+ int width = image.getWidth();
+ int height = image.getHeight();
+
+ if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
+ buffer = ByteBufferExtractor.extract(image);
+
+ switch (properties.getImageFormat()) {
+ case MPImage.IMAGE_FORMAT_RGBA:
+ skColorType = SkColorType.RGBA_8888;
+ break;
+ case MPImage.IMAGE_FORMAT_RGB:
+ skColorType = SkColorType.RGB_888X;
+ break;
+ case MPImage.IMAGE_FORMAT_ALPHA:
+ skColorType = SkColorType.ALPHA_8;
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported MediaPipe Image image format: " + properties.getImageFormat());
+ }
+ } else if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
+ Bitmap bitmap = BitmapExtractor.extract(image);
+ if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
+ throw new UnsupportedOperationException("Bitmap must use ARGB_8888 config.");
+ }
+ skColorType = SkColorType.RGBA_8888;
+
+ buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
+ bitmap.copyPixelsToBuffer(buffer);
+ } else if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) {
+ Image mediaImage = MediaImageExtractor.extract(image);
+ if (mediaImage.getFormat() != PixelFormat.RGBA_8888) {
+ throw new UnsupportedOperationException("Android media image must use RGBA_8888 config.");
+ }
+ buffer = mediaImage.getPlanes()[0].getBuffer();
+ skColorType = SkColorType.RGBA_8888;
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported Image container type: " + properties.getStorageType());
+ }
+
+ return nativeCreateSkBitmap(
+ buffer, width, height, skColorType.getValue(), skAlphaType.getValue());
+ }
+
private void validateState() {
if (isProcessing.get()) {
throw new IllegalStateException("Previous invocation still processing. Wait for done=true.");
@@ -187,4 +366,11 @@ private void validateState() {
private static native void nativePredictAsync(long sessionPointer, long callbackContextHandle);
private static native int nativeSizeInTokens(long sessionPointer, String input);
+
+ private static native long nativeCreateSkBitmap(
+ ByteBuffer buffer, int width, int height, int colorType, int alphaType);
+
+ private static native void nativeAddImage(long sessionPointer, long imagePointer);
+
+ private static native void nativeDeleteSkBitmap(long imagePointer);
}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java
new file mode 100644
index 0000000000..c041917cae
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/GraphOptions.java
@@ -0,0 +1,42 @@
+package com.google.mediapipe.tasks.genai.llminference;
+
+import com.google.auto.value.AutoValue;
+
+/** Configuration for the inference graph. */
+@AutoValue
+public abstract class GraphOptions {
+
+ /**
+ * Returns whether to configure the graph to include the token cost calculator, which allows users
+ * to only compute the cost of a prompt.
+ */
+ public abstract boolean includeTokenCostCalculator();
+
+ /** Returns whether to configure the graph to include the vision modality. */
+ public abstract boolean enableVisionModality();
+
+ /** Returns a new {@link Builder} instance. */
+ public static Builder builder() {
+ return new AutoValue_GraphOptions.Builder()
+ .setIncludeTokenCostCalculator(true)
+ .setEnableVisionModality(false);
+ }
+
+ /** Builder for {@link GraphConfig}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets whether to configure the graph to include the token cost calculator. */
+ public abstract Builder setIncludeTokenCostCalculator(boolean includeTokenCostCalculator);
+
+ /** Sets whether to configure the graph to include the vision modality. */
+ public abstract Builder setEnableVisionModality(boolean enableVisionModality);
+
+ /** AutoValue generated builder method. */
+ abstract GraphOptions autoBuild();
+
+ /** Builds a new {@link GraphConfig} instance. */
+ public GraphOptions build() {
+ return autoBuild();
+ }
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java
index dba1e5b955..194fe6bdba 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInference.java
@@ -35,7 +35,7 @@ public final class LlmInference implements AutoCloseable {
/** Creates an LlmInference Task. */
public static LlmInference createFromOptions(Context context, LlmInferenceOptions options) {
// Configure LLM model settings.
- LlmModelSettings modelSettings =
+ LlmModelSettings.Builder modelSettings =
LlmModelSettings.newBuilder()
.setModelPath(options.modelPath())
.setCacheDir(context.getCacheDir().getAbsolutePath())
@@ -43,10 +43,20 @@ public static LlmInference createFromOptions(Context context, LlmInferenceOption
.setMaxTokens(options.maxTokens())
.setMaxTopK(options.maxTopK())
.setNumberOfSupportedLoraRanks(options.supportedLoraRanks().size())
- .addAllSupportedLoraRanks(options.supportedLoraRanks())
- .build();
+ .addAllSupportedLoraRanks(options.supportedLoraRanks());
- return new LlmInference(context, STATS_TAG, modelSettings, options.resultListener());
+ if (options.visionModelOptions().isPresent()) {
+ VisionModelOptions visionModelOptions = options.visionModelOptions().get();
+
+ LlmModelSettings.VisionModelSettings.Builder visionModelSettings =
+ LlmModelSettings.VisionModelSettings.newBuilder();
+ visionModelOptions.getEncoderPath().ifPresent(visionModelSettings::setEncoderPath);
+ visionModelOptions.getAdapterPath().ifPresent(visionModelSettings::setAdapterPath);
+
+ modelSettings.setVisionModelSettings(visionModelSettings.build());
+ }
+
+ return new LlmInference(context, STATS_TAG, modelSettings.build(), options.resultListener());
}
/** Constructor to initialize an {@link LlmInference}. */
@@ -196,9 +206,12 @@ public abstract static class Builder {
*/
public abstract Builder setMaxTopK(int maxTopK);
- /** The supported lora ranks for the base model. Used by GPU only. */
+ /** Sets the supported lora ranks for the base model. Used by GPU only. */
public abstract Builder setSupportedLoraRanks(List supportedLoraRanks);
+ /** Sets the model options to use for vision modality. */
+ public abstract Builder setVisionModelOptions(VisionModelOptions visionModelOptions);
+
abstract LlmInferenceOptions autoBuild();
/** Validates and builds the {@link ImageGeneratorOptions} instance. */
@@ -232,6 +245,9 @@ public final LlmInferenceOptions build() {
/** The error listener to use for the {@link LlmInference#generateAsync} API. */
public abstract Optional errorListener();
+ /** The model options to for vision modality. */
+ public abstract Optional visionModelOptions();
+
/** Returns a new builder with the same values as this instance. */
public abstract Builder toBuilder();
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java
index 2d070cd207..d862667293 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/LlmInferenceSession.java
@@ -1,6 +1,7 @@
package com.google.mediapipe.tasks.genai.llminference;
import com.google.auto.value.AutoValue;
+import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.LlmTaskRunner;
import com.google.mediapipe.tasks.core.LlmTaskRunner.LlmSession;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig;
@@ -37,6 +38,16 @@ public static LlmInferenceSession createFromOptions(
sessionConfig.setLoraPath("");
}
+ if (options.graphOptions().isPresent()) {
+ GraphOptions graphOptions = options.graphOptions().get();
+ LlmSessionConfig.GraphConfig graphConfig =
+ LlmSessionConfig.GraphConfig.newBuilder()
+ .setIncludeTokenCostCalculator(graphOptions.includeTokenCostCalculator())
+ .setEnableVisionModality(graphOptions.enableVisionModality())
+ .build();
+ sessionConfig.setGraphConfig(graphConfig);
+ }
+
LlmTaskRunner taskRunner = llmInference.getTaskRunner();
LlmSession session = taskRunner.createSession(sessionConfig.build());
return new LlmInferenceSession(taskRunner, session);
@@ -60,6 +71,16 @@ public void addQueryChunk(String inputText) {
taskRunner.addQueryChunk(session, inputText);
}
+ /**
+ * Add an image to the session.
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void addImage(MPImage image) {
+ taskRunner.addImage(session, image);
+ }
+
/**
* Generates a response based on the previously added query chunks synchronously. Use {@link
* #addQueryChunk(String)} to add at least one query chunk before calling this function.
@@ -170,6 +191,9 @@ public abstract static class Builder {
*/
public abstract Builder setLoraPath(String loraPath);
+ /** Sets the parameters to customize the graph. */
+ public abstract Builder setGraphOptions(GraphOptions graphOptions);
+
abstract LlmInferenceSessionOptions autoBuild();
/** Validates and builds the {@link LlmInferenceSessionOptions} instance. */
@@ -196,6 +220,9 @@ public final LlmInferenceSessionOptions build() {
*/
public abstract Optional loraPath();
+ /** Returns the parameters to customize the graph. */
+ public abstract Optional graphOptions();
+
/** Instantiates a new LlmInferenceOptions builder. */
public static Builder builder() {
return new AutoValue_LlmInferenceSession_LlmInferenceSessionOptions.Builder()
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java
new file mode 100644
index 0000000000..1effe75ab8
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/genai/llminference/VisionModelOptions.java
@@ -0,0 +1,36 @@
+package com.google.mediapipe.tasks.genai.llminference;
+
+import com.google.auto.value.AutoValue;
+import java.util.Optional;
+
+/** Options for configuring vision modality */
+@AutoValue
+public abstract class VisionModelOptions {
+ /** Returns the path to the vision encoder model file. */
+ public abstract Optional getEncoderPath();
+
+ /** Path to the vision adapter model file. */
+ public abstract Optional getAdapterPath();
+
+ /** Builder for {@link VisionModelOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the path to the vision encoder model file. */
+ public abstract Builder setEncoderPath(String encoderPath);
+
+ /** Sets the to the vision adapter model file. */
+ public abstract Builder setAdapterPath(String adapterPath);
+
+ abstract VisionModelOptions autoBuild();
+
+ /** Validates and builds the {@link VisionModelOptions} instance. */
+ public final VisionModelOptions build() {
+ return autoBuild();
+ }
+ }
+
+ /** Instantiates a new VisionModelOption builder. */
+ public static Builder builder() {
+ return new AutoValue_VisionModelOptions.Builder();
+ }
+}