From aa428db08da56f5f64b601ff4aee80ce8fc1bea8 Mon Sep 17 00:00:00 2001 From: Stax124 Date: Sun, 26 Nov 2023 13:59:48 +0100 Subject: [PATCH] Fix diffusers samplers with ControlNet --- core/inference/pytorch/pipeline.py | 15 +++++++++++---- tests/inference/test_pytorch.py | 4 +--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/inference/pytorch/pipeline.py b/core/inference/pytorch/pipeline.py index e85b6c88..45e8f126 100644 --- a/core/inference/pytorch/pipeline.py +++ b/core/inference/pytorch/pipeline.py @@ -7,7 +7,7 @@ import torch from diffusers.models.adapter import MultiAdapter from diffusers.models.autoencoder_kl import AutoencoderKL -from diffusers.models.controlnet import ControlNetModel +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_output import ( StableDiffusionPipelineOutput, @@ -586,7 +586,6 @@ def do_denoise( controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, - return_dict=False, ) if guess_mode and do_classifier_free_guidance: @@ -721,12 +720,20 @@ def _call(*args, **kwargs): args = args[:2] if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") - return s( + out = s( *args, encoder_hidden_states=encoder_hidden_states, # type: ignore return_dict=True, **kwargs, - )[0] + ) + + if isinstance(out, ControlNetOutput): + down_block_res_samples = out.down_block_res_samples + mid_block_res_sample = out.mid_block_res_sample + + return down_block_res_samples, mid_block_res_sample + else: + return out[0] for i, t in enumerate(tqdm(timesteps, desc="PyTorch")): latents = do_denoise(latents, t, _call, change) # type: ignore diff --git a/tests/inference/test_pytorch.py b/tests/inference/test_pytorch.py index 3477f6aa..c687bee3 100644 --- a/tests/inference/test_pytorch.py +++ b/tests/inference/test_pytorch.py @@ -187,9 +187,7 @@ def test_inpaint(pipe: PyTorchStableDiffusion): pipe.generate(job) -@pytest.mark.parametrize( - "scheduler", [KarrasDiffusionSchedulers.UniPCMultistepScheduler, "dpmpp_2m"] -) +@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + ["dpmpp_2m"]) def test_controlnet(pipe: PyTorchStableDiffusion, scheduler): "Generate an image with ControlNet Image to Image"