Skip to content

Commit

Permalink
gtest adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Jul 23, 2024
1 parent 8aaf164 commit f4857e1
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 97 deletions.
13 changes: 8 additions & 5 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct Token {
struct Logits {
float * m_data = nullptr;
size_t m_size;
// Late initialized
// Late initialized for top_p or top_k transforms
std::vector<Token> m_vector;

Logits(float* data, size_t size): m_data(data), m_size(size) {}
Expand All @@ -29,8 +29,11 @@ struct Logits {
OPENVINO_ASSERT(m_vector.size() == 0, "Logits vector already initialized");
m_vector.reserve(m_size);
for (size_t i = 0; i < m_size; i++)
m_vector.emplace_back(m_data[i], i);

m_vector.emplace_back(m_data[i], i);
}

bool vector_initialized() const {
return m_vector.size() > 0;
}

void resize(size_t new_size) {
Expand All @@ -56,7 +59,7 @@ class TopPFilter : public ILogitTransformer {
TopPFilter(double top_p) : m_top_p(top_p) {}

void apply(Logits& logits) override {
if (logits.m_vector.size() == 0) {
if (!logits.vector_initialized()) {
// Initialize and sort vector
logits.initialize_vector();
std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
Expand Down Expand Up @@ -84,7 +87,7 @@ class TopKFilter : public ILogitTransformer {
if (m_top_k >= logits.m_size)
return;

if (logits.m_vector.size() == 0) {
if (!logits.vector_initialized()) {
// Initialize and partially sort vector
logits.initialize_vector();
std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class Sampler {
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
std::vector<float> multinomial_weights;
multinomial_weights.reserve(logits.m_size);
if (logits.m_vector.size() > 0)
if (logits.vector_initialized())
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
else
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);
Expand All @@ -243,7 +243,7 @@ class Sampler {
std::vector<Token> out_tokens;
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
size_t element_to_pick = dist(rng_engine);
if (logits.m_vector.size() > 0)
if (logits.vector_initialized())
out_tokens.push_back(logits.m_vector[element_to_pick]);
else
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
Expand Down
Loading

0 comments on commit f4857e1

Please sign in to comment.