-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper: Fix decoder inputs for static pipeline #1469
base: master
Are you sure you want to change the base?
Conversation
@@ -25,6 +25,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi | |||
|
|||
private: | |||
WhisperInitializedModels m_models; | |||
std::shared_ptr<ov::Model> m_decoder_model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to store model? once model is compiled, we need to release ov::Model to free memory consumed by its weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't compile this model until generate()
called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I'd assume we have something like this:
class DecoderCache {
public:
ov::CompiledModel get_model(uint8_t input_id_size) {
// Get from hash table, otherwise compile and store...
}
private:
// [input_ids_size -> CompiledModel]
std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
std::shared_ptr<ov::Model> decoder_model; // <- this is dynamic w/o transformation applied
}
// Whenever we need a model:
auto decoder = m_decoder_cache.get(input_ids_size);
std::fill(input_ids_data + init_ids.size(), | ||
input_ids_data + input_ids_tensor.get_size(), | ||
static_cast<int32_t>(pad_token)); | ||
// std::fill(input_ids_data + init_ids.size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0); | ||
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1); | ||
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0); | ||
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why -2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, since we reshape model anyway, we don't need attention_mask
at all, probably we may not apply transformation that expose this
@@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) { | |||
preprocessor.input("attention_mask").preprocess().convert_element_type(); | |||
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) { | |||
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16); | |||
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why it's removed?
@@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( | |||
|
|||
// prepare init_ids just once for whole input | |||
if (init_ids.empty()) { | |||
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this check do?
Tickets: