Skip to content

Commit

Permalink
Fix wrong logits processing without applying of slice matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Nov 15, 2024
1 parent 96bcffe commit 44650f4
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(

auto logits = m_llm.get_tensor("logits");

// if slice matmul is not applied logits will contains not only result tokens
size_t vocab_size = logits.get_shape().back();
if (!m_embedding.has_value()) {
ov::Tensor new_logits = ov::Tensor(logits.get_element_type(), {batch_size, 1, vocab_size});
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;

for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;
std::copy(logits_data, logits_data + vocab_size, new_logits.data<float>() + batch_idx * vocab_size);
}
logits = new_logits;
}

int64_t sequence_len = logits.get_shape().at(1);
for (auto& sequence_group : sequence_groups)
sequence_group->schedule_tokens(sequence_len);
Expand Down

0 comments on commit 44650f4

Please sign in to comment.