diff --git a/mediapipe/tasks/cc/vision/image_generator/BUILD b/mediapipe/tasks/cc/vision/image_generator/BUILD index 71b8230ae7..7defb5a5bf 100644 --- a/mediapipe/tasks/cc/vision/image_generator/BUILD +++ b/mediapipe/tasks/cc/vision/image_generator/BUILD @@ -123,6 +123,7 @@ cc_library( "//mediapipe/tasks/cc/vision/face_landmarker", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto", @@ -131,6 +132,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", ], ) diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc index e10ccf0b19..d780fdb568 100644 --- a/mediapipe/tasks/cc/vision/image_generator/image_generator.cc +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "mediapipe/framework/api2/builder.h" @@ -31,6 +32,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h" #include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h" @@ -226,6 +228,18 @@ ConvertImageGeneratorGraphOptionsProto( auto& options_proto = *options_proto_and_condition_index.options_proto; options_proto.set_text2image_model_directory( image_generator_options->text2image_model_directory); + options_proto.mutable_stable_diffusion_iterate_options()->set_file_folder( + image_generator_options->text2image_model_directory); + switch (image_generator_options->model_type) { + case ImageGeneratorOptions::ModelType::SD_1: + options_proto.mutable_stable_diffusion_iterate_options()->set_model_type( + mediapipe::StableDiffusionIterateCalculatorOptions::SD_1); + break; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported ImageGenerator model type: %d", + image_generator_options->model_type)); + } if (image_generator_options->lora_weights_file_path.has_value()) { options_proto.mutable_lora_weights_file()->set_file_name( *image_generator_options->lora_weights_file_path); diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.h b/mediapipe/tasks/cc/vision/image_generator/image_generator.h index 52599c02fd..52cfb0c0b4 100644 --- a/mediapipe/tasks/cc/vision/image_generator/image_generator.h +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.h @@ -91,6 +91,12 @@ struct ImageGeneratorOptions { // The text to image model directory storing the model weights. std::string text2image_model_directory; + enum ModelType { + SD_1 = 1, // Stable Diffusion v1 models, including SD 1.4 and 1.5. + } + + model_type = ModelType::SD_1; + // The path to LoRA weights file. std::optional lora_weights_file_path; }; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java index d206190755..718b4c8f03 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java @@ -534,6 +534,11 @@ protected abstract static class ConditionImageResult implements TaskResult { @AutoValue public abstract static class ImageGeneratorOptions extends TaskOptions { + /** The supported model types. */ + public enum ModelType { + SD_1, // Stable Diffusion v1 models, including SD 1.4 and 1.5. + } + /** Builder for {@link ImageGeneratorOptions}. */ @AutoValue.Builder public abstract static class Builder { @@ -544,6 +549,9 @@ public abstract static class Builder { /** Sets the path to LoRA weights file. */ public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath); + /** Sets the model type. */ + public abstract Builder setModelType(ModelType modelType); + /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); @@ -559,13 +567,27 @@ public final ImageGeneratorOptions build() { abstract Optional loraWeightsFilePath(); + abstract ModelType modelType(); + abstract Optional errorListener(); private Optional conditionOptions; + private StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions + .ModelType + convertModelTypeToProto(ModelType modelType) { + switch (modelType) { + case SD_1: + return StableDiffusionIterateCalculatorOptionsProto + .StableDiffusionIterateCalculatorOptions.ModelType.SD_1; + } + throw new IllegalArgumentException("Unsupported model type: " + modelType.name()); + } + public static Builder builder() { return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder() - .setImageGeneratorModelDirectory(""); + .setImageGeneratorModelDirectory("") + .setModelType(ModelType.SD_1); } /** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */ @@ -594,6 +616,7 @@ public Any convertToAnyProto() { .newBuilder() .setBaseSeed(0) .setFileFolder(imageGeneratorModelDirectory()) + .setModelType(convertModelTypeToProto(modelType())) .setOutputImageWidth(GENERATED_IMAGE_WIDTH) .setOutputImageHeight(GENERATED_IMAGE_HEIGHT) .setEmitEmptyPacket(true) @@ -685,17 +708,17 @@ public Any convertToAnyProto() { .build()); } if (depthConditionOptions().isPresent()) { - taskOptionsBuilder.addControlPluginGraphsOptions( - ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() - .setBaseOptions( - convertBaseOptionsToProto( - depthConditionOptions().get().pluginModelBaseOptions())) - .setConditionedImageGraphOptions( - ConditionedImageGraphOptions.newBuilder() - .setDepthConditionTypeOptions( - depthConditionOptions().get().convertToProto()) - .build()) - .build()); + taskOptionsBuilder.addControlPluginGraphsOptions( + ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() + .setBaseOptions( + convertBaseOptionsToProto( + depthConditionOptions().get().pluginModelBaseOptions())) + .setConditionedImageGraphOptions( + ConditionedImageGraphOptions.newBuilder() + .setDepthConditionTypeOptions( + depthConditionOptions().get().convertToProto()) + .build()) + .build()); } return Any.newBuilder() .setTypeUrl(