Skip to content

Commit

Permalink
StaticWhisperPipeline: fix encoder input_features reshape (openvinoto…
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae authored Dec 4, 2024
1 parent bc388a3 commit 0b81108
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,15 +432,17 @@ void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_si
model->reshape(new_shapes);
}

void reshape_to_static_encoder(std::shared_ptr<ov::Model> model) {
void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t feature_size) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_features") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
OPENVINO_ASSERT(partial_shape.size() >= 3);
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = feature_size;
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -540,7 +542,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys

size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model);
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
Expand Down

0 comments on commit 0b81108

Please sign in to comment.