Skip to content

Commit

Permalink
Merge pull request #109 from invoke-ai/callbacks
Browse files Browse the repository at this point in the history
Add on_save_checkpoint() and on_save_validation_images() callbacks
  • Loading branch information
RyanJDick authored Apr 22, 2024
2 parents 391e9c4 + 861a54c commit cfb95f3
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 18 deletions.
33 changes: 29 additions & 4 deletions src/invoke_training/_shared/stable_diffusion/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from transformers import CLIPTextModel, CLIPTokenizer

from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig


def generate_validation_images_sd(
def generate_validation_images_sd( # noqa: C901
epoch: int,
step: int,
out_dir: str,
Expand All @@ -32,6 +33,7 @@ def generate_validation_images_sd(
unet: UNet2DConditionModel,
config: SdLoraConfig,
logger: logging.Logger,
callbacks: list[PipelineCallbacks] | None = None,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
training.
Expand Down Expand Up @@ -64,6 +66,8 @@ def generate_validation_images_sd(

validation_resolution = Resolution.parse(config.data_loader.resolution)

validation_images = ValidationImages(images=[], epoch=epoch, step=step)

# Run inference.
with torch.no_grad():
for prompt_idx, prompt in enumerate(config.validation_prompts):
Expand Down Expand Up @@ -93,7 +97,11 @@ def generate_validation_images_sd(
)
os.makedirs(validation_dir)
for image_idx, image in enumerate(images):
image.save(os.path.join(validation_dir, f"{image_idx:0>4}.jpg"))
image_path = os.path.join(validation_dir, f"{image_idx:0>4}.jpg")
validation_images.images.append(
ValidationImage(file_path=image_path, prompt=prompt, image_idx=image_idx)
)
image.save(image_path)

# Log images to trackers. Currently, only tensorboard is supported.
for tracker in accelerator.trackers:
Expand All @@ -120,8 +128,13 @@ def generate_validation_images_sd(
vae.to(vae_device)
text_encoder.to(text_encoder_device)

# Run callbacks.
if callbacks is not None:
for cb in callbacks:
cb.on_save_validation_images(images=validation_images)


def generate_validation_images_sdxl(
def generate_validation_images_sdxl( # noqa: C901
epoch: int,
step: int,
out_dir: str,
Expand All @@ -135,6 +148,7 @@ def generate_validation_images_sdxl(
unet: UNet2DConditionModel,
config: SdxlLoraConfig,
logger: logging.Logger,
callbacks: list[PipelineCallbacks] | None = None,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
training.
Expand Down Expand Up @@ -166,6 +180,8 @@ def generate_validation_images_sdxl(

validation_resolution = Resolution.parse(config.data_loader.resolution)

validation_images = ValidationImages(images=[], epoch=epoch, step=step)

# Run inference.
with torch.no_grad():
for prompt_idx, prompt in enumerate(config.validation_prompts):
Expand Down Expand Up @@ -195,7 +211,11 @@ def generate_validation_images_sdxl(
)
os.makedirs(validation_dir)
for image_idx, image in enumerate(images):
image.save(os.path.join(validation_dir, f"{image_idx:0>4}.jpg"))
image_path = os.path.join(validation_dir, f"{image_idx:0>4}.jpg")
validation_images.images.append(
ValidationImage(file_path=image_path, prompt=prompt, image_idx=image_idx)
)
image.save(image_path)

# Log images to trackers. Currently, only tensorboard is supported.
for tracker in accelerator.trackers:
Expand All @@ -222,3 +242,8 @@ def generate_validation_images_sdxl(
vae.to(vae_device)
text_encoder_1.to(text_encoder_1_device)
text_encoder_2.to(text_encoder_2_device)

# Run callbacks.
if callbacks is not None:
for cb in callbacks:
cb.on_save_validation_images(images=validation_images)
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig
from invoke_training.pipelines.callbacks import PipelineCallbacks
from invoke_training.pipelines.stable_diffusion.lora.train import cache_text_encoder_outputs


Expand Down Expand Up @@ -186,7 +187,10 @@ def train_forward_dpo( # noqa: C901
return loss


def train(config: SdDirectPreferenceOptimizationLoraConfig): # noqa: C901
def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
if callbacks:
raise ValueError(f"This pipeline does not support callbacks, but {len(callbacks)} were provided.")

# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
Expand Down
78 changes: 78 additions & 0 deletions src/invoke_training/pipelines/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC
from enum import Enum


class ModelType(Enum):
# At first glance, it feels like these model types should be further broken down into separate enums (e.g.
# base_model, model_type, checkpoint_format). But, I haven't yet come up with a taxonomy that feels sufficiently
# future-proof. So, for now, there is one enum for each file type that invoke-training can produce.

# A Stable Diffusion 1.x LoRA model in Kohya format.
SD1_LORA_KOHYA = "SD1_LORA_KOHYA"
# A Stable Diffusion 1.x LoRA model in PEFT format.
SD1_LORA_PEFT = "SD1_LORA_PEFT"
# A Stable Diffusion XL LoRA model in Kohya format.
SDXL_LORA_KOHYA = "SDXL_LORA_KOHYA"
# A Stable Diffusion XL LoRA model in PEFT format.
SDXL_LORA_PEFT = "SDXL_LORA_PEFT"

# A Stable Diffusion 1.x Textual Inversion model.
SD1_TEXTUAL_INVERSION = "SD1_TEXTUAL_INVERSION"
# A Stable Diffusion XL Textual Inversion model.
SDXL_TEXTUAL_INVERSION = "SDXL_TEXTUAL_INVERSION"


class ModelCheckpoint:
"""A single model checkpoint."""

def __init__(self, file_path: str, model_type: ModelType):
self.file_path = file_path
self.model_type = model_type


class TrainingCheckpoint:
"""A training checkpoint. May contain multiple model checkpoints if multiple models are being trained
simultaneously.
"""

def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int):
self.models = models
self.epoch = epoch
self.step = step


class ValidationImage:
def __init__(self, file_path: str, prompt: str, image_idx: int):
"""A single validation image.
Args:
file_path (str): Path to the image file.
prompt (str): The prompt used to generate the image.
image_idx (int): The index of this image in the current validation set (i.e. in the set of images generated
with the same prompt at the same validation point).
"""
self.file_path = file_path
self.prompt = prompt
self.image_idx = image_idx


class ValidationImages:
def __init__(self, images: list[ValidationImage], epoch: int, step: int):
"""A collection of validation images.
Args:
images (list[ValidationImage]): The validation images.
epoch (int): The last completed epoch at the time that these images were generated.
step (int): The last completed training step at the time that these images were generated.
"""
self.images = images
self.epoch = epoch
self.step = step


class PipelineCallbacks(ABC):
def on_save_checkpoint(self, checkpoint: TrainingCheckpoint):
pass

def on_save_validation_images(self, images: ValidationImages):
pass
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.pipelines._experimental.sd_dpo_lora.train import train as train_sd_ddpo_lora
from invoke_training.pipelines.callbacks import PipelineCallbacks
from invoke_training.pipelines.stable_diffusion.lora.train import train as train_sd_lora
from invoke_training.pipelines.stable_diffusion.textual_inversion.train import train as train_sd_ti
from invoke_training.pipelines.stable_diffusion_xl.lora.train import train as train_sdxl_lora
Expand All @@ -9,19 +10,25 @@
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import train as train_sdxl_ti


def train(config: PipelineConfig):
def train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | None = None):
"""This is the main entry point for all training pipelines."""

# Fail early if invalid callback types are provided, rather than failing later when the callbacks are used.
for cb in callbacks or []:
assert isinstance(cb, PipelineCallbacks)

if config.type == "SD_LORA":
train_sd_lora(config)
train_sd_lora(config, callbacks)
elif config.type == "SDXL_LORA":
train_sdxl_lora(config)
train_sdxl_lora(config, callbacks)
elif config.type == "SD_TEXTUAL_INVERSION":
train_sd_ti(config)
train_sd_ti(config, callbacks)
elif config.type == "SDXL_TEXTUAL_INVERSION":
train_sdxl_ti(config)
train_sdxl_ti(config, callbacks)
elif config.type == "SDXL_LORA_AND_TEXTUAL_INVERSION":
train_sdxl_lora_and_ti(config)
train_sdxl_lora_and_ti(config, callbacks)
elif config.type == "SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA":
print(f"Running EXPERIMENTAL pipeline: '{config.type}'.")
train_sd_ddpo_lora(config)
train_sd_ddpo_lora(config, callbacks)
else:
raise ValueError(f"Unexpected pipeline type: '{config.type}'.")
16 changes: 15 additions & 1 deletion src/invoke_training/pipelines/stable_diffusion/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig


Expand All @@ -51,6 +52,7 @@ def _save_sd_lora_checkpoint(
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
callbacks: list[PipelineCallbacks] | None,
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
Expand All @@ -59,12 +61,22 @@ def _save_sd_lora_checkpoint(
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)

if lora_checkpoint_format == "invoke_peft":
model_type = ModelType.SD1_LORA_PEFT
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
model_type = ModelType.SD1_LORA_KOHYA
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")

if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step
)
)


def _build_data_loader(
data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig],
Expand Down Expand Up @@ -241,7 +253,7 @@ def train_forward( # noqa: C901
return loss.mean()


def train(config: SdLoraConfig): # noqa: C901
def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
Expand Down Expand Up @@ -509,6 +521,7 @@ def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
callbacks=callbacks,
)
accelerator.wait_for_everyone()

Expand All @@ -527,6 +540,7 @@ def validate(num_completed_epochs: int, num_completed_steps: int):
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
restore_original_embeddings,
)
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs, train_forward
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig

Expand All @@ -44,6 +45,7 @@ def _save_ti_embeddings(
accelerator: Accelerator,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
callbacks: list[PipelineCallbacks] | None,
):
"""Save a Textual Inversion checkpoint. Old checkpoints are deleted if necessary to respect the checkpoint_tracker
limits.
Expand All @@ -63,6 +65,16 @@ def _save_ti_embeddings(

save_state_dict(learned_embeds_dict, save_path)

if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SD1_TEXTUAL_INVERSION)],
epoch=epoch,
step=step,
)
)


def _initialize_placeholder_tokens(
config: SdTextualInversionConfig,
Expand Down Expand Up @@ -120,7 +132,7 @@ def _initialize_placeholder_tokens(
return placeholder_tokens, placeholder_token_ids


def train(config: SdTextualInversionConfig): # noqa: C901
def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
Expand Down Expand Up @@ -308,6 +320,7 @@ def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator=accelerator,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
callbacks=callbacks,
)
accelerator.wait_for_everyone()

Expand All @@ -326,6 +339,7 @@ def validate(num_completed_epochs: int, num_completed_steps: int):
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()

Expand Down
Loading

0 comments on commit cfb95f3

Please sign in to comment.