Skip to content

Commit

Permalink
Initial tests & implementation for SDXL
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Jun 26, 2024
1 parent a329759 commit 2f9cd56
Show file tree
Hide file tree
Showing 25 changed files with 3,998 additions and 0 deletions.
1 change: 1 addition & 0 deletions caikit_computer_vision/modules/text_to_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

# Local
from .sdxl import SDXL
from .tti_stub import TTIStub
266 changes: 266 additions & 0 deletions caikit_computer_vision/modules/text_to_image/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for text to image via SDXL.
"""
# Standard
from typing import Any, Dict, Optional, Union
import os

# Third Party
from diffusers import AutoPipelineForText2Image
from PIL.Image import SAVE, init
import torch

# First Party
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.interfaces.vision import data_model as caikit_dm
import alog

# Local
from ...data_model import CaptionedImage
from ...data_model.tasks import TextToImageTask

log = alog.use_channel("SDXL")
error = error_handler.get(log)


@module(
id="28aa777c-1b13-21b0-11b3-bb9c3b0cbb56",
name="Text to Image via SDXL",
version="0.1.0",
task=TextToImageTask,
)
class SDXL(ModuleBase):
_DETECT_DEVICE = "__DETECT__"
_SDXL_PIPELINE_CONFIG_KEY = "sdxl_model"

def __init__(
self,
model_name: str,
pipeline: Any, # TODO: test this with SD, but usually a StableDiffusionXLPipeline
) -> "SDXL":
"""Initialize a wrapper around an SDXL text to image pipeline.
Args:
model_name: str
Name of the model being initialized.
pipeline: Any
Initialized pipeline to be wrapped.
"""
super().__init__()
self.model_name = model_name
self.pipeline = pipeline

@classmethod
def load(
cls, model_path: Union[str, "ModuleConfig"], device: str = _DETECT_DEVICE
) -> "SDXL":
"""Loads an instance of this class from an saved caikit module.
Args:
model_path: Union[str, "ModuleConfig"]
Path to the caikit model to be loaded.
device: str
The device to load the model onto (follows device name convention used by pytorch).
Returns:
SDXL
An instance of this class wrapping the model indicated by model_name on the correct
device.
"""
config = ModuleConfig.load(model_path)
pipeline_path = os.path.join(
config.model_path, config[SDXL._SDXL_PIPELINE_CONFIG_KEY]
)
return cls.bootstrap(pipeline_path, device)

@classmethod
def bootstrap(
cls, model_name: str, device: str = _DETECT_DEVICE, **pipeline_kwargs
) -> "SDXL":
"""Creates an instance of this class from an external model, i.e., local or
on HuggingFaceHub.
Args:
model_name: str
The model that we would like to bootstrap.
device: str
The device to load the model onto (follows device name convention used by pytorch).
**pipeline_kwargs
Additional kwargs to be passed to the pipeline creation, i.e.,
AutoPipelineForText2Image.from_pretrained, e.g., revision, etc.
Returns:
SDXL
An instance of this class wrapping the model indicated by model_name on the correct
device.
"""
pipeline = AutoPipelineForText2Image.from_pretrained(
model_name, **pipeline_kwargs
)
device = cls._get_device(device)
log.warning(f"Loading Text to image pipline on device [{device}]")
pipeline = pipeline.to(device)
return cls(model_name, pipeline)

def save(self, model_path: str):
"""Saves the pipeline model.
Args:
model_path: str
Path to the model we would like to save.
"""
saver = ModuleSaver(
self,
model_path=model_path,
)
with saver:
model_rel_path, model_abs_path = saver.add_dir(
self._SDXL_PIPELINE_CONFIG_KEY
)
saver.update_config(
{
"model_name": self.model_name,
self._SDXL_PIPELINE_CONFIG_KEY: model_rel_path,
}
)
self.pipeline.save_pretrained(model_abs_path)

def run(
self,
inputs: str,
height: int = 512,
width: int = 512,
num_steps: int = 1,
guidance_scale: float = 0.0,
negative_prompt: Optional[str] = None,
image_format: str = "png",
) -> CaptionedImage:
"""Generates an image matching the provided height and width.
NOTE: We currently expose guidance scale / negative prompt as args, but they should be
unset for SDXL turbo, since the model does not leverage them.
Args:
inputs: str
Text prompt to be used for the generation.
height: int
Height of the image to be generated.
width: int
Width of the image to be generated.
num_steps: int
Number of steps to be used in image generation.
guidance_scale: float
Guidance scale to be used; set this to 0.0 for SDXL turbo.
negative_prompt: Optional[str]
Negative prompt to be used; leave this unset for SDXL turbo.
image_format: str
Format to be used for the underlying PIL object, e.g., at serialization
time, etc.
Returns:
TextToImageResult
Object encapsulating the generated image.
"""
error.value_check(
"<CCV81444491E>",
height > 0 and width > 0,
"Height & width must be positive values",
)
error.value_check(
"<CCV14111912E>",
num_steps > 0,
"Number of steps must be a positive value",
)
SDXL._validate_image_format(image_format)
dims_dict = self._force_to_nearest_multiples_of_eight(
{
"height": height,
"width": width,
}
)
image = self.pipeline(
prompt=inputs,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
negative_prompt=negative_prompt,
**dims_dict,
).images[0]
# Update the image export format - this is used when converting the image to proto, etc.
image.format = image_format
return CaptionedImage(
output=caikit_dm.Image(image),
caption=inputs,
)

@staticmethod
def _force_to_nearest_multiples_of_eight(num_dict: Dict[str, int]):
"""Get the nearest multiple of eight for required params, e.g., height and width.
Args:
num_dict: dict[str, int]
A dictionary whose values must be multiples of 8 for diffuser inference.
returns:
dict[str, int]
A handle to the dictionary whose values are all multiples of 8.
"""
for num_name, num_val in num_dict.items():
# Force anything that is not a multiple of 8 to a multiple of 8, which is required
if num_val // 8 != num_val / 8:
log.warning(
f"Forcing inference param [{num_name}] to nearest multiple of 8"
)
num_dict[num_name] = round(num_val / 8) * 8
return num_dict

@classmethod
def _get_device(cls, device: Optional[str]) -> Union[str, None]:
"""Get the device which we expect to run our models on. Defaults to GPU
if one is available, otherwise falls back to None (cpu).
NOTE: This code is adapted from Caikit NLP.
Args:
device: Optional[Union[str, int]]
Device to be leveraged; if set to cls._DETECT_DEVICE, infers the device,
otherwise we simply echo the value, which generally indicates a user override.
Returns:
Optional[str]
Device string that we should move our models / tensors .to() at inference
time.
"""
if device == cls._DETECT_DEVICE:
device = "cuda" if torch.cuda.is_available() else None
return device

def _validate_image_format(image_format: str):
"""Validates that the provided image format is an allowed choice.
Args:
image_format: str
Image format to be used at serialization time for the data model.
"""
# Initialize PIL's save driver registry if it isn't already
if not SAVE:
init()
fmt = image_format.upper()
if fmt not in SAVE:
error(
"<CCV14828291E>",
KeyError(
f"Format {fmt} is unsupported! Supported formats: {list(SAVE.keys())}"
),
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"transformers>=4.27.1,<5",
"torch>=2.0,<3",
"timm>=0.9.5,<1",
"diffusers<1",
#"opencv-python>=4.9.0.80,<5",
#"pycocotools>=2.0.7,<3",
#"detectron2 @git+https://github.com/facebookresearch/detectron2.git@e70b9229d77aa39d85f8fa5266e6ea658e92eed3",
Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from caikit_computer_vision.modules.object_detection import TransformersObjectDetector

# from caikit_computer_vision.modules.segmentation import ViTSegmenter
from caikit_computer_vision.modules.text_to_image import SDXL

### Constants used in fixtures
FIXTURES_DIR = os.path.join(os.path.dirname(__file__))
TINY_MODELS_DIR = os.path.join(FIXTURES_DIR, "tiny_models")
TRANSFORMER_OBJ_DETECT_MODEL = os.path.join(TINY_MODELS_DIR, "YolosForObjectDetection")
# SEMANTIC_SEGMENTATION_MODEL_DIR = os.path.join(TINY_MODELS_DIR, "ImageSegmentation")
# SEGMENTATION_MODEL_CKPT = os.path.join(SEMANTIC_SEGMENTATION_MODEL_DIR, "model.pt")
SDXL_MODEL = os.path.join(TINY_MODELS_DIR, "SDXL")


@pytest.fixture
Expand All @@ -40,3 +42,9 @@ def detector_transformer_dummy_model():
# def segmentation_dummy_model():
# """Load torch scripted model weights for ViT Segmentation"""
# return ViTSegmenter.bootstrap(SEGMENTATION_MODEL_CKPT)


@pytest.fixture
def sdxl_dummy_model():
"""Load an SDXL dummy model."""
return SDXL.load(SDXL_MODEL)
3 changes: 3 additions & 0 deletions tests/fixtures/tiny_models/SDXL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This model is a bootstrapped export of one of HuggingFace's internal testing repositories.

More specifically: https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-xl-pipe
10 changes: 10 additions & 0 deletions tests/fixtures/tiny_models/SDXL/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
caikit_computer_vision_version: 0.0.1
created: "2024-06-24 14:03:31.930040"
model_name: hf-internal-testing/tiny-stable-diffusion-xl-pipe
module_class: caikit_computer_vision.modules.text_to_image.sdxl.SDXL
module_id: 28aa938b-1a33-11a0-11a3-bb9c3b1cbb11
name: Text to Image via SDXL
saved: "2024-06-24 14:03:31.930044"
sdxl_model: ./sdxl_model
tracking_id: 8cc3a483-7eb6-4e81-beb6-b6c352a0a3bc
version: 0.1.0
15 changes: 15 additions & 0 deletions tests/fixtures/tiny_models/SDXL/sdxl_model/model_index.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"_class_name": "StableDiffusionXLPipeline",
"_diffusers_version": "0.29.1",
"_name_or_path": "hf-internal-testing/tiny-stable-diffusion-xl-pipe",
"feature_extractor": [null, null],
"force_zeros_for_empty_prompt": true,
"image_encoder": [null, null],
"scheduler": ["diffusers", "EulerDiscreteScheduler"],
"text_encoder": ["transformers", "CLIPTextModel"],
"text_encoder_2": ["transformers", "CLIPTextModelWithProjection"],
"tokenizer": ["transformers", "CLIPTokenizer"],
"tokenizer_2": ["transformers", "CLIPTokenizer"],
"unet": ["diffusers", "UNet2DConditionModel"],
"vae": ["diffusers", "AutoencoderKL"]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"_class_name": "EulerDiscreteScheduler",
"_diffusers_version": "0.29.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"final_sigmas_type": "zero",
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rescale_betas_zero_snr": false,
"sigma_max": null,
"sigma_min": null,
"steps_offset": 1,
"timestep_spacing": "leading",
"timestep_type": "discrete",
"trained_betas": null,
"use_karras_sigmas": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"_name_or_path": "/Users/alexanderjbrooks/.cache/huggingface/hub/models--hf-internal-testing--tiny-stable-diffusion-xl-pipe/snapshots/fa06b5c3d6d45fabd978486616c6dae068519e85/text_encoder",
"architectures": ["CLIPTextModel"],
"attention_dropout": 0.0,
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 32,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 37,
"layer_norm_eps": 1e-5,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 4,
"num_hidden_layers": 5,
"pad_token_id": 1,
"projection_dim": 32,
"torch_dtype": "float32",
"transformers_version": "4.31.0",
"vocab_size": 1000
}
Binary file not shown.
Loading

0 comments on commit 2f9cd56

Please sign in to comment.