diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 644aa369c6..cb27ca60fb 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -105,6 +105,20 @@ std::pair get_lm_encoded_results( auto logits = m_llm.get_tensor("logits"); + // if slice matmul is not applied logits will contains not only result tokens + size_t vocab_size = logits.get_shape().back(); + if (!m_embedding.has_value()) { + ov::Tensor new_logits = ov::Tensor(logits.get_element_type(), {batch_size, 1, vocab_size}); + size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; + + for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { + size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; + const float* logits_data = logits.data() + batch_offset + sequence_offset; + std::copy(logits_data, logits_data + vocab_size, new_logits.data() + batch_idx * vocab_size); + } + logits = new_logits; + } + int64_t sequence_len = logits.get_shape().at(1); for (auto& sequence_group : sequence_groups) sequence_group->schedule_tokens(sequence_len);