Skip to content

Commit

Permalink
[Image generation] Added num_steps to callback (#1372)
Browse files Browse the repository at this point in the history
With image to image and inpainting, an user passed `num_inference_steps`
is scaled based on `strength` parameter.
So, we need to report actual number of steps within `callback`

CC @RyanMetcalfeInt8
  • Loading branch information
ilya-lavrenov authored Dec 13, 2024
1 parent b955ea6 commit d17f716
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 45 deletions.
6 changes: 3 additions & 3 deletions samples/cpp/image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Please find the template of the callback usage below.
```cpp
ov::genai::Text2ImagePipeline pipe(models_path, device);

auto callback = [&](size_t step, ov::Tensor& intermediate_res) -> bool {
std::cout << "Image generation step: " << step << std::endl;
ov::Tensor img = pipe.decode(intermediate_res); // get intermediate image tensor
auto callback = [&](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
std::cout << "Image generation step: " << step << " / " << num_steps << std::endl;
ov::Tensor img = pipe.decode(latent); // get intermediate image tensor
if (your_condition) // return true if you want to interrupt image generation
return true;
return false;
Expand Down
6 changes: 3 additions & 3 deletions samples/python/image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ Please find the template of the callback usage below.
```python
pipe = openvino_genai.Text2ImagePipeline(model_dir, device)

def callback(step, intermediate_res):
print("Image generation step: ", step)
image_tensor = pipe.decode(intermediate_res) # get intermediate image tensor
def callback(step, num_steps, latent):
print(f"Image generation step: {step} / {num_steps}")
image_tensor = pipe.decode(latent) # get intermediate image tensor
if your_condition: # return True if you want to interrupt image generation
return True
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ static constexpr ov::Property<int> max_sequence_length{"max_sequence_length"};

/**
* User callback for image generation pipelines, which is called within a pipeline with the following arguments:
* - Total number of inference steps. Note, that in case of 'strength' parameter, the number of inference steps is reduced linearly
* - Current inference step
* - Total number of inference steps. Note, that in case of 'strength' parameter, the number of inference steps is reduced linearly
* - Tensor representing current latent. Such latent can be converted to human-readable representation via image generation pipeline 'decode()' method
*/
static constexpr ov::Property<std::function<bool(size_t, ov::Tensor&)>> callback{"callback"};
static constexpr ov::Property<std::function<bool(size_t, size_t, ov::Tensor&)>> callback{"callback"};

/**
* Function to pass 'ImageGenerationConfig' as property to 'generate()' call.
Expand Down
22 changes: 7 additions & 15 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,11 @@ class FluxPipeline : public DiffusionPipeline {
m_custom_generation_config.strength = 1.0f;
}

if (!initial_image) {
// in case of typical text to image generation, we need to ignore 'strength'
m_custom_generation_config.strength = 1.0f;
// Use callback if defined
std::function<bool(size_t, size_t, ov::Tensor&)> callback = nullptr;
auto callback_iter = properties.find(ov::genai::callback.name());
if (callback_iter != properties.end()) {
callback = callback_iter->second.as<std::function<bool(size_t, size_t, ov::Tensor&)>>();
}

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
Expand All @@ -355,14 +357,6 @@ class FluxPipeline : public DiffusionPipeline {
m_scheduler->set_timesteps_with_sigma(sigmas, mu);
std::vector<float> timesteps = m_scheduler->get_float_timesteps();

// Use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

// 6. Denoising loop
ov::Tensor timestep(ov::element::f32, {1});
float* timestep_data = timestep.data<float>();
Expand All @@ -375,10 +369,8 @@ class FluxPipeline : public DiffusionPipeline {
auto scheduler_step_result = m_scheduler->step(noise_pred_tensor, latents, inference_step, m_custom_generation_config.generator);
latents = scheduler_step_result["latent"];

if (do_callback) {
if (callback(inference_step, latents)) {
return ov::Tensor(ov::element::u8, {});
}
if (callback && callback(inference_step, timesteps.size(), latents)) {
return ov::Tensor(ov::element::u8, {});
}
}

Expand Down
21 changes: 9 additions & 12 deletions src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,13 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
generation_config.strength = 1.0f;
}

// Use callback if defined
std::function<bool(size_t, size_t, ov::Tensor&)> callback = nullptr;
auto callback_iter = properties.find(ov::genai::callback.name());
if (callback_iter != properties.end()) {
callback = callback_iter->second.as<std::function<bool(size_t, size_t, ov::Tensor&)>>();
}

const auto& transformer_config = m_transformer->get_config();
const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale)
Expand Down Expand Up @@ -467,14 +474,6 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
// 6. Denoising loop
ov::Tensor noisy_residual_tensor(ov::element::f32, {});

// Use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
}

for (size_t inference_step = 0; inference_step < timesteps.size(); ++inference_step) {
// concat the same latent twice along a batch dimension in case of CFG
if (batch_size_multiplier > 1) {
Expand Down Expand Up @@ -510,10 +509,8 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
auto scheduler_step_result = m_scheduler->step(noisy_residual_tensor, latent, inference_step, generation_config.generator);
latent = scheduler_step_result["latent"];

if (do_callback) {
if (callback(inference_step, latent)) {
return ov::Tensor(ov::element::u8, {});
}
if (callback && callback(inference_step, timesteps.size(), latent)) {
return ov::Tensor(ov::element::u8, {});
}
}

Expand Down
13 changes: 5 additions & 8 deletions src/cpp/src/image_generation/stable_diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,10 @@ class StableDiffusionPipeline : public DiffusionPipeline {
}

// use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
std::function<bool(size_t, size_t, ov::Tensor&)> callback = nullptr;
auto callback_iter = properties.find(ov::genai::callback.name());
bool do_callback = callback_iter != properties.end();
if (do_callback) {
callback = callback_iter->second.as<std::function<bool(size_t, ov::Tensor&)>>();
if (callback_iter != properties.end()) {
callback = callback_iter->second.as<std::function<bool(size_t, size_t, ov::Tensor&)>>();
}

// Stable Diffusion pipeline
Expand Down Expand Up @@ -400,10 +399,8 @@ class StableDiffusionPipeline : public DiffusionPipeline {
const auto it = scheduler_step_result.find("denoised");
denoised = it != scheduler_step_result.end() ? it->second : latent;

if (do_callback) {
if (callback(inference_step, denoised)) {
return ov::Tensor(ov::element::u8, {});
}
if (callback && callback(inference_step, timesteps.size(), denoised)) {
return ov::Tensor(ov::element::u8, {});
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
} else if (py::isinstance<ov::genai::Generator>(py_obj)) {
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
return py::cast<std::function<bool(size_t, ov::Tensor&)>>(py_obj);
return py::cast<std::function<bool(size_t, size_t, ov::Tensor&)>>(py_obj);
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
auto streamer = py::cast<ov::genai::pybind::utils::PyBindStreamerVariant>(py_obj);
return ov::genai::streamer(pystreamer_to_streamer(streamer)).second;
Expand Down
2 changes: 1 addition & 1 deletion tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __init__(self) -> types.NoneType:
self.start_time = time.perf_counter()
self.duration = -1

def __call__(self, step, latents):
def __call__(self, step, num_steps, latents):
self.iteration_time.append(time.perf_counter() - self.start_time)
self.start_time = time.perf_counter()
return False
Expand Down

0 comments on commit d17f716

Please sign in to comment.