Skip to content

Commit

Permalink
Fix diffusers samplers with ControlNet
Browse files Browse the repository at this point in the history
  • Loading branch information
Stax124 committed Nov 26, 2023
1 parent f43bc33 commit aa428db
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 11 additions & 4 deletions core/inference/pytorch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from diffusers.models.adapter import MultiAdapter
from diffusers.models.autoencoder_kl import AutoencoderKL
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_output import (
StableDiffusionPipelineOutput,
Expand Down Expand Up @@ -586,7 +586,6 @@ def do_denoise(
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)

if guess_mode and do_classifier_free_guidance:
Expand Down Expand Up @@ -721,12 +720,20 @@ def _call(*args, **kwargs):
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
return s(
out = s(
*args,
encoder_hidden_states=encoder_hidden_states, # type: ignore
return_dict=True,
**kwargs,
)[0]
)

if isinstance(out, ControlNetOutput):
down_block_res_samples = out.down_block_res_samples
mid_block_res_sample = out.mid_block_res_sample

return down_block_res_samples, mid_block_res_sample
else:
return out[0]

for i, t in enumerate(tqdm(timesteps, desc="PyTorch")):
latents = do_denoise(latents, t, _call, change) # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions tests/inference/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def test_inpaint(pipe: PyTorchStableDiffusion):
pipe.generate(job)


@pytest.mark.parametrize(
"scheduler", [KarrasDiffusionSchedulers.UniPCMultistepScheduler, "dpmpp_2m"]
)
@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + ["dpmpp_2m"])
def test_controlnet(pipe: PyTorchStableDiffusion, scheduler):
"Generate an image with ControlNet Image to Image"

Expand Down

0 comments on commit aa428db

Please sign in to comment.