Skip to content

Commit

Permalink
Fixed logprobs for greedy. (#886)
Browse files Browse the repository at this point in the history
Applied LogSoftmax to logits returned by _greedy_sample().
Ticket: CVS-152274
  • Loading branch information
popovaan authored Sep 20, 2024
1 parent 40f962d commit de77f96
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,20 @@ class PresencePenaltyTransform : public IPenaltyTransformer {
}
};


class LogFilter : public ILogitTransformer {
public:
LogFilter() {}

void apply(Logits& logits) override {
OPENVINO_ASSERT(logits.is_vector_initialized());
for (size_t i = 0; i < logits.m_size; i++) {
logits.m_vector[i].m_log_prob = std::log(logits.m_vector[i].m_log_prob);
}
}
};


} // namespace LogitTransformers

class LogitProcessor {
Expand Down Expand Up @@ -352,6 +366,7 @@ class LogitProcessor {
if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits<size_t>::max()) {
m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k));
}
m_logit_transformers.emplace_back(new LogitTransformers::LogFilter());
}
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,15 @@ class Sampler {
max_index = i;
}
}
return Token(logits.m_data[max_index], max_index);

// apply log softmax to max value
float log_sum = std::log(std::accumulate(
logits.m_data, logits.m_data + logits.m_size, 0.0f, [max_value](float accumulated, float to_add) {
return accumulated + std::exp(to_add - max_value);
}));
max_value = -log_sum;

return Token(max_value, max_index);
}

std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {
Expand Down

0 comments on commit de77f96

Please sign in to comment.