Skip to content

Commit

Permalink
Add vision modality to the Java LLM API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705528226
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Dec 12, 2024
1 parent ad6c03e commit 96e3a69
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 5 deletions.
1 change: 1 addition & 0 deletions mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ android_library(
deps = [
":core_java",
":logging",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/proto:llm_response_context_java_proto_lite",
"//third_party/java/protobuf:protobuf_lite",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@
package com.google.mediapipe.tasks.core;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.PixelFormat;
import android.media.Image;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.MPImageProperties;
import com.google.mediapipe.framework.image.MediaImageExtractor;
import com.google.mediapipe.tasks.core.OutputHandler.ProgressListener;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmModelSettings;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig;
import com.google.mediapipe.tasks.core.jni.proto.LlmResponseContextProto.LlmResponseContext;
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -38,6 +47,109 @@ public final class LlmTaskRunner implements AutoCloseable {
private final TasksStatsLogger statsLogger;
private final AtomicBoolean isProcessing;

/**
* Describes how pixel bits encode color. A pixel may be an alpha mask, a grayscale, RGB, or ARGB.
*
* <p>This matches the SkColorType enum in https://api.skia.org/SkColorType_8h.html.
*/
private enum SkColorType {
/** Uninitialized. */
UNKNOWN(0),
/** Pixel with alpha in 8-bit byte. */
ALPHA_8(1),
/** Pixel with 5 bits red, 6 bits green, 5 bits blue, in 16-bit word. */
RGB_565(2),
/** Pixel with 4 bits for alpha, red, green, blue; in 16-bit word. */
ARGB_4444(3),
/** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word. */
RGBA_8888(4),
/** Pixel with 8 bits each for red, green, blue; in 32-bit word. */
RGB_888X(5),
/** Pixel with 8 bits for blue, green, red, alpha; in 32-bit word. */
BGRA_8888(6),
/** 10 bits for red, green, blue; 2 bits for alpha; in 32-bit word. */
RGBA_1010102(7),
/** 10 bits for blue, green, red; 2 bits for alpha; in 32-bit word. */
BGRA_1010102(8),
/** Pixel with 10 bits each for red, green, blue; in 32-bit word. */
RGB_101010X(9),
/** Pixel with 10 bits each for blue, green, red; in 32-bit word. */
BGR101010X(10),
/** Pixel with 10 bits each for blue, green, red; in 32-bit word, extended range. */
BGR_101010X_XR(11),
/** Pixel with 10 bits each for blue, green, red, alpha; in 64-bit word, extended range. */
BGRA_10101010_XR(12),
/**
* Pixel with 10 used bits (most significant) followed by 6 unused bits for red, green, blue,
* alpha; in 64-bit word.
*/
RGBA_10X6(13),
/** Pixel with grayscale level in 8-bit byte. */
GRAY_8(14),
/** Pixel with half floats in [0,1] for red, green, blue, alpha; in 64-bit word. */
RGBA_F16NORM(15),
/** Pixel with half floats for red, green, blue, alpha; in 64-bit word. */
RGBA_F16(16),
/** Pixel with half floats for red, green, blue; in 64-bit word. */
RGB_F16F16F16X(17),
/** Pixel using C float for red, green, blue, alpha; in 128-bit word. */
RGBA_F32(18),
/** Pixel with a uint8_t for red and green. */
R8G8_UNORM(19),
/** Pixel with a half float for alpha. */
A16_FLOAT(20),
/** Pixel with a half float for red and green. */
R16G16_FLOAT(21),
/** Pixel with a little endian uint16_t for alpha. */
A16_UNORM(22),
/** Pixel with a little endian uint16_t for red and green. */
R16G16_UNORM(23),
/** Pixel with a little endian uint16_t for red, green, blue and alpha. */
R16G16B16A16_UNORM(24),
/** Pixel with 8 bits for red, green, blue, alpha; in 32-bit word, gamma encoded. */
SRGBA_8888(25),
/** Pixel with a uint8_t for red. */
R8_UNORM(26);

private final int value;

SkColorType(int value) {
this.value = value;
}

/** Returns the integer value associated with this color type. */
int getValue() {
return value;
}
}

/**
* Describes how to interpret the alpha component of a pixel. A pixel may be opaque, or alpha,
* describing multiple levels of transparency.
*
* <p>This matches the SkColorType enum in https://api.skia.org/SkAlphaType_8h.html.
*/
private enum SkAlphaType {
UNIITALIZED(0),
/** Pixel is opaque */
OPAQUE(1),
/** Pixel components are premultiplied by alpha */
PREMULTIPLIED(2),
/** Pixel components are independent of alpha */
UNPREMULTIPLIED(3);

private final int value;

SkAlphaType(int value) {
this.value = value;
}

/** Returns the integer value associated with this alpha type. */
int getValue() {
return value;
}
};

/** The session to use for LLM inference calls. */
public static final class LlmSession {
private final long sessionHandle;
Expand Down Expand Up @@ -79,6 +191,21 @@ public void addQueryChunk(LlmSession session, String input) {
nativeAddQueryChunk(session.sessionHandle, input);
}

/** Adds a new image to the session context. */
public void addImage(LlmSession session, MPImage input) {
validateState();
long imageHandle = createImage(input);
try {
// TODO: Remove this dummy chunk.
// Since AddImage cannot distinguish if start_id is being added,
// use a dummy chunk to make sure the start_id is being added properly.
nativeAddQueryChunk(session.sessionHandle, "");
nativeAddImage(session.sessionHandle, imageHandle);
} finally {
nativeDeleteSkBitmap(imageHandle);
}
}

/** Invokes the LLM with the given session and waits for the result. */
public List<String> predictSync(LlmSession session) {
validateState();
Expand Down Expand Up @@ -160,6 +287,58 @@ public void close() {
}
}

private long createImage(MPImage image) {
MPImageProperties properties = image.getContainedImageProperties().get(0);

SkAlphaType skAlphaType = SkAlphaType.OPAQUE;
ByteBuffer buffer;
SkColorType skColorType;

int width = image.getWidth();
int height = image.getHeight();

if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
buffer = ByteBufferExtractor.extract(image);

switch (properties.getImageFormat()) {
case MPImage.IMAGE_FORMAT_RGBA:
skColorType = SkColorType.RGBA_8888;
break;
case MPImage.IMAGE_FORMAT_RGB:
skColorType = SkColorType.RGB_888X;
break;
case MPImage.IMAGE_FORMAT_ALPHA:
skColorType = SkColorType.ALPHA_8;
break;
default:
throw new UnsupportedOperationException(
"Unsupported MediaPipe Image image format: " + properties.getImageFormat());
}
} else if (properties.getStorageType() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("Bitmap must use ARGB_8888 config.");
}
skColorType = SkColorType.RGBA_8888;

buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer);
} else if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) {
Image mediaImage = MediaImageExtractor.extract(image);
if (mediaImage.getFormat() != PixelFormat.RGBA_8888) {
throw new UnsupportedOperationException("Android media image must use RGBA_8888 config.");
}
buffer = mediaImage.getPlanes()[0].getBuffer();
skColorType = SkColorType.RGBA_8888;
} else {
throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getStorageType());
}

