Skip to content

Commit

Permalink
SD3 GPU fix (#1362)
Browse files Browse the repository at this point in the history
  • Loading branch information
likholat authored Dec 11, 2024
1 parent 5aa13b1 commit 15176a8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
17 changes: 12 additions & 5 deletions src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,25 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {

set_scheduler(Scheduler::from_config(root_dir / "scheduler/scheduler_config.json"));

// Temporary fix for GPU
ov::AnyMap updated_properties = properties;
if (device.find("GPU") != std::string::npos &&
updated_properties.find("INFERENCE_PRECISION_HINT") == updated_properties.end()) {
updated_properties["INFERENCE_PRECISION_HINT"] = ov::element::f32;
}

const std::string text_encoder = data["text_encoder"][1].get<std::string>();
if (text_encoder == "CLIPTextModelWithProjection") {
m_clip_text_encoder_1 =
std::make_shared<CLIPTextModelWithProjection>(root_dir / "text_encoder", device, properties);
std::make_shared<CLIPTextModelWithProjection>(root_dir / "text_encoder", device, updated_properties);
} else {
OPENVINO_THROW("Unsupported '", text_encoder, "' text encoder type");
}

const std::string text_encoder_2 = data["text_encoder_2"][1].get<std::string>();
if (text_encoder_2 == "CLIPTextModelWithProjection") {
m_clip_text_encoder_2 =
std::make_shared<CLIPTextModelWithProjection>(root_dir / "text_encoder_2", device, properties);
std::make_shared<CLIPTextModelWithProjection>(root_dir / "text_encoder_2", device, updated_properties);
} else {
OPENVINO_THROW("Unsupported '", text_encoder_2, "' text encoder type");
}
Expand All @@ -151,7 +158,7 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
if (!text_encoder_3_json.is_null()) {
const std::string text_encoder_3 = text_encoder_3_json.get<std::string>();
if (text_encoder_3 == "T5EncoderModel") {
m_t5_text_encoder = std::make_shared<T5EncoderModel>(root_dir / "text_encoder_3", device, properties);
m_t5_text_encoder = std::make_shared<T5EncoderModel>(root_dir / "text_encoder_3", device, updated_properties);
} else {
OPENVINO_THROW("Unsupported '", text_encoder_3, "' text encoder type");
}
Expand All @@ -167,9 +174,9 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
const std::string vae = data["vae"][1].get<std::string>();
if (vae == "AutoencoderKL") {
if (m_pipeline_type == PipelineType::TEXT_2_IMAGE)
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_decoder", device, properties);
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_decoder", device, updated_properties);
else if (m_pipeline_type == PipelineType::IMAGE_2_IMAGE || m_pipeline_type == PipelineType::INPAINTING) {
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_encoder", root_dir / "vae_decoder", device, properties);
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_encoder", root_dir / "vae_decoder", device, updated_properties);
} else {
OPENVINO_ASSERT("Unsupported pipeline type");
}
Expand Down
10 changes: 5 additions & 5 deletions src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ class StableDiffusionXLPipeline : public StableDiffusionPipeline {
}

// Temporary fix for GPU
ov::AnyMap updated_roperties = properties;
ov::AnyMap updated_properties = properties;
if (device.find("GPU") != std::string::npos &&
updated_roperties.find("INFERENCE_PRECISION_HINT") == updated_roperties.end()) {
updated_roperties["INFERENCE_PRECISION_HINT"] = ov::element::f32;
updated_properties.find("INFERENCE_PRECISION_HINT") == updated_properties.end()) {
updated_properties["INFERENCE_PRECISION_HINT"] = ov::element::f32;
}

const std::string vae = data["vae"][1].get<std::string>();
if (vae == "AutoencoderKL") {
if (m_pipeline_type == PipelineType::TEXT_2_IMAGE)
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_decoder", device, properties);
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_decoder", device, updated_properties);
else if (m_pipeline_type == PipelineType::IMAGE_2_IMAGE || m_pipeline_type == PipelineType::INPAINTING) {
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_encoder", root_dir / "vae_decoder", device, properties);
m_vae = std::make_shared<AutoencoderKL>(root_dir / "vae_encoder", root_dir / "vae_decoder", device, updated_properties);
} else {
OPENVINO_ASSERT("Unsupported pipeline type");
}
Expand Down

0 comments on commit 15176a8

Please sign in to comment.