From cd6ca9df2987c000b28e13b19bd4eec3ef3c914b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 21 Nov 2024 13:02:31 +0530 Subject: [PATCH 1/8] Fix prepare latent image ids and vae sample generators for flux (#9981) * fix * update expected slice --- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- .../flux/pipeline_flux_controlnet.py | 20 ++++++++++++++++--- ...pipeline_flux_controlnet_image_to_image.py | 4 ++-- .../pipeline_flux_controlnet_inpainting.py | 4 ++-- .../controlnet_flux/test_controlnet_flux.py | 2 +- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 12996f3f3e92..e0add1e60ce2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -513,7 +513,7 @@ def prepare_latents( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 904173852ee4..654bc41af4d0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -97,6 +97,20 @@ def calculate_shift( return mu +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -512,7 +526,7 @@ def prepare_latents( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: @@ -772,7 +786,7 @@ def __call__( controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack @@ -810,7 +824,7 @@ def __call__( if self.controlnet.nets[0].input_hint_block is None: # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 5d65df0b768e..6ab34d8a9c08 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -801,7 +801,7 @@ def __call__( ) height, width = control_image.shape[-2:] - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor height_control_image, width_control_image = control_image.shape[2:] @@ -832,7 +832,7 @@ def __call__( ) height, width = control_image_.shape[-2:] - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor height_control_image, width_control_image = control_image_.shape[2:] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 5d5c8f73762c..d81cffaca35b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -942,7 +942,7 @@ def __call__( controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack @@ -979,7 +979,7 @@ def __call__( if self.controlnet.nets[0].input_hint_block is None: # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index ee3984dcd3e2..8202424e7f15 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -170,7 +170,7 @@ def test_controlnet_flux(self): assert image.shape == (1, 32, 32, 3) expected_slice = np.array( - [0.7348633, 0.41333008, 0.6621094, 0.5444336, 0.47607422, 0.5859375, 0.44677734, 0.4506836, 0.40454102] + [0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156] ) assert ( From 2e86a3f0235cb41b212417d84b9c2cd46d8c1297 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 22 Nov 2024 12:45:21 +0530 Subject: [PATCH 2/8] [Tests] skip nan lora tests on PyTorch 2.5.1 CPU. (#9975) * skip nan lora tests on PyTorch 2.5.1 CPU. * cog * use xfail * correct xfail * add condition * tests --- tests/lora/test_lora_layers_cogvideox.py | 7 +++++++ tests/lora/test_lora_layers_mochi.py | 7 +++++++ tests/lora/utils.py | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index c141ebc96b3e..623b06621d66 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -29,6 +30,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, is_peft_available, + is_torch_version, require_peft_backend, skip_mps, torch_device, @@ -126,6 +128,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs @skip_mps + @pytest.mark.xfail( + condtion=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) def test_lora_fuse_nan(self): for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index eb15124601c6..910b126c147b 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -23,6 +24,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, is_peft_available, + is_torch_version, require_peft_backend, skip_mps, torch_device, @@ -105,6 +107,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs + @pytest.mark.xfail( + condtion=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) def test_lora_fuse_nan(self): for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7cdb2d6f51d7..d8dc86d57007 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -19,6 +19,7 @@ from itertools import product import numpy as np +import pytest import torch from diffusers import ( @@ -32,6 +33,7 @@ from diffusers.utils.testing_utils import ( CaptureLogger, floats_tensor, + is_torch_version, require_peft_backend, require_peft_version_greater, require_transformers_version_greater, @@ -1510,6 +1512,11 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): ) @skip_mps + @pytest.mark.xfail( + condtion=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), + reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", + strict=True, + ) def test_lora_fuse_nan(self): for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) From 64b3e0f5390728f62887be7820a5e2724d0fb419 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 22 Nov 2024 18:02:54 +0800 Subject: [PATCH 3/8] make `pipelines` tests device-agnostic (part1) (#9399) * enable on xpu * add 1 more * add one more * enable more * add 1 more * add more * enable 1 * enable more cases * enable * enable * update comment * one more * enable 1 * add more cases * enable xpu * add one more caswe * add more cases * add 1 * add more * add more cases * add case * enable * add more * add more * add more * enbale more * add more * update code * update test marker * add skip back * update comment * remove single files * remove * style * add * revert * reformat * update decorator * update * update * update * Update tests/pipelines/deepfloyd_if/test_if.py Co-authored-by: Dhruv Nair * Update src/diffusers/utils/testing_utils.py Co-authored-by: Dhruv Nair * Update tests/pipelines/animatediff/test_animatediff_controlnet.py Co-authored-by: Dhruv Nair * Update tests/pipelines/animatediff/test_animatediff.py Co-authored-by: Dhruv Nair * Update tests/pipelines/animatediff/test_animatediff_controlnet.py Co-authored-by: Dhruv Nair * update float16 * no unitest.skipt * update * apply style check * reapply format --------- Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- src/diffusers/utils/testing_utils.py | 8 +++ tests/pipelines/amused/test_amused.py | 4 +- tests/pipelines/amused/test_amused_img2img.py | 4 +- tests/pipelines/amused/test_amused_inpaint.py | 4 +- .../pipelines/animatediff/test_animatediff.py | 18 +++-- .../test_animatediff_controlnet.py | 12 ++-- .../animatediff/test_animatediff_sdxl.py | 12 ++-- .../test_animatediff_sparsectrl.py | 10 +-- .../test_animatediff_video2video.py | 12 ++-- ...test_animatediff_video2video_controlnet.py | 10 +-- .../controlnet/test_controlnet_sdxl.py | 2 +- .../controlnet_xs/test_controlnetxs.py | 11 +-- tests/pipelines/deepfloyd_if/test_if.py | 12 +++- .../pipelines/deepfloyd_if/test_if_img2img.py | 16 ++++- .../test_if_img2img_superresolution.py | 13 +++- .../deepfloyd_if/test_if_inpainting.py | 13 +++- .../test_if_inpainting_superresolution.py | 13 +++- .../deepfloyd_if/test_if_superresolution.py | 13 +++- .../test_latent_diffusion_superresolution.py | 3 +- tests/pipelines/pag/test_pag_animatediff.py | 12 ++-- tests/pipelines/pia/test_pia.py | 12 ++-- .../test_semantic_diffusion.py | 3 +- ...test_stable_diffusion_attend_and_excite.py | 6 +- .../test_stable_diffusion_depth.py | 17 ++--- .../test_stable_diffusion_upscale.py | 3 +- .../test_stable_diffusion_v_pred.py | 3 +- .../test_safe_diffusion.py | 4 +- .../test_stable_video_diffusion.py | 33 +++++---- tests/pipelines/test_pipelines_common.py | 68 +++++++++---------- .../test_text_to_video_zero_sdxl.py | 34 ++++++---- 30 files changed, 229 insertions(+), 156 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 03b9c3752922..b3e381f7d3fb 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -373,6 +373,14 @@ def require_note_seq(test_case): return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) +def require_accelerator(test_case): + """ + Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no + hardware accelerator available. + """ + return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case) + + def require_torchsde(test_case): """ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index 32f3e13ad911..f28d8708d309 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -22,7 +22,7 @@ from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -129,7 +129,7 @@ def test_inference_batch_single_identical(self): @slow -@require_torch_gpu +@require_torch_accelerator class AmusedPipelineSlowTests(unittest.TestCase): def test_amused_256(self): pipe = AmusedPipeline.from_pretrained("amused/amused-256") diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py index c647a5aa304e..2699bbe7f56f 100644 --- a/tests/pipelines/amused/test_amused_img2img.py +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -23,7 +23,7 @@ from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -131,7 +131,7 @@ def test_inference_batch_single_identical(self): @slow -@require_torch_gpu +@require_torch_accelerator class AmusedImg2ImgPipelineSlowTests(unittest.TestCase): def test_amused_256(self): pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256") diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index 4a8d501450bb..645379a7eab1 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -23,7 +23,7 @@ from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -135,7 +135,7 @@ def test_inference_batch_single_identical(self): @slow -@require_torch_gpu +@require_torch_accelerator class AmusedInpaintPipelineSlowTests(unittest.TestCase): def test_amused_256(self): pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256") diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 54c83d6a1b68..c382bb5b7f30 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -19,7 +19,13 @@ ) from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging -from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + numpy_cosine_similarity_distance, + require_accelerator, + require_torch_gpu, + slow, + torch_device, +) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -272,7 +278,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -288,14 +294,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 519d848c6dc2..6fcf6fe44fb7 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -21,7 +21,7 @@ from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -281,7 +281,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -297,14 +297,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py index 2db0139154e9..45fa6bfc5c6d 100644 --- a/tests/pipelines/animatediff/test_animatediff_sdxl.py +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -14,7 +14,7 @@ UNetMotionModel, ) from diffusers.utils import is_xformers_available, logging -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -212,7 +212,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -228,14 +228,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py index 189d6765de4f..21b59d0252b2 100644 --- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py +++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py @@ -20,7 +20,7 @@ ) from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -345,7 +345,7 @@ def test_inference_batch_single_identical_use_simplified_condition_embedding_tru max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -361,13 +361,13 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0] self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_to_dtype(self): diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index c3fd4c73736a..bb1cb9882c69 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -19,7 +19,7 @@ ) from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin @@ -258,7 +258,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -274,14 +274,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py index 5e598e67ec11..5a4b507aff50 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py @@ -20,7 +20,7 @@ ) from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin @@ -274,7 +274,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -290,13 +290,13 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0] self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_to_dtype(self): diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index c931391ac4d5..ea7fff5537a5 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -1019,7 +1019,7 @@ def test_conditioning_channels(self): ) controlnet = ControlNetModel.from_unet(unet, conditioning_channels=4) - assert type(controlnet.mid_block) == UNetMidBlock2D + assert type(controlnet.mid_block) is UNetMidBlock2D assert controlnet.conditioning_channels == 4 def get_dummy_components(self, time_cond_proj_dim=None): diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index bb0306741fdb..007a2b0e46d7 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -38,6 +38,7 @@ is_torch_compile, load_image, load_numpy, + require_accelerator, require_torch_2, require_torch_gpu, run_test_in_subprocess, @@ -306,7 +307,7 @@ def test_multi_vae(self): assert out_vae_np.shape == out_np.shape - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -322,14 +323,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) @slow diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index 0818665ea113..13a05855f145 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -23,7 +23,14 @@ ) from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference @@ -58,7 +65,8 @@ def get_dummy_inputs(self, device, seed=0): def test_save_load_optional_components(self): self._test_save_load_optional_components() - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index b71cb05e50ae..26ac42831b8b 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -22,7 +22,15 @@ from diffusers import IFImg2ImgPipeline from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import ( TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, @@ -70,12 +78,14 @@ def test_save_load_optional_components(self): def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1e-1) diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py index dc0cf9826b62..1d1244c96c33 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -22,7 +22,15 @@ from diffusers import IFImg2ImgSuperResolutionPipeline from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import ( TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, @@ -72,7 +80,8 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_save_load_optional_components(self): self._test_save_load_optional_components() - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py index df0cecd8c307..1c4f27403332 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -22,7 +22,15 @@ from diffusers import IFInpaintingPipeline from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import ( TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, @@ -72,7 +80,8 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_save_load_optional_components(self): self._test_save_load_optional_components() - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py index 2e9f64773289..fc1b04aacb9b 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -22,7 +22,15 @@ from diffusers import IFInpaintingSuperResolutionPipeline from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import ( TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, @@ -74,7 +82,8 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_save_load_optional_components(self): self._test_save_load_optional_components() - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py index 2e3c8c6e0e15..bdb9f8a76d8a 100644 --- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -22,7 +22,15 @@ from diffusers import IFSuperResolutionPipeline from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + load_numpy, + require_accelerator, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference @@ -67,7 +75,8 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_save_load_optional_components(self): self._test_save_load_optional_components() - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder super().test_save_load_float16(expected_max_diff=1e-1) diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py index 9b9a8ef65572..38ac6a46ccca 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py @@ -26,6 +26,7 @@ floats_tensor, load_image, nightly, + require_accelerator, require_torch, torch_device, ) @@ -93,7 +94,7 @@ def test_inference_superresolution(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + @require_accelerator def test_inference_superresolution_fp16(self): unet = self.dummy_uncond_unet scheduler = DDIMScheduler() diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 7efe8002d17c..59ce9cc0a987 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -19,7 +19,7 @@ ) from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_accelerator, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -218,7 +218,7 @@ def test_dict_tuple_outputs_equivalent(self): expected_slice = np.array([0.5295, 0.3947, 0.5300, 0.4864, 0.4518, 0.5315, 0.5440, 0.4775, 0.5538]) return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice) - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -234,14 +234,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index ca558fbb83e5..e461860eff65 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -18,7 +18,7 @@ UNetMotionModel, ) from diffusers.utils import is_xformers_available, logging -from diffusers.utils.testing_utils import floats_tensor, torch_device +from diffusers.utils.testing_utils import floats_tensor, require_accelerator, torch_device from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin @@ -278,7 +278,7 @@ def test_inference_batch_single_identical( max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() assert max_diff < expected_max_diff - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -294,14 +294,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device))[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index 990c389a9c5f..6cd431f02d58 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -28,6 +28,7 @@ enable_full_determinism, floats_tensor, nightly, + require_accelerator, require_torch_gpu, torch_device, ) @@ -237,7 +238,7 @@ def test_semantic_diffusion_no_safety_checker(self): image = pipe("example prompt", num_inference_steps=2).images[0] assert image is not None - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + @require_accelerator def test_semantic_diffusion_fp16(self): """Test that stable diffusion works with fp16""" unet = self.dummy_cond_unet diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 4c2b3a3c1e85..1caad9500b24 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -30,7 +30,7 @@ load_numpy, nightly, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, skip_mps, torch_device, ) @@ -205,7 +205,7 @@ def test_from_pipe_consistent_forward_pass_cpu_offload(self): super().test_from_pipe_consistent_forward_pass_cpu_offload(expected_max_diff=5e-3) -@require_torch_gpu +@require_torch_accelerator @nightly class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): # Attend and excite requires being able to run a backward pass at @@ -237,7 +237,7 @@ def test_attend_and_excite_fp16(self): pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 ) - pipe.to("cuda") + pipe.to(torch_device) prompt = "a painting of an elephant with glasses" token_indices = [5, 7] diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 42eef061069e..01a0a3abe4ee 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -36,13 +36,14 @@ StableDiffusionDepth2ImgPipeline, UNet2DConditionModel, ) -from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, load_image, load_numpy, nightly, + require_accelerate_version_greater, + require_accelerator, require_torch_gpu, skip_mps, slow, @@ -194,7 +195,8 @@ def test_save_load_local(self): max_diff = np.abs(output - output_loaded).max() self.assertLess(max_diff, 1e-4) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self): components = self.get_dummy_components() for name, module in components.items(): @@ -226,7 +228,8 @@ def test_save_load_float16(self): max_diff = np.abs(output - output_loaded).max() self.assertLess(max_diff, 2e-2, "The output of the fp16 pipeline changed after saving and loading.") - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_float16_inference(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -246,10 +249,8 @@ def test_float16_inference(self): max_diff = np.abs(output - output_fp16).max() self.assertLess(max_diff, 1.3e-2, "The outputs of the fp16 and fp32 pipelines are too different.") - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.14.0") def test_cpu_offload_forward_pass(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -259,7 +260,7 @@ def test_cpu_offload_forward_pass(self): inputs = self.get_dummy_inputs(torch_device) output_without_offload = pipe(**inputs)[0] - pipe.enable_sequential_cpu_offload() + pipe.enable_sequential_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(torch_device) output_with_offload = pipe(**inputs)[0] diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index c21da7af6d2c..4b04169a270b 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -29,6 +29,7 @@ floats_tensor, load_image, load_numpy, + require_accelerator, require_torch_gpu, slow, torch_device, @@ -289,7 +290,7 @@ def test_stable_diffusion_upscale_prompt_embeds(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2 - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + @require_accelerator def test_stable_diffusion_upscale_fp16(self): """Test that stable diffusion upscale works with fp16""" unet = self.dummy_cond_unet_upscale diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 703c3b7a39d8..d69d1c492548 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -34,6 +34,7 @@ enable_full_determinism, load_numpy, numpy_cosine_similarity_distance, + require_accelerator, require_torch_gpu, slow, torch_device, @@ -213,7 +214,7 @@ def test_stable_diffusion_v_pred_k_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + @require_accelerator def test_stable_diffusion_v_pred_fp16(self): """Test that stable diffusion v-prediction works with fp16""" unet = self.dummy_cond_unet diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index ccb20a1c218e..269677c08345 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline -from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import floats_tensor, nightly, require_accelerator, require_torch_gpu, torch_device class SafeDiffusionPipelineFastTests(unittest.TestCase): @@ -228,7 +228,7 @@ def test_stable_diffusion_no_safety_checker(self): image = pipe("example prompt", num_inference_steps=2).images[0] assert image is not None - @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + @require_accelerator def test_stable_diffusion_fp16(self): """Test that stable diffusion works with fp16""" unet = self.dummy_cond_unet diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 60fc21e2027b..ac9acb26afd3 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -18,13 +18,15 @@ StableVideoDiffusionPipeline, UNetSpatioTemporalConditionModel, ) -from diffusers.utils import is_accelerate_available, is_accelerate_version, load_image, logging +from diffusers.utils import load_image, logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, enable_full_determinism, floats_tensor, numpy_cosine_similarity_distance, + require_accelerate_version_greater, + require_accelerator, require_torch_gpu, slow, torch_device, @@ -250,7 +252,8 @@ def test_float16_inference(self, expected_max_diff=5e-2): max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): @@ -366,7 +369,7 @@ def test_save_load_local(self, expected_max_difference=9e-4): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -381,14 +384,14 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu")).frames[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda")).frames[0] - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs(torch_device)).frames[0] + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) def test_to_dtype(self): components = self.get_dummy_components() @@ -402,10 +405,8 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.14.0") def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -419,7 +420,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): inputs = self.get_dummy_inputs(generator_device) output_without_offload = pipe(**inputs).frames[0] - pipe.enable_sequential_cpu_offload() + pipe.enable_sequential_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs).frames[0] @@ -427,10 +428,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.17.0") def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): generator_device = "cpu" components = self.get_dummy_components() @@ -446,7 +445,7 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): inputs = self.get_dummy_inputs(generator_device) output_without_offload = pipe(**inputs).frames[0] - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs).frames[0] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 12f31aec678b..7ec677558059 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -38,9 +38,11 @@ from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging -from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, + require_accelerate_version_greater, + require_accelerator, require_torch, skip_mps, torch_device, @@ -770,17 +772,15 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3): type(proc) == AttnProcessor for proc in component.attn_processors.values() ), "`from_pipe` changed the attention processor in original pipeline." - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.14.0") def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1e-3): components = self.get_dummy_components() pipe = self.pipeline_class(**components) for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs_pipe(torch_device) output = pipe(**inputs)[0] @@ -815,7 +815,7 @@ def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1 if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - pipe_from_original.enable_model_cpu_offload() + pipe_from_original.enable_model_cpu_offload(device=torch_device) pipe_from_original.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs_pipe(torch_device) output_from_original = pipe_from_original(**inputs)[0] @@ -1202,7 +1202,8 @@ def test_components_function(self): self.assertTrue(hasattr(pipe, "components")) self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_float16_inference(self, expected_max_diff=5e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -1239,7 +1240,8 @@ def test_float16_inference(self, expected_max_diff=5e-2): max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): @@ -1320,7 +1322,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -1333,11 +1335,11 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0] self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) def test_to_dtype(self): @@ -1394,10 +1396,8 @@ def _test_attention_slicing_forward_pass( assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0])) assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0])) - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.14.0") def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): import accelerate @@ -1413,8 +1413,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): inputs = self.get_dummy_inputs(generator_device) output_without_offload = pipe(**inputs)[0] - pipe.enable_sequential_cpu_offload() - assert pipe._execution_device.type == "cuda" + pipe.enable_sequential_cpu_offload(device=torch_device) + assert pipe._execution_device.type == torch_device inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] @@ -1457,10 +1457,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.17.0") def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): import accelerate @@ -1478,8 +1476,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): inputs = self.get_dummy_inputs(generator_device) output_without_offload = pipe(**inputs)[0] - pipe.enable_model_cpu_offload() - assert pipe._execution_device.type == "cuda" + pipe.enable_model_cpu_offload(device=torch_device) + assert pipe._execution_device.type == torch_device inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] @@ -1514,10 +1512,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.17.0") def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): import accelerate @@ -1531,11 +1527,11 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): pipe.set_progress_bar_config(disable=None) - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload_twice = pipe(**inputs)[0] @@ -1571,10 +1567,8 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.14.0") def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): import accelerate @@ -1588,11 +1582,11 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): pipe.set_progress_bar_config(disable=None) - pipe.enable_sequential_cpu_offload() + pipe.enable_sequential_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] - pipe.enable_sequential_cpu_offload() + pipe.enable_sequential_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(generator_device) output_with_offload_twice = pipe(**inputs)[0] diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py index 8bef0cede154..db24767b60fc 100644 --- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py +++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py @@ -23,8 +23,14 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel -from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version -from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + nightly, + require_accelerate_version_greater, + require_accelerator, + require_torch_gpu, + torch_device, +) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin @@ -213,7 +219,8 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() self.assertLess(max_diff, expected_max_difference) - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_float16_inference(self, expected_max_diff=5e-2): components = self.get_dummy_components() for name, module in components.items(): @@ -255,10 +262,8 @@ def test_inference_batch_consistent(self): def test_inference_batch_single_identical(self): pass - @unittest.skipIf( - torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), - reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", - ) + @require_accelerator + @require_accelerate_version_greater("0.17.0") def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -268,7 +273,7 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): inputs = self.get_dummy_inputs(self.generator_device) output_without_offload = pipe(**inputs)[0] - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_dummy_inputs(self.generator_device) output_with_offload = pipe(**inputs)[0] @@ -279,7 +284,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): def test_pipeline_call_signature(self): pass - @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): @@ -331,7 +337,7 @@ def test_save_load_optional_components(self): def test_sequential_cpu_offload_forward_pass(self): pass - @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + @require_accelerator def test_to_device(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -344,12 +350,12 @@ def test_to_device(self): output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu self.assertTrue(np.isnan(output_cpu).sum() == 0) - pipe.to("cuda") + pipe.to(torch_device) model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] - self.assertTrue(all(device == "cuda" for device in model_devices)) + self.assertTrue(all(device == torch_device for device in model_devices)) - output_cuda = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu - self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + output_device = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu + self.assertTrue(np.isnan(to_np(output_device)).sum() == 0) @unittest.skip( reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor." From b5fd6f13f5434d69d919cc8cedf0b11db664cf06 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 22 Nov 2024 12:22:52 +0000 Subject: [PATCH 4/8] ControlNet from_single_file when already converted (#9978) Co-authored-by: Dhruv Nair --- src/diffusers/loaders/single_file_utils.py | 27 +++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d1bad8b5a7cd..9a460cb5d1ef 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -62,7 +62,14 @@ "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", - "controlnet": "control_model.time_embed.0.weight", + "controlnet": [ + "control_model.time_embed.0.weight", + "controlnet_cond_embedding.conv_in.weight", + ], + # TODO: find non-Diffusers keys for controlnet_xl + "controlnet_xl": "add_embedding.linear_1.weight", + "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight", "playground-v2-5": "edm_mean", "inpainting": "model.diffusion_model.input_blocks.0.0.weight", "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", @@ -96,6 +103,9 @@ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"}, "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, + "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"}, + "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"}, + "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"}, "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"}, "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, @@ -481,8 +491,16 @@ def infer_diffusers_model_type(checkpoint): elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: model_type = "upscale" - elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: - model_type = "controlnet" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]): + if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint: + if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint: + model_type = "controlnet_xl_large" + elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint: + model_type = "controlnet_xl_mid" + else: + model_type = "controlnet_xl_small" + else: + model_type = "controlnet" elif ( CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint @@ -1072,6 +1090,9 @@ def convert_controlnet_checkpoint( config, **kwargs, ): + # Return checkpoint if it's already been converted + if "time_embedding.linear_1.weight" in checkpoint: + return checkpoint # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: From 7ac6e286ee994270e737b70c904ea50049d53567 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 23 Nov 2024 17:11:25 +0530 Subject: [PATCH 5/8] Flux Fill, Canny, Depth, Redux (#9985) * update --------- Co-authored-by: yiyixuxu Co-authored-by: Sayak Paul --- docs/source/en/api/pipelines/flux.md | 153 ++- scripts/convert_flux_to_diffusers.py | 7 +- src/diffusers/__init__.py | 10 + .../models/transformers/transformer_flux.py | 7 +- src/diffusers/pipelines/__init__.py | 10 + src/diffusers/pipelines/flux/__init__.py | 12 +- src/diffusers/pipelines/flux/modeling_flux.py | 47 + .../pipelines/flux/pipeline_flux_control.py | 891 ++++++++++++++++ .../flux/pipeline_flux_control_img2img.py | 946 +++++++++++++++++ .../flux/pipeline_flux_controlnet.py | 1 + .../pipelines/flux/pipeline_flux_fill.py | 970 ++++++++++++++++++ .../flux/pipeline_flux_prior_redux.py | 405 ++++++++ .../pipelines/flux/pipeline_output.py | 16 + .../dummy_torch_and_transformers_objects.py | 75 ++ .../flux/test_pipeline_flux_control.py | 203 ++++ .../test_pipeline_flux_control_img2img.py | 168 +++ .../pipelines/flux/test_pipeline_flux_fill.py | 168 +++ .../flux/test_pipeline_flux_redux.py | 108 ++ 18 files changed, 4189 insertions(+), 8 deletions(-) create mode 100644 src/diffusers/pipelines/flux/modeling_flux.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_control.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_fill.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_control.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_control_img2img.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_fill.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_redux.py diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 255c69c854bc..011972bc59dd 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca -Flux comes in two variants: +Flux comes in the following variants: -* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`) -* Guidance-distilled (`black-forest-labs/FLUX.1-dev`) +| model type | model id | +|:----------:|:--------:| +| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | +| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) | +| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) | +| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) | +| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) | +| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) | +| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) | +| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) | -Both checkpoints have slightly difference usage which we detail below. +All checkpoints have different usage which we detail below. ### Timestep-distilled @@ -77,7 +85,132 @@ out = pipe( out.save("image.png") ``` +### Fill Inpainting/Outpainting + +* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines. +* It supports both inpainting and outpainting. + +```python +import torch +from diffusers import FluxFillPipeline +from diffusers.utils import load_image + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png") +mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png") + +repo_id = "black-forest-labs/FLUX.1-Fill-dev" +pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda") + +image = pipe( + prompt="a white paper cup", + image=image, + mask_image=mask, + height=1632, + width=1232, + max_sequence_length=512, + generator=torch.Generator("cpu").manual_seed(0) +).images[0] +image.save(f"output.png") +``` + +### Canny Control + +**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. + +```python +# !pip install -U controlnet-aux +import torch +from controlnet_aux import CannyDetector +from diffusers import FluxControlPipeline +from diffusers.utils import load_image + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = CannyDetector() +control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024) + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=30.0, +).images[0] +image.save("output.png") +``` + +### Depth Control + +**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. + +```python +# !pip install git+https://github.com/asomoza/image_gen_aux.git +import torch +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils import load_image +from image_gen_aux import DepthPreprocessor + +pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda") + +prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." +control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + +processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") +control_image = processor(control_image)[0].convert("RGB") + +image = pipe( + prompt=prompt, + control_image=control_image, + height=1024, + width=1024, + num_inference_steps=30, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), +).images[0] +image.save("output.png") +``` + +### Redux + +* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation. +* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation. +* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM. + +```python +import torch +from diffusers import FluxPriorReduxPipeline, FluxPipeline +from diffusers.utils import load_image +device = "cuda" +dtype = torch.bfloat16 + + +repo_redux = "black-forest-labs/FLUX.1-Redux-dev" +repo_base = "black-forest-labs/FLUX.1-dev" +pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device) +pipe = FluxPipeline.from_pretrained( + repo_base, + text_encoder=None, + text_encoder_2=None, + torch_dtype=torch.bfloat16 +).to(device) + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png") +pipe_prior_output = pipe_prior_redux(image) +images = pipe( + guidance_scale=2.5, + num_inference_steps=50, + generator=torch.Generator("cpu").manual_seed(0), + **pipe_prior_output, +).images +images[0].save("flux-redux.png") +``` + ## Running FP16 inference + Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details. FP16 inference code: @@ -188,3 +321,15 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxControlNetImg2ImgPipeline - all - __call__ + +## FluxControlPipeline + +[[autodoc]] FluxControlPipeline + - all + - __call__ + +## FluxControlImg2ImgPipeline + +[[autodoc]] FluxControlImg2ImgPipeline + - all + - __call__ diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py index 05a1da256d33..33668fed8120 100644 --- a/scripts/convert_flux_to_diffusers.py +++ b/scripts/convert_flux_to_diffusers.py @@ -37,6 +37,8 @@ parser.add_argument("--original_state_dict_repo_id", default=None, type=str) parser.add_argument("--filename", default="flux.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--in_channels", type=int, default=64) +parser.add_argument("--out_channels", type=int, default=None) parser.add_argument("--vae", action="store_true") parser.add_argument("--transformer", action="store_true") parser.add_argument("--output_path", type=str) @@ -279,10 +281,13 @@ def main(args): num_single_layers = 38 inner_dim = 3072 mlp_ratio = 4.0 + converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio ) - transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) + transformer = FluxTransformer2DModel( + in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance + ) transformer.load_state_dict(converted_transformer_state_dict, strict=True) print( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d9d7491e5c79..a4749af5f61b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -269,12 +269,16 @@ "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", "CycleDiffusionPipeline", + "FluxControlImg2ImgPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", "FluxControlNetPipeline", + "FluxControlPipeline", + "FluxFillPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", + "FluxPriorReduxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -321,6 +325,7 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "ReduxImageEncoder", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -734,12 +739,16 @@ CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, CycleDiffusionPipeline, + FluxControlImg2ImgPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, + FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, + FluxPriorReduxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, @@ -786,6 +795,7 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + ReduxImageEncoder, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0ad3be866019..18527e3c46c0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -238,6 +238,7 @@ def __init__( self, patch_size: int = 1, in_channels: int = 64, + out_channels: Optional[int] = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, @@ -248,7 +249,7 @@ def __init__( axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() - self.out_channels = in_channels + self.out_channels = out_channels or in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) @@ -261,7 +262,7 @@ def __init__( ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ @@ -449,6 +450,7 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -456,6 +458,7 @@ def forward( guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None + temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 98574de1ad5f..5143b1114fd3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -127,12 +127,17 @@ "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["flux"] = [ + "FluxControlPipeline", + "FluxControlImg2ImgPipeline", "FluxControlNetPipeline", "FluxControlNetImg2ImgPipeline", "FluxControlNetInpaintPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxPipeline", + "FluxFillPipeline", + "FluxPriorReduxPipeline", + "ReduxImageEncoder", ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -521,12 +526,17 @@ VQDiffusionPipeline, ) from .flux import ( + FluxControlImg2ImgPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, + FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, + FluxPriorReduxPipeline, + ReduxImageEncoder, ) from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 0ebf5ea6d78d..3570368a5ca1 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -12,7 +12,7 @@ _dummy_objects = {} _additional_imports = {} -_import_structure = {"pipeline_output": ["FluxPipelineOutput"]} +_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -22,12 +22,17 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["modeling_flux"] = ["ReduxImageEncoder"] _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] + _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -35,12 +40,17 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .modeling_flux import ReduxImageEncoder from .pipeline_flux import FluxPipeline + from .pipeline_flux_control import FluxControlPipeline + from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_fill import FluxFillPipeline from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline + from .pipeline_flux_prior_redux import FluxPriorReduxPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/modeling_flux.py b/src/diffusers/pipelines/flux/modeling_flux.py new file mode 100644 index 000000000000..5ff60f774d19 --- /dev/null +++ b/src/diffusers/pipelines/flux/modeling_flux.py @@ -0,0 +1,47 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput + + +@dataclass +class ReduxImageEncoderOutput(BaseOutput): + image_embeds: Optional[torch.Tensor] = None + + +class ReduxImageEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + ) -> None: + super().__init__() + + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) + + def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: + projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) + + return ReduxImageEncoderOutput(image_embeds=projected_x) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py new file mode 100644 index 000000000000..04a93ba6351c --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -0,0 +1,891 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from controlnet_aux import CannyDetector + >>> from diffusers import FluxControlPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... control_image=control_image, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for controllable text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_latent_channels = ( + self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae_latent_channels + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py new file mode 100644 index 000000000000..ef20ab98ee2e --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -0,0 +1,946 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from controlnet_aux import CannyDetector + >>> from diffusers import FluxControlImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... image=image, + ... control_image=control_image, + ... strength=0.8, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 654bc41af4d0..ce7ea35c6cea 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -750,6 +750,7 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype + # 3. Prepare text embeddings lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py new file mode 100644 index 000000000000..32b2bbefa709 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -0,0 +1,970 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxFillPipeline + >>> from diffusers.utils import load_image + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png") + >>> mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png") + + >>> pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU + + >>> image = pipe( + ... prompt="a white paper cup", + ... image=image, + ... mask_image=mask, + ... height=1632, + ... width=1232, + ... guidance_scale=30, + ... num_inference_steps=50, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0] + >>> image.save("flux_fill.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxFillPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux Fill pipeline for image inpainting/outpainting. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # 1. calculate the height and width of the latents + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # 2. encode the masked image + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + batch_size = batch_size * num_images_per_prompt + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # 4. pack the masked_image_latents + # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4 + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + # 5.resize mask to latents shape we we concatenate the mask to the latents + mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed) + mask = mask.view( + batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor + ) # batch_size, height, 8, width, 8 + mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width + mask = mask.reshape( + batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width + ) # batch_size, 8*8, height, width + + # 6. pack the mask: + # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2 + mask = self._pack_latents( + mask, + batch_size, + self.vae_scale_factor * self.vae_scale_factor, + height, + width, + ) + mask = mask.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + image=None, + mask_image=None, + masked_image_latents=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + if image is not None and masked_image_latents is not None: + raise ValueError( + "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed." + ) + + if image is not None and mask_image is None: + raise ValueError("Please provide `mask_image` when passing `image`.") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: Optional[torch.FloatTensor] = None, + mask_image: Optional[torch.FloatTensor] = None, + masked_image_latents: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 30.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + image=image, + mask_image=mask_image, + masked_image_latents=masked_image_latents, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare mask and masked image latents + if masked_image_latents is not None: + masked_image_latents = masked_image_latents.to(latents.device) + else: + image = self.image_processor.preprocess(image, height=height, width=width) + mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) + + height, width = image.shape[-2:] + mask, masked_image_latents = self.prepare_mask_latents( + mask_image, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-process the image + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py new file mode 100644 index 000000000000..cf50e89ca5ae --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -0,0 +1,405 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +import torch +from PIL import Image +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ..pipeline_utils import DiffusionPipeline +from .modeling_flux import ReduxImageEncoder +from .pipeline_output import FluxPriorReduxPipelineOutput + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> dtype = torch.bfloat16 + + >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev" + >>> repo_base = "black-forest-labs/FLUX.1-dev" + >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device) + >>> pipe = FluxPipeline.from_pretrained( + ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16 + ... ).to(device) + + >>> image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png" + ... ) + >>> pipe_prior_output = pipe_prior_redux(image) + >>> images = pipe( + ... guidance_scale=2.5, + ... num_inference_steps=50, + ... generator=torch.Generator("cpu").manual_seed(0), + ... **pipe_prior_output, + ... ).images + >>> images[0].save("flux-redux.png") + ``` +""" + + +class FluxPriorReduxPipeline(DiffusionPipeline): + r""" + The Flux Redux pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + image_encoder ([`SiglipVisionModel`]): + SIGLIP vision model to encode the input image. + feature_extractor ([`SiglipImageProcessor`]): + Image processor for preprocessing images for the SIGLIP model. + image_embedder ([`ReduxImageEncoder`]): + Redux image encoder to process the SIGLIP embeddings. + text_encoder ([`CLIPTextModel`], *optional*): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`], *optional*): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`, *optional*): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`, *optional*): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "image_encoder->image_embedder" + _optional_components = [ + "text_encoder", + "tokenizer", + "text_encoder_2", + "tokenizer_2", + ] + _callback_tensor_inputs = [] + + def __init__( + self, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image_embedder: ReduxImageEncoder, + text_encoder: CLIPTextModel = None, + tokenizer: CLIPTokenizer = None, + text_encoder_2: T5EncoderModel = None, + tokenizer_2: T5TokenizerFast = None, + ): + super().__init__() + + self.register_modules( + image_encoder=image_encoder, + feature_extractor=feature_extractor, + image_embedder=image_embedder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + image = self.feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image = image.to(device=device, dtype=dtype) + + image_enc_hidden_states = self.image_encoder(**image).last_hidden_state + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + return image_enc_hidden_states + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + # 2. Define call parameters + if image is not None and isinstance(image, Image.Image): + batch_size = 1 + elif image is not None and isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + + # 3. Prepare image embeddings + image_latents = self.encode_image(image, device, 1) + + image_embeds = self.image_embedder(image_latents).image_embeds + image_embeds = image_embeds.to(device=device) + + # 3. Prepare (dummy) text embeddings + if hasattr(self, "text_encoder") and self.text_encoder is not None: + ( + prompt_embeds, + pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=[""] * batch_size, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=device, + num_images_per_prompt=1, + max_sequence_length=512, + lora_scale=None, + ) + else: + # max_sequence_length is 512, t5 encoder hidden size is 4096 + prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype) + # pooled_prompt_embeds is 768, clip text encoder hidden size + pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) + + # Concatenate image and text embeddings + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prompt_embeds, pooled_prompt_embeds) + + return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py index b5d98fb5bf60..388824e89f87 100644 --- a/src/diffusers/pipelines/flux/pipeline_output.py +++ b/src/diffusers/pipelines/flux/pipeline_output.py @@ -3,6 +3,7 @@ import numpy as np import PIL.Image +import torch from ...utils import BaseOutput @@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: torch.Tensor + pooled_prompt_embeds: torch.Tensor diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8b4b158efd0a..b76ea3824060 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -377,6 +377,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -422,6 +437,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxFillPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -467,6 +512,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxPriorReduxPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1157,6 +1217,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ReduxImageEncoder(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py new file mode 100644 index 000000000000..2bd511db3d65 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -0,0 +1,203 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + + +class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + control_image = Image.new("RGB", (16, 16), 0) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py new file mode 100644 index 000000000000..807013270eda --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py @@ -0,0 +1,168 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxControlImg2ImgPipeline, + FluxTransformer2DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxControlImg2ImgPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + image = Image.new("RGB", (16, 16), 0) + control_image = Image.new("RGB", (16, 16), 0) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "control_image": control_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py new file mode 100644 index 000000000000..6c6ec138c781 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_fill.py @@ -0,0 +1,168 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxFillPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=20, + out_channels=8, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=2, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_flux_fill_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_fill_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-3) diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py new file mode 100644 index 000000000000..39c83df1c143 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -0,0 +1,108 @@ +import gc +import unittest + +import numpy as np +import pytest +import torch + +from diffusers import FluxPipeline, FluxPriorReduxPipeline +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, + slow, + torch_device, +) + + +@slow +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxReduxSlowTests(unittest.TestCase): + pipeline_class = FluxPriorReduxPipeline + repo_id = "YiYiXu/yiyi-redux" # update to "black-forest-labs/FLUX.1-Redux-dev" once PR is merged + base_pipeline_class = FluxPipeline + base_repo_id = "black-forest-labs/FLUX.1-schnell" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + init_image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png" + ) + return {"image": init_image} + + def get_base_pipeline_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return { + "num_inference_steps": 2, + "guidance_scale": 2.0, + "output_type": "np", + "generator": generator, + } + + def test_flux_redux_inference(self): + pipe_redux = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe_base = self.base_pipeline_class.from_pretrained( + self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None + ) + pipe_redux.to(torch_device) + pipe_base.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device) + + redux_pipeline_output = pipe_redux(**inputs) + image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0] + + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + 0.30078125, + 0.37890625, + 0.46875, + 0.28125, + 0.36914062, + 0.47851562, + 0.28515625, + 0.375, + 0.4765625, + 0.28125, + 0.375, + 0.48046875, + 0.27929688, + 0.37695312, + 0.47851562, + 0.27734375, + 0.38085938, + 0.4765625, + 0.2734375, + 0.38085938, + 0.47265625, + 0.27539062, + 0.37890625, + 0.47265625, + 0.27734375, + 0.37695312, + 0.47070312, + 0.27929688, + 0.37890625, + 0.47460938, + ], + dtype=np.float32, + ) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4 From c4b5d2ff6b529ac0f895cedb04fef5b25e89c412 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:51:06 +0200 Subject: [PATCH 6/8] [SD3 dreambooth lora] smol fix to checkpoint saving (#9993) * smol change to fix checkpoint saving & resuming (as done in train_dreambooth_sd3.py) * style * modify comment to explain reasoning behind hidden size check --- examples/dreambooth/train_dreambooth_lora_sd3.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index dcf093a94c5a..3f721e56addf 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1294,10 +1294,13 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two + # both text encoders are of the same class, so we check hidden size to distinguish between the two + hidden_size = unwrap_model(model).config.hidden_size + if hidden_size == 768: + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + elif hidden_size == 1280: + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") From 047bf492914ddc9393070b8f73bba5ad5823eb29 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 25 Nov 2024 15:57:59 +0530 Subject: [PATCH 7/8] [Docs] add: missing pipelines from the spec. (#10005) add: missing pipelines from the spec. --- docs/source/en/api/pipelines/flux.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 011972bc59dd..94624264646f 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -333,3 +333,15 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxControlImg2ImgPipeline - all - __call__ + +## FluxPriorReduxPipeline + +[[autodoc]] FluxPriorReduxPipeline + - all + - __call__ + +## FluxFillPipeline + +[[autodoc]] FluxFillPipeline + - all + - __call__ From 074e12358bc17e7dbe111ea4f62f05dbae8a49d5 Mon Sep 17 00:00:00 2001 From: SkyCol <97716552+SkyCol@users.noreply.github.com> Date: Mon, 25 Nov 2024 21:12:06 +0800 Subject: [PATCH 8/8] Add prompt about wandb in examples/dreambooth/readme. (#10014) Add files via upload --- examples/dreambooth/README_flux.md | 2 +- examples/dreambooth/README_sd3.md | 2 +- examples/dreambooth/README_sdxl.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index a724ca53b927..c0802246e1f2 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -118,7 +118,7 @@ accelerate launch train_dreambooth_flux.py \ To better track our training experiments, we're using the following flags in the command above: -* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. > [!NOTE] diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index 89d87d65dd44..2ac7bf7101d8 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -105,7 +105,7 @@ accelerate launch train_dreambooth_sd3.py \ To better track our training experiments, we're using the following flags in the command above: -* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. > [!NOTE] diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 7a42bf8fffd8..565ff9a5dd33 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -99,7 +99,7 @@ accelerate launch train_dreambooth_lora_sdxl.py \ To better track our training experiments, we're using the following flags in the command above: -* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. Our experiments were conducted on a single 40GB A100 GPU.