Skip to content

Commit

Permalink
Merge pull request #1 from jobunk/main
Browse files Browse the repository at this point in the history
fixing SDXL Sampler V3 to work with newest version of ComfyUI
  • Loading branch information
Danamir authored Apr 8, 2024
2 parents 2eb934b + 151de56 commit 0324265
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
45 changes: 21 additions & 24 deletions modules/custom_sdxl_ksampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@

import comfy.sample
import comfy.samplers
import comfy.sampler_helpers

import comfy.utils
import latent_preview

from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.model_management import get_torch_device, load_models_gpu

from comfy.model_management import get_torch_device
from .utils import slerp_latents
from .utils import bilateral_blur

Expand Down Expand Up @@ -99,7 +100,7 @@ def sdxl_sample(base_model, refiner_model, noise, base_steps, refiner_steps, cfg
device = get_torch_device()

if noise_mask is not None:
noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device)
noise_mask = comfy.sampler_helpers.prepare_mask(noise_mask, noise.shape, device)

steps = base_steps + refiner_steps

Expand Down Expand Up @@ -167,32 +168,30 @@ def base_tonemap_reinhard(args):
base_model.set_model_sampler_cfg_function(base_rescale_cfg)
elif cfg_method == CfgMethods.TONEMAP and dynamic_base_cfg > 0.0:
base_model.set_model_sampler_cfg_function(base_tonemap_reinhard)

base_models, inference_memory = comfy.sample.get_additional_models(base_positive, base_negative,

temp_base_conds = {"positive": base_positive, "negative": base_negative}
base_conds = {}
for k in temp_base_conds:
base_conds[k] = comfy.sampler_helpers.convert_cond(temp_base_conds[k])
base_models, inference_memory = comfy.sampler_helpers.get_additional_models(base_conds,
base_model.model_dtype())

memory_required = base_model.memory_required(noise.shape) + inference_memory
load_models_gpu([base_model] + base_models, memory_required)

real_base_model = base_model.model

original_latent = latent_image

noise = noise.to(device)
latent_image = latent_image.to(device)

pos_base_copy = comfy.sample.convert_cond(base_positive)
neg_base_copy = comfy.sample.convert_cond(base_negative)

base_sampler = comfy.samplers.KSampler(real_base_model, steps=steps, device=device, sampler=sampler_name,
base_sampler = comfy.samplers.KSampler(base_model, steps=steps, device=device, sampler=sampler_name,
scheduler=scheduler, denoise=denoise, model_options=base_model.model_options)

base_samples = base_sampler.sample(noise, pos_base_copy, neg_base_copy, cfg=cfg, latent_image=latent_image,
base_samples = base_sampler.sample(noise, base_positive, base_negative, cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=base_steps, force_full_denoise=False,
denoise_mask=noise_mask, sigmas=sigmas, callback=base_callback,
disable_pbar=disable_pbar, seed=seed)

comfy.sample.cleanup_additional_models(base_models)
comfy.sampler_helpers.cleanup_additional_models(base_models)

noise = torch.zeros(base_samples.size(), dtype=base_samples.dtype, layout=base_samples.layout, device=device)

Expand Down Expand Up @@ -279,30 +278,28 @@ def refiner_tonemap_reinhard(args):
elif cfg_method == CfgMethods.TONEMAP and dynamic_refiner_cfg > 0.0:
refiner_model.set_model_sampler_cfg_function(refiner_tonemap_reinhard)

refiner_models, inference_memory = comfy.sample.get_additional_models(refiner_positive, refiner_negative,
temp_refiner_conds = {"positive": refiner_positive, "negative": refiner_negative}
refiner_conds = {}
for k in temp_refiner_conds:
refiner_conds[k] = comfy.sampler_helpers.convert_cond(temp_refiner_conds[k])
refiner_models, inference_memory = comfy.sampler_helpers.get_additional_models(refiner_conds,
refiner_model.model_dtype())

memory_required = refiner_model.memory_required(noise.shape) + inference_memory
load_models_gpu([refiner_model] + refiner_models, memory_required)

real_refiner_model = refiner_model.model

pos_refiner_copy = comfy.sample.convert_cond(refiner_positive)
neg_refiner_copy = comfy.sample.convert_cond(refiner_negative)

refiner_sampler = comfy.samplers.KSampler(real_refiner_model, steps=steps, device=device, sampler=sampler_name,
refiner_sampler = comfy.samplers.KSampler(refiner_model, steps=steps, device=device, sampler=sampler_name,
scheduler=scheduler, denoise=denoise,
model_options=refiner_model.model_options)

refiner_samples = refiner_sampler.sample(noise, pos_refiner_copy, neg_refiner_copy, cfg=cfg,
refiner_samples = refiner_sampler.sample(noise, refiner_positive, refiner_negative, cfg=cfg,
latent_image=latent_from_base, start_step=base_steps, last_step=last_step,
force_full_denoise=force_full_denoise,
denoise_mask=noise_mask, sigmas=sigmas, callback=refiner_callback,
disable_pbar=disable_pbar, seed=seed)

refiner_samples = refiner_samples.cpu()

comfy.sample.cleanup_additional_models(refiner_models)
comfy.sampler_helpers.cleanup_additional_models(refiner_models)

return refiner_samples

Expand Down
6 changes: 4 additions & 2 deletions modules/stage_latent_detailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
SOFTWARE.
"""

from comfy.sample import prepare_mask
try:
from comfy.sampler_helpers import prepare_mask
except ImportError:
from comfy.sample import prepare_mask

from .data_utils import retrieve_parameter
from .mb_pipeline import PipelineAccess
Expand Down

0 comments on commit 0324265

Please sign in to comment.