From 7a02d2bca6cf29dfe8fdcd796fca0d33ef275426 Mon Sep 17 00:00:00 2001 From: Anna Likholat Date: Thu, 19 Dec 2024 08:41:00 +0100 Subject: [PATCH] [ImageGeneration] EulerAncestralDiscreteScheduler (#1407) ![image](https://github.com/user-attachments/assets/6b688510-50d9-4f32-b80d-cb8cfa0b4b79) CVS-156803 CVS-158965 --------- Co-authored-by: Ilya Lavrenov --- .../genai/image_generation/scheduler.hpp | 3 +- .../schedulers/euler_ancestral_discrete.cpp | 261 ++++++++++++++++++ .../schedulers/euler_ancestral_discrete.hpp | 61 ++++ .../image_generation/schedulers/scheduler.cpp | 3 + .../src/image_generation/schedulers/types.cpp | 2 + src/docs/SUPPORTED_MODELS.md | 1 + .../openvino_genai/py_openvino_genai.pyi | 5 +- src/python/py_image_generation_pipelines.cpp | 3 +- tools/llm_bench/llm_bench_utils/ov_utils.py | 2 +- 9 files changed, 337 insertions(+), 4 deletions(-) create mode 100644 src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp create mode 100644 src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp diff --git a/src/cpp/include/openvino/genai/image_generation/scheduler.hpp b/src/cpp/include/openvino/genai/image_generation/scheduler.hpp index 21c266aa50..25c5e07a2f 100644 --- a/src/cpp/include/openvino/genai/image_generation/scheduler.hpp +++ b/src/cpp/include/openvino/genai/image_generation/scheduler.hpp @@ -19,7 +19,8 @@ class OPENVINO_GENAI_EXPORTS Scheduler { DDIM, EULER_DISCRETE, FLOW_MATCH_EULER_DISCRETE, - PNDM + PNDM, + EULER_ANCESTRAL_DISCRETE }; static std::shared_ptr from_config(const std::filesystem::path& scheduler_config_path, diff --git a/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp new file mode 100644 index 0000000000..a63a073cfc --- /dev/null +++ b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.cpp @@ -0,0 +1,261 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "image_generation/schedulers/euler_ancestral_discrete.hpp" +#include "image_generation/numpy_utils.hpp" + +namespace ov { +namespace genai { + +EulerAncestralDiscreteScheduler::Config::Config(const std::filesystem::path& scheduler_config_path) { + std::ifstream file(scheduler_config_path); + OPENVINO_ASSERT(file.is_open(), "Failed to open ", scheduler_config_path); + + nlohmann::json data = nlohmann::json::parse(file); + using utils::read_json_param; + + read_json_param(data, "num_train_timesteps", num_train_timesteps); + read_json_param(data, "beta_start", beta_start); + read_json_param(data, "beta_end", beta_end); + read_json_param(data, "beta_schedule", beta_schedule); + read_json_param(data, "trained_betas", trained_betas); + read_json_param(data, "steps_offset", steps_offset); + read_json_param(data, "prediction_type", prediction_type); + read_json_param(data, "timestep_spacing", timestep_spacing); + read_json_param(data, "rescale_betas_zero_snr", rescale_betas_zero_snr); +} + +EulerAncestralDiscreteScheduler::EulerAncestralDiscreteScheduler(const std::filesystem::path& scheduler_config_path) + : EulerAncestralDiscreteScheduler(Config(scheduler_config_path)) { +} + +EulerAncestralDiscreteScheduler::EulerAncestralDiscreteScheduler(const Config& scheduler_config): m_config(scheduler_config) { + std::vector alphas, betas; + + using numpy_utils::linspace; + + if (!m_config.trained_betas.empty()) { + betas = m_config.trained_betas; + } else if (m_config.beta_schedule == BetaSchedule::LINEAR) { + betas = linspace(m_config.beta_start, m_config.beta_end, m_config.num_train_timesteps); + } else if (m_config.beta_schedule == BetaSchedule::SCALED_LINEAR) { + float start = std::sqrt(m_config.beta_start); + float end = std::sqrt(m_config.beta_end); + betas = linspace(start, end, m_config.num_train_timesteps); + std::for_each(betas.begin(), betas.end(), [](float& x) { + x *= x; + }); + // TODO: else if beta_schedule == "squaredcos_cap_v2" + } else { + OPENVINO_THROW( + "'beta_schedule' must be one of 'LINEAR' or 'SCALED_LINEAR'. Please, add support of other types"); + } + + if (m_config.rescale_betas_zero_snr) { + using numpy_utils::rescale_zero_terminal_snr; + rescale_zero_terminal_snr(betas); + } + + std::transform(betas.begin(), betas.end(), std::back_inserter(alphas), [](float b) { + return 1.0f - b; + }); + + for (size_t i = 1; i <= alphas.size(); ++i) { + float alpha_cumprod = + std::accumulate(std::begin(alphas), std::begin(alphas) + i, 1.0, std::multiplies{}); + m_alphas_cumprod.push_back(alpha_cumprod); + } + + if (m_config.rescale_betas_zero_snr) { + m_alphas_cumprod.back() = std::pow(2, -24); + } + + for (auto it = m_alphas_cumprod.rbegin(); it != m_alphas_cumprod.rend(); ++it) { + float sigma = std::pow(((1 - (*it)) / (*it)), 0.5); + m_sigmas.push_back(sigma); + } + m_sigmas.push_back(0); + + // setable values + auto linspaced = + linspace(0.0f, static_cast(m_config.num_train_timesteps - 1), m_config.num_train_timesteps, true); + for (auto it = linspaced.rbegin(); it != linspaced.rend(); ++it) { + m_timesteps.push_back(static_cast(std::round(*it))); + } + m_num_inference_steps = -1; + m_step_index = -1; + m_begin_index = -1; + m_is_scale_input_called = false; +} + +void EulerAncestralDiscreteScheduler::set_timesteps(size_t num_inference_steps, float strength) { + m_timesteps.clear(); + m_sigmas.clear(); + m_step_index = m_begin_index = -1; + m_num_inference_steps = num_inference_steps; + std::vector sigmas; + + switch (m_config.timestep_spacing) { + case TimestepSpacing::LINSPACE: { + using numpy_utils::linspace; + float end = static_cast(m_config.num_train_timesteps - 1); + auto linspaced = linspace(0.0f, end, num_inference_steps, true); + for (auto it = linspaced.rbegin(); it != linspaced.rend(); ++it) { + m_timesteps.push_back(static_cast(std::round(*it))); + } + break; + } + case TimestepSpacing::LEADING: { + size_t step_ratio = m_config.num_train_timesteps / m_num_inference_steps; + for (size_t i = num_inference_steps - 1; i != -1; --i) { + m_timesteps.push_back(i * step_ratio + m_config.steps_offset); + } + break; + } + case TimestepSpacing::TRAILING: { + float step_ratio = static_cast(m_config.num_train_timesteps) / static_cast(m_num_inference_steps); + for (float i = m_config.num_train_timesteps; i > 0; i -= step_ratio) { + m_timesteps.push_back(static_cast(std::round(i)) - 1); + } + break; + } + default: + OPENVINO_THROW("Unsupported value for 'timestep_spacing'"); + } + + for (const float& i : m_alphas_cumprod) { + float sigma = std::pow(((1 - i) / i), 0.5); + sigmas.push_back(sigma); + } + + using numpy_utils::interp; + std::vector x_data_points(sigmas.size()); + std::iota(x_data_points.begin(), x_data_points.end(), 0); + m_sigmas = interp(m_timesteps, x_data_points, sigmas); + m_sigmas.push_back(0.0f); + + // apply 'strength' used in image generation + // in diffusers, it's https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L650 + { + size_t init_timestep = std::min(num_inference_steps * strength, num_inference_steps); + size_t t_start = std::max(num_inference_steps - init_timestep, 0); + // keep original timesteps + m_schedule_timesteps = m_timesteps; + // while return patched ones by 'strength' parameter + m_timesteps = std::vector(m_timesteps.begin() + t_start, m_timesteps.end()); + m_begin_index = t_start; + } +} + +std::map EulerAncestralDiscreteScheduler::step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr generator) { + // noise_pred - model_output + // latents - sample + // inference_step + + size_t timestep = m_timesteps[inference_step]; + + if (m_step_index == -1) + m_step_index = m_begin_index; + + float sigma = m_sigmas[m_step_index]; + + float* model_output_data = noise_pred.data(); + float* sample_data = latents.data(); + + ov::Tensor pred_original_sample(noise_pred.get_element_type(), noise_pred.get_shape()); + float* pred_original_sample_data = pred_original_sample.data(); + + switch (m_config.prediction_type) { + case PredictionType::EPSILON: + for (size_t i = 0; i < noise_pred.get_size(); ++i) { + pred_original_sample_data[i] = sample_data[i] - sigma * model_output_data[i]; + } + break; + case PredictionType::V_PREDICTION: + for (size_t i = 0; i < noise_pred.get_size(); ++i) { + pred_original_sample_data[i] = model_output_data[i] * (-sigma / std::pow((std::pow(sigma, 2) + 1), 0.5)) + + (sample_data[i] / (std::pow(sigma, 2) + 1)); + } + break; + default: + OPENVINO_THROW("Unsupported value for 'PredictionType': must be one of `epsilon`, or `v_prediction`"); + } + + float sigma_from = m_sigmas[m_step_index]; + float sigma_to = m_sigmas[m_step_index + 1]; + float sigma_up = std::sqrt(std::pow(sigma_to, 2) * (std::pow(sigma_from, 2) - std::pow(sigma_to, 2)) / std::pow(sigma_from, 2)); + float sigma_down = std::sqrt(std::pow(sigma_to, 2) - std::pow(sigma_up, 2)); + float dt = sigma_down - sigma; + + ov::Tensor prev_sample = ov::Tensor(latents.get_element_type(), latents.get_shape()); + float* prev_sample_data = prev_sample.data(); + + ov::Tensor noise = generator->randn_tensor(noise_pred.get_shape()); + const float* noise_data = noise.data(); + + for (size_t i = 0; i < prev_sample.get_size(); ++i) { + float derivative = (sample_data[i] - pred_original_sample_data[i]) / sigma; + prev_sample_data[i] = (sample_data[i] + derivative * dt) + noise_data[i] * sigma_up; + } + + m_step_index++; + + return {{"latent", prev_sample}, {"denoised", pred_original_sample}}; +} + +size_t EulerAncestralDiscreteScheduler::_index_for_timestep(int64_t timestep) const{ + for (size_t i = 0; i < m_schedule_timesteps.size(); ++i) { + if (timestep == m_schedule_timesteps[i]) { + return i; + } + } + + OPENVINO_THROW("Failed to find index for timestep ", timestep); +} + +void EulerAncestralDiscreteScheduler::add_noise(ov::Tensor init_latent, ov::Tensor noise, int64_t latent_timestep) const { + size_t index_for_timestep = _index_for_timestep(latent_timestep); + const float sigma = m_sigmas[index_for_timestep]; + + float * init_latent_data = init_latent.data(); + const float * noise_data = noise.data(); + + for (size_t i = 0; i < init_latent.get_size(); ++i) { + init_latent_data[i] = init_latent_data[i] + sigma * noise_data[i]; + } +} + +std::vector EulerAncestralDiscreteScheduler::get_timesteps() const { + return m_timesteps; +} + +void EulerAncestralDiscreteScheduler::scale_model_input(ov::Tensor sample, size_t inference_step) { + if (m_step_index == -1) + m_step_index = m_begin_index; + + float sigma = m_sigmas[m_step_index]; + float* sample_data = sample.data(); + for (size_t i = 0; i < sample.get_size(); i++) { + sample_data[i] /= std::pow((std::pow(sigma, 2) + 1), 0.5); + } + m_is_scale_input_called = true; +} + +float EulerAncestralDiscreteScheduler::get_init_noise_sigma() const { + float max_sigma = *std::max_element(m_sigmas.begin(), m_sigmas.end()); + + if (m_config.timestep_spacing == TimestepSpacing::LINSPACE || + m_config.timestep_spacing == TimestepSpacing::TRAILING) { + return max_sigma; + } + + return std::sqrt(std::pow(max_sigma, 2) + 1); +} + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp new file mode 100644 index 0000000000..9d82c9a0a9 --- /dev/null +++ b/src/cpp/src/image_generation/schedulers/euler_ancestral_discrete.hpp @@ -0,0 +1,61 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "image_generation/schedulers/types.hpp" +#include "image_generation/schedulers/ischeduler.hpp" + +namespace ov { +namespace genai { + +class EulerAncestralDiscreteScheduler : public IScheduler { +public: + struct Config { + int32_t num_train_timesteps = 1000; + float beta_start = 0.0001f, beta_end = 0.02f; + BetaSchedule beta_schedule = BetaSchedule::LINEAR; + std::vector trained_betas = {}; + size_t steps_offset = 0; + PredictionType prediction_type = PredictionType::EPSILON; + TimestepSpacing timestep_spacing = TimestepSpacing::LEADING; + bool rescale_betas_zero_snr = false; + + Config() = default; + explicit Config(const std::filesystem::path& scheduler_config_path); + }; + + explicit EulerAncestralDiscreteScheduler(const std::filesystem::path& scheduler_config_path); + explicit EulerAncestralDiscreteScheduler(const Config& scheduler_config); + + void set_timesteps(size_t num_inference_steps, float strength) override; + + std::vector get_timesteps() const override; + + float get_init_noise_sigma() const override; + + void scale_model_input(ov::Tensor sample, size_t inference_step) override; + + std::map step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr generator) override; + + void add_noise(ov::Tensor init_latent, ov::Tensor noise, int64_t latent_timestep) const override; + +private: + Config m_config; + + std::vector m_alphas_cumprod, m_sigmas; + std::vector m_timesteps, m_schedule_timesteps; + size_t m_num_inference_steps; + + int m_step_index, m_begin_index; + bool m_is_scale_input_called; + + size_t _index_for_timestep(int64_t timestep) const; +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/image_generation/schedulers/scheduler.cpp b/src/cpp/src/image_generation/schedulers/scheduler.cpp index f9cd098346..868f6f05cf 100644 --- a/src/cpp/src/image_generation/schedulers/scheduler.cpp +++ b/src/cpp/src/image_generation/schedulers/scheduler.cpp @@ -11,6 +11,7 @@ #include "image_generation/schedulers/euler_discrete.hpp" #include "image_generation/schedulers/flow_match_euler_discrete.hpp" #include "image_generation/schedulers/pndm.hpp" +#include "image_generation/schedulers/euler_ancestral_discrete.hpp" namespace ov { namespace genai { @@ -41,6 +42,8 @@ std::shared_ptr Scheduler::from_config(const std::filesystem::path& s scheduler = std::make_shared(scheduler_config_path); } else if (scheduler_type == Scheduler::Type::PNDM) { scheduler = std::make_shared(scheduler_config_path); + } else if (scheduler_type == Scheduler::Type::EULER_ANCESTRAL_DISCRETE) { + scheduler = std::make_shared(scheduler_config_path); } else { OPENVINO_THROW("Unsupported scheduler type '", scheduler_type, ". Please, manually create scheduler via supported one"); } diff --git a/src/cpp/src/image_generation/schedulers/types.cpp b/src/cpp/src/image_generation/schedulers/types.cpp index 2f7c6d3f25..5a9e5b6865 100644 --- a/src/cpp/src/image_generation/schedulers/types.cpp +++ b/src/cpp/src/image_generation/schedulers/types.cpp @@ -57,6 +57,8 @@ void read_json_param(const nlohmann::json& data, const std::string& name, Schedu param = Scheduler::FLOW_MATCH_EULER_DISCRETE; else if (scheduler_type_str == "PNDMScheduler") param = Scheduler::PNDM; + else if (scheduler_type_str == "EulerAncestralDiscreteScheduler") + param = Scheduler::EULER_ANCESTRAL_DISCRETE; else if (!scheduler_type_str.empty()) { OPENVINO_THROW("Unsupported value for 'scheduler' ", scheduler_type_str); } diff --git a/src/docs/SUPPORTED_MODELS.md b/src/docs/SUPPORTED_MODELS.md index 8c922ee644..9762874596 100644 --- a/src/docs/SUPPORTED_MODELS.md +++ b/src/docs/SUPPORTED_MODELS.md @@ -217,6 +217,7 @@ The pipeline can work with other similar topologies produced by `optimum-intel` diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 524ff0f921..bfcb869157 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1343,15 +1343,18 @@ class Scheduler: FLOW_MATCH_EULER_DISCRETE PNDM + + EULER_ANCESTRAL_DISCRETE """ AUTO: typing.ClassVar[Scheduler.Type] # value = DDIM: typing.ClassVar[Scheduler.Type] # value = + EULER_ANCESTRAL_DISCRETE: typing.ClassVar[Scheduler.Type] # value = EULER_DISCRETE: typing.ClassVar[Scheduler.Type] # value = FLOW_MATCH_EULER_DISCRETE: typing.ClassVar[Scheduler.Type] # value = LCM: typing.ClassVar[Scheduler.Type] # value = LMS_DISCRETE: typing.ClassVar[Scheduler.Type] # value = PNDM: typing.ClassVar[Scheduler.Type] # value = - __members__: typing.ClassVar[dict[str, Scheduler.Type]] # value = {'AUTO': , 'LCM': , 'LMS_DISCRETE': , 'DDIM': , 'EULER_DISCRETE': , 'FLOW_MATCH_EULER_DISCRETE': , 'PNDM': } + __members__: typing.ClassVar[dict[str, Scheduler.Type]] # value = {'AUTO': , 'LCM': , 'LMS_DISCRETE': , 'DDIM': , 'EULER_DISCRETE': , 'FLOW_MATCH_EULER_DISCRETE': , 'PNDM': , 'EULER_ANCESTRAL_DISCRETE': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: diff --git a/src/python/py_image_generation_pipelines.cpp b/src/python/py_image_generation_pipelines.cpp index f5347c279d..311f3f3760 100644 --- a/src/python/py_image_generation_pipelines.cpp +++ b/src/python/py_image_generation_pipelines.cpp @@ -198,7 +198,8 @@ void init_image_generation_pipelines(py::module_& m) { .value("DDIM", ov::genai::Scheduler::Type::DDIM) .value("EULER_DISCRETE", ov::genai::Scheduler::Type::EULER_DISCRETE) .value("FLOW_MATCH_EULER_DISCRETE", ov::genai::Scheduler::Type::FLOW_MATCH_EULER_DISCRETE) - .value("PNDM", ov::genai::Scheduler::Type::PNDM); + .value("PNDM", ov::genai::Scheduler::Type::PNDM) + .value("EULER_ANCESTRAL_DISCRETE", ov::genai::Scheduler::Type::EULER_ANCESTRAL_DISCRETE); image_generation_scheduler.def_static("from_config", &ov::genai::Scheduler::from_config, py::arg("scheduler_config_path"), diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index c3df84925b..316c9d0b89 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -421,7 +421,7 @@ def get_vae_decoder_step_count(self): scheduler_type = data.get("scheduler", ["", ""])[1] if (scheduler_type not in ["LCMScheduler", "DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler", "EulerDiscreteScheduler", - "FlowMatchEulerDiscreteScheduler"]): + "FlowMatchEulerDiscreteScheduler", "EulerAncestralDiscreteScheduler"]): scheduler = openvino_genai.Scheduler.from_config(model_path / "scheduler/scheduler_config.json", openvino_genai.Scheduler.Type.DDIM) log.warning(f'Type of scheduler {scheduler_type} is unsupported. Please, be aware that it will be replaced to DDIMScheduler')