Skip to content

Commit

Permalink
[Image generation] Rework max_sequence_length handling in T5 (#1211)
Browse files Browse the repository at this point in the history
`max_sequence_length` is not available as parameter in T5 config.json,
it's set as parameter of `generate` method
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L696

So, we need to rely on per-pipeline default value.
  • Loading branch information
ilya-lavrenov authored Nov 14, 2024
1 parent 860c39a commit f99751c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct OPENVINO_GENAI_EXPORTS ImageGenerationConfig {
size_t num_inference_steps = 50;

// the following value used by t5_encoder_model (Flux, SD3 pipelines)
size_t max_sequence_length = -1;
int max_sequence_length = -1;

// used by some image to image pipelines to balance between noise and initial image
// higher 'stregth' value means more noise is added to initial latent image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ namespace genai {

class OPENVINO_GENAI_EXPORTS T5EncoderModel {
public:
struct OPENVINO_GENAI_EXPORTS Config {
size_t max_sequence_length = 512;

explicit Config(const std::filesystem::path& config_path);
};

explicit T5EncoderModel(const std::filesystem::path& root_dir);

T5EncoderModel(const std::filesystem::path& root_dir,
Expand All @@ -41,11 +35,7 @@ class OPENVINO_GENAI_EXPORTS T5EncoderModel {

T5EncoderModel(const T5EncoderModel&);

const Config& get_config() const;

void set_max_sequence_length(size_t max_sequence_length);

T5EncoderModel& reshape(int batch_size);
T5EncoderModel& reshape(int batch_size, int max_sequence_length);

T5EncoderModel& compile(const std::string& device, const ov::AnyMap& properties = {});

Expand All @@ -56,12 +46,11 @@ class OPENVINO_GENAI_EXPORTS T5EncoderModel {
return compile(device, ov::AnyMap{std::forward<Properties>(properties)...});
}

ov::Tensor infer(const std::string& pos_prompt);
ov::Tensor infer(const std::string& pos_prompt, int max_sequence_length);

ov::Tensor get_output_tensor(const size_t idx);

private:
Config m_config;
AdapterController m_adapter_controller;
ov::InferRequest m_request;
std::shared_ptr<ov::Model> m_model;
Expand Down
31 changes: 17 additions & 14 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,16 @@ class FluxPipeline : public DiffusionPipeline {
check_image_size(height, width);

m_clip_text_encoder->reshape(1);
m_t5_text_encoder->reshape(1);
m_transformer->reshape(num_images_per_prompt, height, width, m_t5_text_encoder->get_config().max_sequence_length);

// TODO: max_sequence_length cannot be specified easily outside, only via:
// Text2ImagePipeline pipe("/path");
// ImageGenerationConfig default_config = pipe.get_generation_config();
// default_config.max_sequence_length = 30;
// pipe.set_generation_config(default_config);
// pipe.reshape(1, 512, 512, default_config.guidance_scale);
m_t5_text_encoder->reshape(1, m_generation_config.max_sequence_length);
m_transformer->reshape(num_images_per_prompt, height, width, m_generation_config.max_sequence_length);

m_vae->reshape(num_images_per_prompt, height, width);
}

Expand Down Expand Up @@ -255,10 +263,8 @@ class FluxPipeline : public DiffusionPipeline {
using namespace numpy_utils;
ImageGenerationConfig generation_config = m_generation_config;
generation_config.update_generation_config(properties);
m_t5_text_encoder->set_max_sequence_length(generation_config.max_sequence_length);

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();

const auto& transformer_config = m_transformer->get_config();

if (generation_config.height < 0)
Expand All @@ -275,7 +281,7 @@ class FluxPipeline : public DiffusionPipeline {
m_clip_text_encoder->infer(positive_prompt, "", false);
ov::Tensor pooled_prompt_embeds_out = m_clip_text_encoder->get_output_tensor(1);

ov::Tensor prompt_embeds_out = m_t5_text_encoder->infer(positive_prompt);
ov::Tensor prompt_embeds_out = m_t5_text_encoder->infer(positive_prompt, generation_config.max_sequence_length);

ov::Tensor pooled_prompt_embeds, prompt_embeds;
if (generation_config.num_images_per_prompt == 1) {
Expand Down Expand Up @@ -344,6 +350,7 @@ class FluxPipeline : public DiffusionPipeline {
if (class_name == "FluxPipeline") {
m_generation_config.guidance_scale = 3.5f;
m_generation_config.num_inference_steps = 28;
m_generation_config.max_sequence_length = 512;
} else {
OPENVINO_THROW("Unsupported class_name '", class_name, "'. Please, contact OpenVINO GenAI developers");
}
Expand All @@ -361,16 +368,12 @@ class FluxPipeline : public DiffusionPipeline {
void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override {
check_image_size(generation_config.width, generation_config.height);

const char* const pipeline_name = "Flux";

OPENVINO_ASSERT(generation_config.negative_prompt == std::nullopt,
"Negative prompt is not used by ", pipeline_name);
OPENVINO_ASSERT(generation_config.negative_prompt_2 == std::nullopt,
"Negative prompt 2 is not used by ", pipeline_name);
OPENVINO_ASSERT(generation_config.negative_prompt_3 == std::nullopt,
"Negative prompt 3 is not used by ", pipeline_name);
OPENVINO_ASSERT(generation_config.max_sequence_length < 512, "T5's 'max_sequence_length' must be less than 512");

OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name);
OPENVINO_ASSERT(generation_config.negative_prompt == std::nullopt, "Negative prompt is not used by FluxPipeline");
OPENVINO_ASSERT(generation_config.negative_prompt_2 == std::nullopt, "Negative prompt 2 is not used by FluxPipeline");
OPENVINO_ASSERT(generation_config.negative_prompt_3 == std::nullopt, "Negative prompt 3 is not used by FluxPipeline");
OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by FluxPipeline");
}

std::shared_ptr<FluxTransformer2DModel> m_transformer;
Expand Down
37 changes: 14 additions & 23 deletions src/cpp/src/image_generation/models/t5_encoder_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ namespace genai {

std::filesystem::path get_tokenizer_path_by_text_encoder(const std::filesystem::path& text_encoder_path);

T5EncoderModel::Config::Config(const std::filesystem::path& config_path) {
std::ifstream file(config_path);
OPENVINO_ASSERT(file.is_open(), "Failed to open ", config_path);
}

T5EncoderModel::T5EncoderModel(const std::filesystem::path& root_dir) :
m_tokenizer(get_tokenizer_path_by_text_encoder(root_dir)),
m_config(root_dir / "config.json") {
m_tokenizer(get_tokenizer_path_by_text_encoder(root_dir)) {
ov::Core core = utils::singleton_core();
m_model = core.read_model((root_dir / "openvino_model.xml").string());
}
Expand All @@ -35,21 +29,12 @@ T5EncoderModel::T5EncoderModel(const std::filesystem::path& root_dir,

T5EncoderModel::T5EncoderModel(const T5EncoderModel&) = default;

const T5EncoderModel::Config& T5EncoderModel::get_config() const {
return m_config;
}

void T5EncoderModel::set_max_sequence_length(size_t max_sequence_length) {
if (max_sequence_length != -1)
m_config.max_sequence_length = max_sequence_length;
}

T5EncoderModel& T5EncoderModel::reshape(int batch_size) {
T5EncoderModel& T5EncoderModel::reshape(int batch_size, int max_sequence_length) {
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot reshape already compiled model");

ov::PartialShape input_shape = m_model->input(0).get_partial_shape();
input_shape[0] = batch_size;
input_shape[1] = m_config.max_sequence_length;
input_shape[1] = max_sequence_length;
std::map<size_t, ov::PartialShape> idx_to_shape{{0, input_shape}};
m_model->reshape(idx_to_shape);

Expand All @@ -68,24 +53,30 @@ T5EncoderModel& T5EncoderModel::compile(const std::string& device, const ov::Any
return *this;
}

ov::Tensor T5EncoderModel::infer(const std::string& pos_prompt) {
ov::Tensor T5EncoderModel::infer(const std::string& pos_prompt, int max_sequence_length) {
OPENVINO_ASSERT(m_request, "T5 encoder model must be compiled first. Cannot infer non-compiled model");

const int32_t pad_token_id = m_tokenizer.get_pad_token_id();

auto perform_tokenization = [&](const std::string& prompt, ov::Tensor input_ids) {
ov::Tensor input_ids_token = m_tokenizer.encode(prompt).input_ids;
size_t min_size = std::min(input_ids.get_size(), input_ids_token.get_size());
size_t min_length = std::min(input_ids.get_size(), input_ids_token.get_size());

std::fill_n(input_ids.data<int32_t>(), input_ids.get_size(), pad_token_id);
std::copy_n(input_ids_token.data<std::int64_t>(), min_size, input_ids.data<std::int32_t>());
std::copy_n(input_ids_token.data<std::int64_t>(), min_length, input_ids.data<std::int32_t>());
};

ov::Tensor input_ids = m_request.get_input_tensor();

// reshape in case of dynamic model
if (input_ids.get_shape()[0] == 0 || input_ids.get_shape()[1] == 0) {
input_ids.set_shape({1, m_config.max_sequence_length});
ov::Shape input_ids_shape = input_ids.get_shape();

OPENVINO_ASSERT(input_ids_shape[1] == 0 || max_sequence_length == input_ids_shape[1],
"In case of T5EncoderModel was reshaped before, reshape's max_sequence_length ", input_ids_shape[1], " must be equal to ",
"infer's max_sequence_length ", max_sequence_length);

if (input_ids_shape[0] == 0 || input_ids_shape[1] == 0) {
input_ids.set_shape({1, static_cast<size_t>(max_sequence_length)});
}

perform_tokenization(pos_prompt, input_ids);
Expand Down
7 changes: 4 additions & 3 deletions src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
const CLIPTextModelWithProjection& clip_text_model_1,
const CLIPTextModelWithProjection& clip_text_model_2,
const SD3Transformer2DModel& transformer,
const AutoencoderKL& vae_decoder)
const AutoencoderKL& vae)
: DiffusionPipeline(pipeline_type),
m_clip_text_encoder_1(std::make_shared<CLIPTextModelWithProjection>(clip_text_model_1)),
m_clip_text_encoder_2(std::make_shared<CLIPTextModelWithProjection>(clip_text_model_2)),
m_vae(std::make_shared<AutoencoderKL>(vae_decoder)),
m_vae(std::make_shared<AutoencoderKL>(vae)),
m_transformer(std::make_shared<SD3Transformer2DModel>(transformer)) {
initialize_generation_config("StableDiffusion3Pipeline");
}
Expand Down Expand Up @@ -594,6 +594,7 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
if (class_name == "StableDiffusion3Pipeline") {
m_generation_config.guidance_scale = 7.0f;
m_generation_config.num_inference_steps = 28;
m_generation_config.max_sequence_length = 256;
} else {
OPENVINO_THROW("Unsupported class_name '", class_name, "'. Please, contact OpenVINO GenAI developers");
}
Expand All @@ -616,8 +617,8 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
check_image_size(generation_config.width, generation_config.height);

const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale);
const char* const pipeline_name = "Stable Diffusion 3";

OPENVINO_ASSERT(generation_config.max_sequence_length < 512, "T5's 'max_sequence_length' must be less than 512");
OPENVINO_ASSERT(
generation_config.prompt_3 == std::nullopt || generation_config.negative_prompt_3 == std::nullopt,
"T5Encoder is not currently supported, 'prompt_3' and 'negative_prompt_3' can't be used. Please, add "
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/image_generation/text2image_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ Text2ImagePipeline Text2ImagePipeline::flux(
const CLIPTextModel& clip_text_model,
const T5EncoderModel t5_encoder_model,
const FluxTransformer2DModel& transformer,
const AutoencoderKL& vae_decoder){
auto impl = std::make_shared<FluxPipeline>(PipelineType::TEXT_2_IMAGE, clip_text_model, t5_encoder_model, transformer, vae_decoder);
const AutoencoderKL& vae){
auto impl = std::make_shared<FluxPipeline>(PipelineType::TEXT_2_IMAGE, clip_text_model, t5_encoder_model, transformer, vae);
assert(scheduler != nullptr);
impl->set_scheduler(scheduler);
return Text2ImagePipeline(impl);
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,10 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token
}

void slice_matmul_statefull_model(std::shared_ptr<ov::Model> model) {
ov::Node* matmul = nullptr;
auto last_node = model->output(0).get_node()->input_value(0).get_node();
if (matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node)) {
ov::Node* matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node);
if (matmul) {
// we have found matmul, do nothing
} else if(auto add = dynamic_cast<ov::op::v1::Add*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(add->input_value(0).get_node());
} else if (auto transpose = dynamic_cast<ov::op::v1::Transpose*>(last_node)) {
Expand Down

0 comments on commit f99751c

Please sign in to comment.