diff --git a/core/inference/ait/pipeline.py b/core/inference/ait/pipeline.py index b1450f0f4..7befb498a 100644 --- a/core/inference/ait/pipeline.py +++ b/core/inference/ait/pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import logging import math from pathlib import Path @@ -136,6 +137,7 @@ def unet_inference( # pylint: disable=dangerous-default-value width, down_block: list = [None], mid_block=None, + **kwargs, ): "Execute AIT#UNet module" exe_module = self.unet_ait_exe @@ -415,14 +417,33 @@ def do_denoise(x, t, call: Callable) -> torch.Tensor: return x if isinstance(self.scheduler, KdiffusionSchedulerAdapter): - latents = self.scheduler.do_inference( - latents, # type: ignore - generator=generator, - call=self.unet_inference, - apply_model=do_denoise, - callback=callback, - callback_steps=1, - ) + func_param_keys = inspect.signature( + self.scheduler.do_inference + ).parameters.keys() + + if ( + "optional_device" in func_param_keys + and "optional_dtype" in func_param_keys + ): + latents = self.scheduler.do_inference( + latents, # type: ignore + generator=generator, + call=self.unet_inference, + apply_model=do_denoise, + callback=callback, + callback_steps=1, + optional_device=self.device, # type: ignore + optional_dtype=latents.dtype, # type: ignore + ) + else: + latents = self.scheduler.do_inference( + latents, # type: ignore + generator=generator, + call=self.unet_inference, + apply_model=do_denoise, + callback=callback, + callback_steps=1, + ) else: def _call(*args, **kwargs): diff --git a/core/scheduling/adapter/unipc_adapter.py b/core/scheduling/adapter/unipc_adapter.py index 9655698e2..4cfbef800 100644 --- a/core/scheduling/adapter/unipc_adapter.py +++ b/core/scheduling/adapter/unipc_adapter.py @@ -113,20 +113,25 @@ def do_inference( generator: Union[PhiloxGenerator, torch.Generator], callback, callback_steps, + optional_device: Optional[torch.device] = None, + optional_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + device = optional_device or call.device + dtype = optional_dtype or call.dtype + def noise_pred_fn(x, t_continuous, cond=None, **model_kwargs): # Was originally get_model_input_time(t_continous) # but "schedule" is ALWAYS "discrete," so we can skip it :) t_input = (t_continuous - 1.0 / self.scheduler.total_N) * 1000 if cond is None: output = call( - x.to(device=call.device, dtype=call.dtype), - t_input.to(device=call.device, dtype=call.dtype), + x.to(device=device, dtype=dtype), + t_input.to(device=device, dtype=dtype), return_dict=True, **model_kwargs, )[0] else: - output = call(x.to(device=call.device, dtype=call.dtype), t_input.to(device=call.device, dtype=call.dtype), return_dict=True, encoder_hidden_states=cond, **model_kwargs)[0] # type: ignore + output = call(x.to(device=device, dtype=dtype), t_input.to(device=device, dtype=dtype), return_dict=True, encoder_hidden_states=cond, **model_kwargs)[0] # type: ignore if self.model_type == "noise": return output elif self.model_type == "x_start": diff --git a/tests/const.py b/tests/const.py new file mode 100644 index 000000000..31ca2e9c6 --- /dev/null +++ b/tests/const.py @@ -0,0 +1,18 @@ +KDIFF_SAMPLERS = [ + "euler_a", + "euler", + "lms", + "heun", + "dpm_fast", + "dpm_adaptive", + "dpm2", + "dpm2_a", + "dpmpp_2s_a", + "dpmpp_2m", + "dpmpp_2m_sharp", + "dpmpp_sde", + "dpmpp_2m_sde", + "dpmpp_3m_sde", + "unipc_multistep", + "restart", +] diff --git a/tests/inference/test_ait.py b/tests/inference/test_ait.py index 5ef34e656..306a18b4e 100644 --- a/tests/inference/test_ait.py +++ b/tests/inference/test_ait.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import pytest from diffusers.schedulers import KarrasDiffusionSchedulers @@ -9,6 +11,7 @@ Txt2imgData, Txt2ImgQueueEntry, ) +from tests.const import KDIFF_SAMPLERS from tests.functions import generate_random_image_base64 try: @@ -20,6 +23,8 @@ from core.inference.ait import AITemplateStableDiffusion model = "Azher--Anything-v4.5-vae-fp16-diffuser__512-1024x512-1024x1-1" +MODIFIED_KDIFF_SAMPLERS = deepcopy(KDIFF_SAMPLERS) +MODIFIED_KDIFF_SAMPLERS.remove("unipc_multistep") @pytest.fixture(name="pipe") @@ -29,7 +34,9 @@ def pipe_fixture(): ) -@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers)) +@pytest.mark.parametrize( + "scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS +) def test_aitemplate_txt2img( pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers ): @@ -45,7 +52,9 @@ def test_aitemplate_txt2img( pipe.generate(job) -@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers)) +@pytest.mark.parametrize( + "scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS +) def test_aitemplate_img2img( pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers ): @@ -62,7 +71,9 @@ def test_aitemplate_img2img( pipe.generate(job) -@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers)) +@pytest.mark.parametrize( + "scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS +) def test_aitemplate_controlnet( pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers ): diff --git a/tests/inference/test_pytorch.py b/tests/inference/test_pytorch.py index 55909f62c..8af083c33 100644 --- a/tests/inference/test_pytorch.py +++ b/tests/inference/test_pytorch.py @@ -15,27 +15,9 @@ Txt2ImgQueueEntry, ) from core.utils import convert_image_to_base64, unwrap_enum +from tests.const import KDIFF_SAMPLERS from tests.functions import generate_random_image, generate_random_image_base64 -kdiff_samplers = [ - "euler_a", - "euler", - "lms", - "heun", - "dpm_fast", - "dpm_adaptive", - "dpm2", - "dpm2_a", - "dpmpp_2s_a", - "dpmpp_2m", - "dpmpp_2m_sharp", - "dpmpp_sde", - "dpmpp_2m_sde", - "dpmpp_3m_sde", - "unipc_multistep", - "restart", -] - @pytest.fixture(name="pipe") def pipe_fixture(): @@ -44,7 +26,7 @@ def pipe_fixture(): return PyTorchStableDiffusion("Azher/Anything-v4.5-vae-fp16-diffuser") -@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + kdiff_samplers) +@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + KDIFF_SAMPLERS) def test_txt2img_scheduler_sweep( pipe: PyTorchStableDiffusion, scheduler: KarrasDiffusionSchedulers ):