Skip to content

Commit

Permalink
Fix decoder inputs for static pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 3, 2025
1 parent 34dc469 commit dfba871
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
27 changes: 18 additions & 9 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
std::fill(input_ids_data + init_ids.size(),
input_ids_data + input_ids_tensor.get_size(),
static_cast<int32_t>(pad_token));
// std::fill(input_ids_data + init_ids.size(),
// input_ids_data + input_ids_tensor.get_size(),
// static_cast<int32_t>(pad_token));

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
Expand Down Expand Up @@ -210,13 +210,13 @@ void zero_past_key_values(ov::InferRequest& request) {
}
}

void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) {
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
// Mask should be inverted for decoder_with_past
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
attention_mask_ptr[attention_mask.get_size() - 2] = 0;
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
// NB: Zero past_key_values.*.decoder.value tensors
Expand Down Expand Up @@ -318,7 +318,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
return {false, output_tokens};
}

prepare_decoder_with_past(models.decoder_with_past, models.decoder);
prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());

for (size_t i = 0; i < max_new_tokens - 1; i++) {
auto output_token = decode_with_past(models.decoder_with_past,
Expand Down Expand Up @@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
Expand Down Expand Up @@ -563,7 +563,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language()
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -577,9 +577,12 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
compiled_model = core.compile_model(encoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();

m_decoder_model = decoder_model; // for reshape in generate() when we get number of input tokens
compiled_model = core.compile_model(decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();
Expand Down Expand Up @@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1);
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);

// Reshape decoder model for the number of input tokens
ov::Core core = utils::singleton_core();
reshape_to_static(m_decoder_model, init_ids.size(), init_ids.size(), hidden_state_tensor.get_shape());
m_models.decoder = core.compile_model(m_decoder_model, "NPU").create_infer_request();
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
std::shared_ptr<ov::Model> m_decoder_model;
};

} // namespace genai
Expand Down

0 comments on commit dfba871

Please sign in to comment.