Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: Fix decoder inputs for static pipeline #1469

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

eshiryae
Copy link
Contributor

@eshiryae eshiryae commented Jan 3, 2025

Tickets:

@github-actions github-actions bot added the category: whisper Whisper pipeline label Jan 3, 2025
@@ -25,6 +25,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
std::shared_ptr<ov::Model> m_decoder_model;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to store model? once model is compiled, we need to release ov::Model to free memory consumed by its weights

Copy link
Collaborator

@TolyaTalamanov TolyaTalamanov Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't compile this model until generate() called

Copy link
Collaborator

@TolyaTalamanov TolyaTalamanov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I'd assume we have something like this:

class DecoderCache {
public:
    ov::CompiledModel get_model(uint8_t input_id_size) {
        // Get from hash table, otherwise compile and store...
    }
private:
    // [input_ids_size -> CompiledModel]
    std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
    std::shared_ptr<ov::Model> decoder_model; // <- this is dynamic w/o transformation applied    
}


// Whenever we need a model:
auto decoder = m_decoder_cache.get(input_ids_size);

std::fill(input_ids_data + init_ids.size(),
input_ids_data + input_ids_tensor.get_size(),
static_cast<int32_t>(pad_token));
// std::fill(input_ids_data + init_ids.size(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, since we reshape model anyway, we don't need attention_mask at all, probably we may not apply transformation that expose this

@@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it's removed?

@@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this check do?

@dmatveev dmatveev changed the title Fix decoder inputs for static pipeline Whisper: Fix decoder inputs for static pipeline Jan 3, 2025
@dmatveev dmatveev added this to the 2025.0 milestone Jan 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants