diff --git a/src/cpp/src/logit_processor.hpp b/src/cpp/src/logit_processor.hpp index b0303e5bde..bc7ba50648 100644 --- a/src/cpp/src/logit_processor.hpp +++ b/src/cpp/src/logit_processor.hpp @@ -299,6 +299,20 @@ class PresencePenaltyTransform : public IPenaltyTransformer { } }; + +class LogFilter : public ILogitTransformer { +public: + LogFilter() {} + + void apply(Logits& logits) override { + OPENVINO_ASSERT(logits.is_vector_initialized()); + for (size_t i = 0; i < logits.m_size; i++) { + logits.m_vector[i].m_log_prob = std::log(logits.m_vector[i].m_log_prob); + } + } +}; + + } // namespace LogitTransformers class LogitProcessor { @@ -352,6 +366,7 @@ class LogitProcessor { if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits::max()) { m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k)); } + m_logit_transformers.emplace_back(new LogitTransformers::LogFilter()); } } } diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index d7a56b0bfb..cd164619f5 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -358,7 +358,15 @@ class Sampler { max_index = i; } } - return Token(logits.m_data[max_index], max_index); + + // apply log softmax to max value + float log_sum = std::log(std::accumulate( + logits.m_data, logits.m_data + logits.m_size, 0.0f, [max_value](float accumulated, float to_add) { + return accumulated + std::exp(to_add - max_value); + })); + max_value = -log_sum; + + return Token(max_value, max_index); } std::vector _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {