Skip to content

Commit

Permalink
Add model type to ImageGeneratorOptions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632990951
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 12, 2024
1 parent 9fcc392 commit 4204859
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
2 changes: 2 additions & 0 deletions mediapipe/tasks/cc/vision/image_generator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto",
Expand All @@ -131,6 +132,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
)
14 changes: 14 additions & 0 deletions mediapipe/tasks/cc/vision/image_generator/image_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <string>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "mediapipe/framework/api2/builder.h"
Expand All @@ -31,6 +32,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h"
Expand Down Expand Up @@ -226,6 +228,18 @@ ConvertImageGeneratorGraphOptionsProto(
auto& options_proto = *options_proto_and_condition_index.options_proto;
options_proto.set_text2image_model_directory(
image_generator_options->text2image_model_directory);
options_proto.mutable_stable_diffusion_iterate_options()->set_file_folder(
image_generator_options->text2image_model_directory);
switch (image_generator_options->model_type) {
case ImageGeneratorOptions::ModelType::SD_1:
options_proto.mutable_stable_diffusion_iterate_options()->set_model_type(
mediapipe::StableDiffusionIterateCalculatorOptions::SD_1);
break;
default:
return absl::InvalidArgumentError(
absl::StrFormat("Unsupported ImageGenerator model type: %d",
image_generator_options->model_type));
}
if (image_generator_options->lora_weights_file_path.has_value()) {
options_proto.mutable_lora_weights_file()->set_file_name(
*image_generator_options->lora_weights_file_path);
Expand Down
6 changes: 6 additions & 0 deletions mediapipe/tasks/cc/vision/image_generator/image_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ struct ImageGeneratorOptions {
// The text to image model directory storing the model weights.
std::string text2image_model_directory;

enum ModelType {
SD_1 = 1, // Stable Diffusion v1 models, including SD 1.4 and 1.5.
}

model_type = ModelType::SD_1;

// The path to LoRA weights file.
std::optional<std::string> lora_weights_file_path;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ protected abstract static class ConditionImageResult implements TaskResult {
@AutoValue
public abstract static class ImageGeneratorOptions extends TaskOptions {

/** The supported model types. */
public enum ModelType {
SD_1, // Stable Diffusion v1 models, including SD 1.4 and 1.5.
}

/** Builder for {@link ImageGeneratorOptions}. */
@AutoValue.Builder
public abstract static class Builder {
Expand All @@ -544,6 +549,9 @@ public abstract static class Builder {
/** Sets the path to LoRA weights file. */
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);

/** Sets the model type. */
public abstract Builder setModelType(ModelType modelType);

/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);

Expand All @@ -559,13 +567,27 @@ public final ImageGeneratorOptions build() {

abstract Optional<String> loraWeightsFilePath();

abstract ModelType modelType();

abstract Optional<ErrorListener> errorListener();

private Optional<ConditionOptions> conditionOptions;

private StableDiffusionIterateCalculatorOptionsProto.StableDiffusionIterateCalculatorOptions
.ModelType
convertModelTypeToProto(ModelType modelType) {
switch (modelType) {
case SD_1:
return StableDiffusionIterateCalculatorOptionsProto
.StableDiffusionIterateCalculatorOptions.ModelType.SD_1;
}
throw new IllegalArgumentException("Unsupported model type: " + modelType.name());
}

public static Builder builder() {
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
.setImageGeneratorModelDirectory("");
.setImageGeneratorModelDirectory("")
.setModelType(ModelType.SD_1);
}

/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
Expand Down Expand Up @@ -594,6 +616,7 @@ public Any convertToAnyProto() {
.newBuilder()
.setBaseSeed(0)
.setFileFolder(imageGeneratorModelDirectory())
.setModelType(convertModelTypeToProto(modelType()))
.setOutputImageWidth(GENERATED_IMAGE_WIDTH)
.setOutputImageHeight(GENERATED_IMAGE_HEIGHT)
.setEmitEmptyPacket(true)
Expand Down Expand Up @@ -685,17 +708,17 @@ public Any convertToAnyProto() {
.build());
}
if (depthConditionOptions().isPresent()) {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(
depthConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setDepthConditionTypeOptions(
depthConditionOptions().get().convertToProto())
.build())
.build());
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(
depthConditionOptions().get().pluginModelBaseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setDepthConditionTypeOptions(
depthConditionOptions().get().convertToProto())
.build())
.build());
}
return Any.newBuilder()
.setTypeUrl(
Expand Down

0 comments on commit 4204859

Please sign in to comment.