Skip to content

Commit

Permalink
Tests for KDiffusion, partial fix for Unipc on AIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Stax124 committed Oct 13, 2023
1 parent 8ea0987 commit 27ee5b5
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 34 deletions.
37 changes: 29 additions & 8 deletions core/inference/ait/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import logging
import math
from pathlib import Path
Expand Down Expand Up @@ -136,6 +137,7 @@ def unet_inference( # pylint: disable=dangerous-default-value
width,
down_block: list = [None],
mid_block=None,
**kwargs,
):
"Execute AIT#UNet module"
exe_module = self.unet_ait_exe
Expand Down Expand Up @@ -415,14 +417,33 @@ def do_denoise(x, t, call: Callable) -> torch.Tensor:
return x

if isinstance(self.scheduler, KdiffusionSchedulerAdapter):
latents = self.scheduler.do_inference(
latents, # type: ignore
generator=generator,
call=self.unet_inference,
apply_model=do_denoise,
callback=callback,
callback_steps=1,
)
func_param_keys = inspect.signature(
self.scheduler.do_inference
).parameters.keys()

if (
"optional_device" in func_param_keys
and "optional_dtype" in func_param_keys
):
latents = self.scheduler.do_inference(
latents, # type: ignore
generator=generator,
call=self.unet_inference,
apply_model=do_denoise,
callback=callback,
callback_steps=1,
optional_device=self.device, # type: ignore
optional_dtype=latents.dtype, # type: ignore
)
else:
latents = self.scheduler.do_inference(
latents, # type: ignore
generator=generator,
call=self.unet_inference,
apply_model=do_denoise,
callback=callback,
callback_steps=1,
)
else:

def _call(*args, **kwargs):
Expand Down
11 changes: 8 additions & 3 deletions core/scheduling/adapter/unipc_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,25 @@ def do_inference(
generator: Union[PhiloxGenerator, torch.Generator],
callback,
callback_steps,
optional_device: Optional[torch.device] = None,
optional_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
device = optional_device or call.device
dtype = optional_dtype or call.dtype

def noise_pred_fn(x, t_continuous, cond=None, **model_kwargs):
# Was originally get_model_input_time(t_continous)
# but "schedule" is ALWAYS "discrete," so we can skip it :)
t_input = (t_continuous - 1.0 / self.scheduler.total_N) * 1000
if cond is None:
output = call(
x.to(device=call.device, dtype=call.dtype),
t_input.to(device=call.device, dtype=call.dtype),
x.to(device=device, dtype=dtype),
t_input.to(device=device, dtype=dtype),
return_dict=True,
**model_kwargs,
)[0]
else:
output = call(x.to(device=call.device, dtype=call.dtype), t_input.to(device=call.device, dtype=call.dtype), return_dict=True, encoder_hidden_states=cond, **model_kwargs)[0] # type: ignore
output = call(x.to(device=device, dtype=dtype), t_input.to(device=device, dtype=dtype), return_dict=True, encoder_hidden_states=cond, **model_kwargs)[0] # type: ignore
if self.model_type == "noise":
return output
elif self.model_type == "x_start":
Expand Down
18 changes: 18 additions & 0 deletions tests/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
KDIFF_SAMPLERS = [
"euler_a",
"euler",
"lms",
"heun",
"dpm_fast",
"dpm_adaptive",
"dpm2",
"dpm2_a",
"dpmpp_2s_a",
"dpmpp_2m",
"dpmpp_2m_sharp",
"dpmpp_sde",
"dpmpp_2m_sde",
"dpmpp_3m_sde",
"unipc_multistep",
"restart",
]
17 changes: 14 additions & 3 deletions tests/inference/test_ait.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import pytest
from diffusers.schedulers import KarrasDiffusionSchedulers

Expand All @@ -9,6 +11,7 @@
Txt2imgData,
Txt2ImgQueueEntry,
)
from tests.const import KDIFF_SAMPLERS
from tests.functions import generate_random_image_base64

try:
Expand All @@ -20,6 +23,8 @@
from core.inference.ait import AITemplateStableDiffusion

model = "Azher--Anything-v4.5-vae-fp16-diffuser__512-1024x512-1024x1-1"
MODIFIED_KDIFF_SAMPLERS = deepcopy(KDIFF_SAMPLERS)
MODIFIED_KDIFF_SAMPLERS.remove("unipc_multistep")


@pytest.fixture(name="pipe")
Expand All @@ -29,7 +34,9 @@ def pipe_fixture():
)


@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers))
@pytest.mark.parametrize(
"scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS
)
def test_aitemplate_txt2img(
pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers
):
Expand All @@ -45,7 +52,9 @@ def test_aitemplate_txt2img(
pipe.generate(job)


@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers))
@pytest.mark.parametrize(
"scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS
)
def test_aitemplate_img2img(
pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers
):
Expand All @@ -62,7 +71,9 @@ def test_aitemplate_img2img(
pipe.generate(job)


@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers))
@pytest.mark.parametrize(
"scheduler", list(KarrasDiffusionSchedulers) + MODIFIED_KDIFF_SAMPLERS
)
def test_aitemplate_controlnet(
pipe: AITemplateStableDiffusion, scheduler: KarrasDiffusionSchedulers
):
Expand Down
22 changes: 2 additions & 20 deletions tests/inference/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,9 @@
Txt2ImgQueueEntry,
)
from core.utils import convert_image_to_base64, unwrap_enum
from tests.const import KDIFF_SAMPLERS
from tests.functions import generate_random_image, generate_random_image_base64

kdiff_samplers = [
"euler_a",
"euler",
"lms",
"heun",
"dpm_fast",
"dpm_adaptive",
"dpm2",
"dpm2_a",
"dpmpp_2s_a",
"dpmpp_2m",
"dpmpp_2m_sharp",
"dpmpp_sde",
"dpmpp_2m_sde",
"dpmpp_3m_sde",
"unipc_multistep",
"restart",
]


@pytest.fixture(name="pipe")
def pipe_fixture():
Expand All @@ -44,7 +26,7 @@ def pipe_fixture():
return PyTorchStableDiffusion("Azher/Anything-v4.5-vae-fp16-diffuser")


@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + kdiff_samplers)
@pytest.mark.parametrize("scheduler", list(KarrasDiffusionSchedulers) + KDIFF_SAMPLERS)
def test_txt2img_scheduler_sweep(
pipe: PyTorchStableDiffusion, scheduler: KarrasDiffusionSchedulers
):
Expand Down

0 comments on commit 27ee5b5

Please sign in to comment.