Skip to content

Commit

Permalink
feat: add sd3.5 medium and skip layer guidance support (#451)
Browse files Browse the repository at this point in the history
* mmdit-x

* add support for sd3.5 medium

* add skip layer guidance support (mmdit only)

* ignore slg if slg_scale is zero (optimization)

* init out_skip once

* slg support for flux (expermiental)

* warn if version doesn't support slg

* refactor slg cli args

* set default slg_scale to 0 (oops)

* format code

---------

Co-authored-by: leejet <[email protected]>
  • Loading branch information
stduhpf and leejet authored Nov 23, 2024
1 parent ac54e00 commit 65fa646
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 81 deletions.
17 changes: 11 additions & 6 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ struct DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) = 0;
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void free_compute_buffer() = 0;
Expand Down Expand Up @@ -70,7 +71,9 @@ struct UNetModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
(void)skip_layers; // SLG doesn't work with UNet models
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
}
};
Expand Down Expand Up @@ -119,8 +122,9 @@ struct MMDiTModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
}
};

Expand Down Expand Up @@ -168,8 +172,9 @@ struct FluxModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
83 changes: 82 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ struct SDParams {
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;

std::vector<int> skip_layers = {7, 8, 9};
float slg_scale = 0.;
float skip_layer_start = 0.01;
float skip_layer_end = 0.2;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -151,6 +156,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
Expand Down Expand Up @@ -197,6 +203,12 @@ void print_usage(int argc, const char* argv[]) {
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n");
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
Expand Down Expand Up @@ -534,6 +546,61 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else if (arg == "--slg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.slg_scale = std::stof(argv[i]);
} else if (arg == "--skip-layers") {
if (++i >= argc) {
invalid_arg = true;
break;
}
if (argv[i][0] != '[') {
invalid_arg = true;
break;
}
std::string layers_str = argv[i];
while (layers_str.back() != ']') {
if (++i >= argc) {
invalid_arg = true;
break;
}
layers_str += " " + std::string(argv[i]);
}
layers_str = layers_str.substr(1, layers_str.size() - 2);

std::regex regex("[, ]+");
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
std::sregex_token_iterator end;
std::vector<std::string> tokens(iter, end);
std::vector<int> layers;
for (const auto& token : tokens) {
try {
layers.push_back(std::stoi(token));
} catch (const std::invalid_argument& e) {
invalid_arg = true;
break;
}
}
params.skip_layers = layers;

if (invalid_arg) {
break;
}
} else if (arg == "--skip-layer-start") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_start = std::stof(argv[i]);
} else if (arg == "--skip-layer-end") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_end = std::stof(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
Expand Down Expand Up @@ -624,6 +691,16 @@ std::string get_image_params(SDParams params, int64_t seed) {
}
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
Expand Down Expand Up @@ -840,7 +917,11 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
params.input_id_images_path.c_str(),
params.skip_layers,
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
Expand Down
26 changes: 19 additions & 7 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,8 @@ namespace Flux {
struct ggml_tensor* timesteps,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<int> skip_layers = std::vector<int>()) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
Expand All @@ -733,6 +734,10 @@ namespace Flux {
txt = txt_in->forward(ctx, txt);

for (int i = 0; i < params.depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
}

auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);

auto img_txt = block->forward(ctx, img, txt, vec, pe);
Expand All @@ -742,6 +747,9 @@ namespace Flux {

auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
for (int i = 0; i < params.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
continue;
}
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);

txt_img = block->forward(ctx, txt_img, vec, pe);
Expand Down Expand Up @@ -769,7 +777,8 @@ namespace Flux {
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<int> skip_layers = std::vector<int>()) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
Expand All @@ -791,7 +800,7 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
Expand Down Expand Up @@ -829,7 +838,8 @@ namespace Flux {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance) {
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

Expand All @@ -856,7 +866,8 @@ namespace Flux {
context,
y,
guidance,
pe);
pe,
skip_layers);

ggml_build_forward_expand(gf, out);

Expand All @@ -870,14 +881,15 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance);
return build_graph(x, timesteps, context, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down
Loading

0 comments on commit 65fa646

Please sign in to comment.