return nativeCreateSkBitmap(
buffer, width, height, skColorType.getValue(), skAlphaType.getValue());
}

private void validateState() {
if (isProcessing.get()) {
throw new IllegalStateException("Previous invocation still processing. Wait for done=true.");
Expand Down Expand Up @@ -187,4 +366,11 @@ private void validateState() {
private static native void nativePredictAsync(long sessionPointer, long callbackContextHandle);

private static native int nativeSizeInTokens(long sessionPointer, String input);

private static native long nativeCreateSkBitmap(
ByteBuffer buffer, int width, int height, int colorType, int alphaType);

private static native void nativeAddImage(long sessionPointer, long imagePointer);

private static native void nativeDeleteSkBitmap(long imagePointer);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.google.mediapipe.tasks.genai.llminference;

import com.google.auto.value.AutoValue;

/** Configuration for the inference graph. */
@AutoValue
public abstract class GraphOptions {

/**
* Returns whether to configure the graph to include the token cost calculator, which allows users
* to only compute the cost of a prompt.
*/
public abstract boolean includeTokenCostCalculator();

/** Returns whether to configure the graph to include the vision modality. */
public abstract boolean enableVisionModality();

/** Returns a new {@link Builder} instance. */
public static Builder builder() {
return new AutoValue_GraphOptions.Builder()
.setIncludeTokenCostCalculator(true)
.setEnableVisionModality(false);
}

/** Builder for {@link GraphConfig}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets whether to configure the graph to include the token cost calculator. */
public abstract Builder setIncludeTokenCostCalculator(boolean includeTokenCostCalculator);

/** Sets whether to configure the graph to include the vision modality. */
public abstract Builder setEnableVisionModality(boolean enableVisionModality);

/** AutoValue generated builder method. */
abstract GraphOptions autoBuild();

/** Builds a new {@link GraphConfig} instance. */
public GraphOptions build() {
return autoBuild();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,28 @@ public final class LlmInference implements AutoCloseable {
/** Creates an LlmInference Task. */
public static LlmInference createFromOptions(Context context, LlmInferenceOptions options) {
// Configure LLM model settings.
LlmModelSettings modelSettings =
LlmModelSettings.Builder modelSettings =
LlmModelSettings.newBuilder()
.setModelPath(options.modelPath())
.setCacheDir(context.getCacheDir().getAbsolutePath())
.setNumDecodeStepsPerSync(NUM_DECODE_STEPS_PER_SYNC)
.setMaxTokens(options.maxTokens())
.setMaxTopK(options.maxTopK())
.setNumberOfSupportedLoraRanks(options.supportedLoraRanks().size())
.addAllSupportedLoraRanks(options.supportedLoraRanks())
.build();
.addAllSupportedLoraRanks(options.supportedLoraRanks());

return new LlmInference(context, STATS_TAG, modelSettings, options.resultListener());
if (options.visionModelOptions().isPresent()) {
VisionModelOptions visionModelOptions = options.visionModelOptions().get();

LlmModelSettings.VisionModelSettings.Builder visionModelSettings =
LlmModelSettings.VisionModelSettings.newBuilder();
visionModelOptions.getEncoderPath().ifPresent(visionModelSettings::setEncoderPath);
visionModelOptions.getAdapterPath().ifPresent(visionModelSettings::setAdapterPath);

modelSettings.setVisionModelSettings(visionModelSettings.build());
}

return new LlmInference(context, STATS_TAG, modelSettings.build(), options.resultListener());
}

/** Constructor to initialize an {@link LlmInference}. */
Expand Down Expand Up @@ -196,9 +206,12 @@ public abstract static class Builder {
*/
public abstract Builder setMaxTopK(int maxTopK);

/** The supported lora ranks for the base model. Used by GPU only. */
/** Sets the supported lora ranks for the base model. Used by GPU only. */
public abstract Builder setSupportedLoraRanks(List<Integer> supportedLoraRanks);

/** Sets the model options to use for vision modality. */
public abstract Builder setVisionModelOptions(VisionModelOptions visionModelOptions);

abstract LlmInferenceOptions autoBuild();

/** Validates and builds the {@link ImageGeneratorOptions} instance. */
Expand Down Expand Up @@ -232,6 +245,9 @@ public final LlmInferenceOptions build() {
/** The error listener to use for the {@link LlmInference#generateAsync} API. */
public abstract Optional<ErrorListener> errorListener();

/** The model options to for vision modality. */
public abstract Optional<VisionModelOptions> visionModelOptions();

/** Returns a new builder with the same values as this instance. */
public abstract Builder toBuilder();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.google.mediapipe.tasks.genai.llminference;

import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.LlmTaskRunner;
import com.google.mediapipe.tasks.core.LlmTaskRunner.LlmSession;
import com.google.mediapipe.tasks.core.jni.proto.LlmOptionsProto.LlmSessionConfig;
Expand Down Expand Up @@ -37,6 +38,16 @@ public static LlmInferenceSession createFromOptions(
sessionConfig.setLoraPath("");
}

if (options.graphOptions().isPresent()) {
GraphOptions graphOptions = options.graphOptions().get();
LlmSessionConfig.GraphConfig graphConfig =
LlmSessionConfig.GraphConfig.newBuilder()
.setIncludeTokenCostCalculator(graphOptions.includeTokenCostCalculator())
.setEnableVisionModality(graphOptions.enableVisionModality())
.build();
sessionConfig.setGraphConfig(graphConfig);
}

LlmTaskRunner taskRunner = llmInference.getTaskRunner();
LlmSession session = taskRunner.createSession(sessionConfig.build());
return new LlmInferenceSession(taskRunner, session);
Expand All @@ -60,6 +71,16 @@ public void addQueryChunk(String inputText) {
taskRunner.addQueryChunk(session, inputText);
}

/**
* Add an image to the session.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public void addImage(MPImage image) {
taskRunner.addImage(session, image);
}

/**
* Generates a response based on the previously added query chunks synchronously. Use {@link
* #addQueryChunk(String)} to add at least one query chunk before calling this function.
Expand Down Expand Up @@ -170,6 +191,9 @@ public abstract static class Builder {
*/
public abstract Builder setLoraPath(String loraPath);

/** Sets the parameters to customize the graph. */
public abstract Builder setGraphOptions(GraphOptions graphOptions);

abstract LlmInferenceSessionOptions autoBuild();

/** Validates and builds the {@link LlmInferenceSessionOptions} instance. */
Expand All @@ -196,6 +220,9 @@ public final LlmInferenceSessionOptions build() {
*/
public abstract Optional<String> loraPath();

/** Returns the parameters to customize the graph. */
public abstract Optional<GraphOptions> graphOptions();

/** Instantiates a new LlmInferenceOptions builder. */
public static Builder builder() {
return new AutoValue_LlmInferenceSession_LlmInferenceSessionOptions.Builder()
Expand Down
Loading

0 comments on commit 96e3a69

Please sign in to comment.