Skip to content

Commit

Permalink
Merge pull request #101 from skunkworxdark/main
Browse files Browse the repository at this point in the history
add negative prompts to validation images
  • Loading branch information
RyanJDick authored Apr 24, 2024
2 parents 38fe8f5 + d5e618a commit fefaedb
Show file tree
Hide file tree
Showing 18 changed files with 373 additions and 50 deletions.
22 changes: 18 additions & 4 deletions src/invoke_training/_shared/stable_diffusion/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def generate_validation_images_sd( # noqa: C901

# Run inference.
with torch.no_grad():
for prompt_idx, prompt in enumerate(config.validation_prompts):
for prompt_idx in range(len(config.validation_prompts)):
positive_prompt = config.validation_prompts[prompt_idx]
negative_prompt = None
if config.negative_validation_prompts is not None:
negative_prompt = config.negative_validation_prompts[prompt_idx]
logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt}'")

generator = torch.Generator(device=accelerator.device)
if config.seed is not None:
generator = generator.manual_seed(config.seed)
Expand All @@ -80,11 +86,12 @@ def generate_validation_images_sd( # noqa: C901
with accelerator.autocast():
images.append(
pipeline(
prompt,
positive_prompt,
num_inference_steps=30,
generator=generator,
height=validation_resolution.height,
width=validation_resolution.width,
negative_prompt=negative_prompt,
).images[0]
)

Expand Down Expand Up @@ -184,7 +191,13 @@ def generate_validation_images_sdxl( # noqa: C901

# Run inference.
with torch.no_grad():
for prompt_idx, prompt in enumerate(config.validation_prompts):
for prompt_idx in range(len(config.validation_prompts)):
positive_prompt = config.validation_prompts[prompt_idx]
negative_prompt = None
if config.negative_validation_prompts is not None:
negative_prompt = config.negative_validation_prompts[prompt_idx]
logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt}'")

generator = torch.Generator(device=accelerator.device)
if config.seed is not None:
generator = generator.manual_seed(config.seed)
Expand All @@ -194,11 +207,12 @@ def generate_validation_images_sdxl( # noqa: C901
with accelerator.autocast():
images.append(
pipeline(
prompt,
positive_prompt,
num_inference_steps=30,
generator=generator,
height=validation_resolution.height,
width=validation_resolution.width,
negative_prompt=negative_prompt,
).images[0]
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated, Literal, Union

from pydantic import Field
from pydantic import Field, model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.config_base_model import ConfigBaseModel
Expand Down Expand Up @@ -183,6 +183,11 @@ class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
See also 'validate_every_n_epochs'.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
Expand Down Expand Up @@ -212,3 +217,14 @@ class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
Typical values for `beta` are in the range [1000.0, 10000.0].
"""

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
23 changes: 18 additions & 5 deletions src/invoke_training/pipelines/stable_diffusion/lora/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Annotated, Literal, Union

from pydantic import Field
from pydantic import Field, model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import (
DreamboothSDDataLoaderConfig,
ImageCaptionSDDataLoaderConfig,
)
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig


Expand Down Expand Up @@ -146,6 +143,11 @@ class SdLoraConfig(BasePipelineConfig):
See also 'validate_every_n_epochs'.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
Expand All @@ -158,3 +160,14 @@ class SdLoraConfig(BasePipelineConfig):
data_loader: Annotated[
Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type")
]

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Literal

from pydantic import model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
Expand Down Expand Up @@ -148,6 +150,11 @@ class SdTextualInversionConfig(BasePipelineConfig):
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can
become very slow if this number is too large.
Expand All @@ -164,3 +171,14 @@ class SdTextualInversionConfig(BasePipelineConfig):
[`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]
for details.
"""

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
23 changes: 18 additions & 5 deletions src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Annotated, Literal, Union

from pydantic import Field
from pydantic import Field, model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import (
DreamboothSDDataLoaderConfig,
ImageCaptionSDDataLoaderConfig,
)
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig


Expand Down Expand Up @@ -146,6 +143,11 @@ class SdxlLoraConfig(BasePipelineConfig):
See also 'validate_every_n_epochs'.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
Expand All @@ -164,3 +166,14 @@ class SdxlLoraConfig(BasePipelineConfig):
model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Literal

from pydantic import model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
Expand Down Expand Up @@ -174,6 +176,11 @@ class SdxlLoraAndTextualInversionConfig(BasePipelineConfig):
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
Expand All @@ -196,3 +203,14 @@ class SdxlLoraAndTextualInversionConfig(BasePipelineConfig):
model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Literal

from pydantic import model_validator

from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
Expand Down Expand Up @@ -148,6 +150,11 @@ class SdxlTextualInversionConfig(BasePipelineConfig):
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""

negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""

num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can
become very slow if this number is too large.
Expand All @@ -170,3 +177,14 @@ class SdxlTextualInversionConfig(BasePipelineConfig):
base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL 1.0
shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""

@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
22 changes: 16 additions & 6 deletions src/invoke_training/ui/config_groups/sd_lora_config_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils import get_typing_literal_options
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options


class SdLoraConfigGroup(UIConfigElement):
Expand Down Expand Up @@ -157,7 +161,11 @@ def __init__(self):
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts", info="Enter one validation prompt per line.", lines=5, interactive=True
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
Expand All @@ -184,7 +192,9 @@ def update_ui_components_with_config_data(self, config: SdLoraConfig) -> dict[gr
self.gradient_checkpointing: config.gradient_checkpointing,
self.lora_rank_dim: config.lora_rank_dim,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: "\n".join(config.validation_prompts),
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
Expand Down Expand Up @@ -224,9 +234,9 @@ def update_config_with_ui_component_data(
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)

validation_prompts: list[str] = ui_data.pop(self.validation_prompts).split("\n")
validation_prompts = [x.strip() for x in validation_prompts if x.strip() != ""]
new_config.validation_prompts = validation_prompts
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts

new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
TextualInversionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils import get_typing_literal_options
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options


class SdTextualInversionConfigGroup(UIConfigElement):
Expand Down Expand Up @@ -157,7 +161,11 @@ def __init__(self):
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts", info="Enter one validation prompt per line.", lines=5, interactive=True
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
Expand All @@ -184,7 +192,9 @@ def update_ui_components_with_config_data(
self.mixed_precision: config.mixed_precision,
self.gradient_checkpointing: config.gradient_checkpointing,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: "\n".join(config.validation_prompts),
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
Expand Down Expand Up @@ -222,9 +232,9 @@ def update_config_with_ui_component_data(
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)

validation_prompts: list[str] = ui_data.pop(self.validation_prompts).split("\n")
validation_prompts = [x.strip() for x in validation_prompts if x.strip() != ""]
new_config.validation_prompts = validation_prompts
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts

new_config.data_loader = (
self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data(
Expand Down
Loading

0 comments on commit fefaedb

Please sign in to comment.