Skip to content

Commit

Permalink
Merge pull request #129 from invoke-ai/masks
Browse files Browse the repository at this point in the history
Add support for image masks
  • Loading branch information
RyanJDick authored May 22, 2024
2 parents 4848266 + 2b17c95 commit 11da280
Show file tree
Hide file tree
Showing 46 changed files with 1,077 additions and 168 deletions.
Binary file added docs/images/bruce_masks/001_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/bruce_masks/002_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/bruce_masks/bruce_masks_step_300.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 93 additions & 0 deletions docs/tutorials/stable_diffusion/gnome_lora_masks_sdxl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# LoRA with Masks - SDXL

This tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model.

Masks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest.

## 1 - Dataset Preparation

For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:

| | |
| - | - |
| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) |
| ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) |

This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).

## 2 - Generate Masks

Use the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `"a stuffed gnome"`:
```bash
python src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \
--in-jsonl sample_data/bruce_the_gnome/data.jsonl \
--out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \
--prompt "a stuffed gnome"
```

The mask generation script will produce the following outputs:
- A directory of generated masks: `sample_data/bruce_the_gnome/masks/`
- A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl`

## 3 - Review the Generated Masks

Review the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground).

Here are some examples of the masks that we just generated:

| | |
| - | - |
| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 1 mask.](../../images/bruce_masks/001_mask.png) |
| ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | ![bruce_the_gnome dataset image 2 mask.](../../images/bruce_masks/002_mask.png) |

## 4 - Configuration

Below is the training configuration that we'll use for this tutorial.

Raw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml).


```yaml title="sdxl_lora_masks_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml"
```

Full documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md)

There are few things to note about this training config:

- We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data.
- The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting.

## 5 - Start Training

Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml
```

Training takes ~30 mins on an NVIDIA RTX 4090.

## 4 - Monitor

In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.

Sample images will be logged to Tensorboard so that you can see how the model is evolving.

Once training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300:

![Screenshot of the Tensorboard UI showing the validation images for step 300.](../../images/bruce_masks/bruce_masks_step_300.jpg)
*Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: "A stuffed gnome at the beach with a pina colada in its hand.".*


## 6 - Import into InvokeAI

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Import your trained LoRA model from the 'Models' tab.

Congratulations, you can now use your new Bruce-the-Gnome model! 🎉
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ nav:
- tutorials/index.md
- Stable Diffusion:
- tutorials/stable_diffusion/robocats_finetune_sdxl.md
- tutorials/stable_diffusion/gnome_lora_masks_sdxl.md
- tutorials/stable_diffusion/textual_inversion_sdxl.md
- tutorials/stable_diffusion/dpo_lora_sd.md
- Concepts:
Expand Down
4 changes: 4 additions & 0 deletions sample_data/bruce_the_gnome/data.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"image": "001.png", "text": "A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background."}
{"image": "002.png", "text": "A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background."}
{"image": "004.png", "text": "A stuffed gnome sits on a white marble floor, photorealistic."}
{"image": "003.png", "text": "A stuffed gnome sits on a gray tiled floor, facing the camera."}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler
from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler
from invoke_training._shared.data.samplers.interleaved_sampler import (
InterleavedSampler,
)
from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler
from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler
from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
Expand Down Expand Up @@ -123,6 +119,8 @@ def build_dreambooth_sd_dataloader(
if vae_output_cache_dir is None:
all_transforms.append(
SDImageTransform(
image_field_names=["image"],
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training.config.data.data_loader_config import (
AspectRatioBucketConfig,
ImageCaptionSDDataLoaderConfig,
)
from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
Expand Down Expand Up @@ -60,6 +55,9 @@ def sd_image_caption_collate_fn(examples):
if "vae_output" in examples[0]:
out_examples["vae_output"] = torch.stack([example["vae_output"] for example in examples])

if "mask" in examples[0]:
out_examples["mask"] = torch.stack([example["mask"] for example in examples])

return out_examples


Expand All @@ -72,9 +70,10 @@ def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
)


def build_image_caption_sd_dataloader(
def build_image_caption_sd_dataloader( # noqa: C901
config: ImageCaptionSDDataLoaderConfig,
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
Expand Down Expand Up @@ -125,29 +124,43 @@ def build_image_caption_sd_dataloader(
all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))

if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))

all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))

vae_cache = TensorDiskCache(vae_output_cache_dir)

cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field={
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
},
cache_field_to_output_field=cache_field_to_output_field,
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))

if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import torch
from torch.utils.data import DataLoader

from invoke_training._shared.data.datasets.build_dataset import (
build_hf_image_pair_preference_dataset,
)
from invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
Expand Down Expand Up @@ -80,24 +78,28 @@ def build_image_pair_preference_sd_dataloader(

all_transforms = []
if vae_output_cache_dir is None:
# TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same
# transformations?
all_transforms.append(
SDImageTransform(
image_field_names=["image_0"],
fields_to_normalize_to_range_minus_one_to_one=["image_0"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
image_field_name="image_0",
orig_size_field_name="original_size_hw_0",
crop_field_name="crop_top_left_yx_0",
)
)
all_transforms.append(
SDImageTransform(
image_field_names=["image_1"],
fields_to_normalize_to_range_minus_one_to_one=["image_1"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
image_field_name="image_1",
orig_size_field_name="original_size_hw_1",
crop_field_name="crop_top_left_yx_1",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform
from invoke_training._shared.data.transforms.template_caption_transform import (
TemplateCaptionTransform,
)
from invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
Expand Down Expand Up @@ -102,6 +98,7 @@ def build_textual_inversion_sd_dataloader( # noqa: C901
config: TextualInversionSDDataLoaderConfig,
placeholder_token: str,
batch_size: int,
use_masks: bool = False,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
Expand Down Expand Up @@ -181,29 +178,44 @@ def build_textual_inversion_sd_dataloader( # noqa: C901
all_transforms.append(ShuffleCaptionTransform(field_name="caption", delimiter=config.shuffle_caption_delimiter))

if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))

all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))

vae_cache = TensorDiskCache(vae_output_cache_dir)

cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"

all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field={
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
},
cache_field_to_output_field=cache_field_to_output_field,
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))

dataset = TransformDataset(base_dataset, all_transforms)

Expand Down
Loading

0 comments on commit 11da280

Please sign in to comment.