Skip to content

Commit

Permalink
fix sdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe56f committed Jan 2, 2024
1 parent e42b247 commit 2d6fea3
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion core/config/api_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class APIConfig:
data_type: Literal[
"float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2"
] = "float16"
use_minimal_sdxl_pipeline: bool = True # slower, but works better
use_minimal_sdxl_pipeline: bool = False # slower, but works better

# VRAM optimizations
# whether to run both parts of CFG>1 generations in one call. Increases VRAM usage during inference,
Expand Down
4 changes: 2 additions & 2 deletions core/inference/sdxl/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ def do_denoise(
# perform guidance
if do_classifier_free_guidance:
if not split_latents_into_two:
if isinstance(noise_pred, list):
noise_pred = noise_pred[0]
# if isinstance(noise_pred, list): # type: ignore
# noise_pred = noise_pred[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # type: ignore
noise_pred = calculate_cfg(
j, noise_pred_text, noise_pred_uncond, guidance_scale, t, additional_pred=noise_pred_vanilla # type: ignore
Expand Down
6 changes: 4 additions & 2 deletions core/inference/utilities/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,6 @@ def download_from_original_stable_diffusion_ckpt(
model_type = "SDXL-Refiner"
if image_size is None:
image_size = 1024
print(model_type)

# Check if we have a SDXL or SD model and initialize default pipeline
pipeline_class = StableDiffusionPipeline # type: ignore
Expand Down Expand Up @@ -1237,7 +1236,10 @@ def download_from_original_stable_diffusion_ckpt(
)

with init_empty_weights():
if model_type in ["SDXL", "SDXL-Refiner"] and volta_config.api.use_minimal_sdxl_pipeline:
if (
model_type in ["SDXL", "SDXL-Refiner"]
and volta_config.api.use_minimal_sdxl_pipeline
):
unet = SDXLUNet2D()
else:
unet = UNet2DConditionModel(**unet_config)
Expand Down
12 changes: 10 additions & 2 deletions core/inference/utilities/latents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import ExitStack
import logging
import math
from time import time
Expand All @@ -18,6 +19,7 @@
from core.inference.utilities.philox import PhiloxGenerator

from .random import randn
from core.optimizations.autocast_utils import autocast

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -315,8 +317,14 @@ def prepare_latents(
else:
if image.shape[1] != 4:
image = pad_tensor(image, pipe.vae_scale_factor)
init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore
init_latents = init_latent_dist.sample(generator=generator)
with ExitStack() as gs:
if pipe.vae.config["force_upcast"] or config.api.upcast_vae:
gs.enter_context(autocast(dtype=torch.float32))
init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore

if pipe.vae.config["force_upcast"] or config.api.upcast_vae:
gs.enter_context(autocast(dtype=config.api.load_dtype))
init_latents = init_latent_dist.sample(generator=generator) # type: ignore
init_latents = 0.18215 * init_latents
init_latents = torch.cat([init_latents] * batch_size, dim=0) # type: ignore
else:
Expand Down

0 comments on commit 2d6fea3

Please sign in to comment.