Skip to content

Commit

Permalink
Add sampling decoding (#6)
Browse files Browse the repository at this point in the history
* greedy-sampling

* greedy sampling

* greedy sampling

* Update greedy_sampling.hpp

don't transform the logits during top_p

* Update greedy_causal_lm.cpp

* format

* early exit for arg_max

* Update default parameters

* Add multinomial sampling

* Remove unused hpp

* Reuse util functions

* Apply review comments

* Merge config

* Use size_t for iteration

* apply comments

* Move rand gen to class member

* Apply comments

* Add multinomial sampling

* Remove path

* Apply comments

* Fix merge

---------

Co-authored-by: wenyi5608 <[email protected]>
  • Loading branch information
as-suvorov and wenyi5608 authored May 27, 2024
1 parent bbc8c25 commit 9e37273
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ CMakeUserPresets.json
# Python-specific
*.?env*
*.pyc
__pycache__
__pycache__

*.so
6 changes: 3 additions & 3 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
StopCriteria stop_criteria = StopCriteria::heuristic;

// Multinomial
float temperature = 0.0f;
float temperature = 1.0f;
float top_p = 1.0f;
int top_k = -1;
size_t top_k = 50;
bool do_sample = false;
float repetition_penalty = 1.0f;

Expand All @@ -99,7 +99,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
size_t get_max_new_tokens(size_t prompt_length = 0) const;
bool is_greedy_decoding() const;
bool is_beam_search() const;
bool is_multimomial() const;
bool is_multinomial() const;
static GenerationConfig anymap_to_generation_config(const ov::AnyMap& config_map = {});
};

Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ bool GenerationConfig::is_beam_search() const {
return num_beams > 1;
}

bool GenerationConfig::is_multimomial() const {
bool GenerationConfig::is_multinomial() const {
return do_sample;
}

Expand Down
Empty file.
19 changes: 14 additions & 5 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ ov::genai::EncodedResults greedy_decoding(
const bool is_chat_conversation = false
);

ov::genai::EncodedResults multinominal_decoding(
ov::InferRequest& model_runner,
ov::Tensor prompts,
ov::Tensor attentin_mask,
GenerationConfig sampling_params,
std::shared_ptr<StreamerBase> streamer
);

EncodedResults beam_search(ov::InferRequest& lm, ov::Tensor prompts, ov::Tensor attentin_mask, GenerationConfig config);


Expand Down Expand Up @@ -252,8 +260,8 @@ ov::genai::EncodedResults ov::genai::LLMPipeline::LLMPipelineImpl::generate(
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}
auto batch_size = input_ids.get_shape().at(0);
if ((batch_size != 1 || !config.is_greedy_decoding()) && streamer_ptr) {
OPENVINO_THROW("Currently streaming is possible only with batch size=1 and greedy decoding");
if ((batch_size != 1 || !(config.is_greedy_decoding() || config.is_multinomial())) && streamer_ptr) {
OPENVINO_THROW("Currently streaming is possible only with batch size=1 and greedy or multinomial decoding");
}

auto attention_mask_data = attention_mask.has_value() ? *attention_mask : ov::genai::utils::init_attention_mask(input_ids);
Expand All @@ -262,10 +270,11 @@ ov::genai::EncodedResults ov::genai::LLMPipeline::LLMPipelineImpl::generate(
result = ov::genai::greedy_decoding(m_model_runner, input_ids, attention_mask_data, config, streamer_ptr, is_chat_conversation);
} else if (config.is_beam_search()) {
result = beam_search(m_model_runner, input_ids, attention_mask_data, config);
} else if (config.is_multinomial()) {
result = multinominal_decoding(m_model_runner, input_ids, attention_mask_data, config, streamer_ptr);
} else {
// todo: implement multinomial sampling
// result = multinomial_sampling(input_ids, config);
}
OPENVINO_THROW("No decoding algorithm found for provided configuration parameters.");
}

if (!is_chat_conversation)
m_model_runner.reset_state();
Expand Down
262 changes: 262 additions & 0 deletions src/cpp/src/multinomial_decoding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <cmath>
#include <iostream>
#include <numeric>
#include <random>
#include <regex>
#include <vector>

#include "generation_config_helper.hpp"
#include "openvino/genai/llm_pipeline.hpp"
#include "utils.hpp"


namespace {

struct TokenIdScore {
int64_t id;
float score;

bool operator<(const TokenIdScore& other) const {
return score < other.score;
}

bool operator>(const TokenIdScore& other) const {
return score > other.score;
}
};

void apply_softmax_inplace(std::vector<TokenIdScore>& tokens) {
float max_score = std::max_element(tokens.begin(), tokens.end())->score;
float sum = 0.f;

for (auto& token : tokens) {
float s = std::exp(token.score - max_score);
token.score = s;
sum += s;
}

float inv_sum = 1.f / sum;

for (auto& token : tokens) {
token.score *= inv_sum;
}
}

TokenIdScore* sample_top_p(TokenIdScore* first, TokenIdScore* last, float top_p) {
// sort score
std::sort(first, last, std::greater<TokenIdScore>());

int tokens_size = last - first;
std::vector<TokenIdScore> token_scores(tokens_size);
for (size_t i = 0; i < tokens_size; i++) {
token_scores[i] = first[i];
}

// calculate softmax
apply_softmax_inplace(token_scores);

float prefix_sum = 0.0f;

// top_p
for (size_t i = 0; i < tokens_size; i++) {
prefix_sum += token_scores[i].score;
if (prefix_sum >= top_p) {
return first + (i + 1);
}
}

return last;
}

void apply_repetition_penalty(float* first, float* last, const std::vector<int64_t>& input_ids, float penalty) {
const float inv_penalty = 1.f / penalty;
const int vocab_size = last - first;
std::vector<bool> occurrence(vocab_size, false);
for (const int64_t id : input_ids) {
if (!occurrence[id]) {
first[id] *= (first[id] > 0) ? inv_penalty : penalty;
}
occurrence[id] = true;
}
}

void apply_inv_temperature(float* first, float* last, float inv_temperature) {
for (float* it = first; it != last; it++) {
*it *= inv_temperature;
}
}

struct RandomSampling {
const size_t top_k;
const float top_p;
const float inv_temperature;
const float repetition_penalty;

std::mt19937 gen{std::random_device{}()};

RandomSampling(ov::genai::GenerationConfig generation_config)
: top_k{generation_config.top_k},
top_p{generation_config.top_p},
inv_temperature{1.f / generation_config.temperature},
repetition_penalty{generation_config.repetition_penalty} {
// parameters validation
OPENVINO_ASSERT(generation_config.top_k > 0,
"top_k must be a strictly positive, but got ",
generation_config.top_p);
OPENVINO_ASSERT(generation_config.top_p > 0 || generation_config.top_p < 1.0f,
"top_p must be a positive float > 0 and < 1, but got ",
generation_config.top_p);
OPENVINO_ASSERT(generation_config.temperature > 0,
"Temperature must be a strictly positive float, but got ",
generation_config.temperature);
OPENVINO_ASSERT(generation_config.repetition_penalty > 0,
"Repetition penalty must be a strictly positive float, but got ",
generation_config.repetition_penalty);
}

TokenIdScore get_out_token(float* logits, size_t vocab_size, const std::vector<int64_t>& tokens) {
// logits pre-process
if (repetition_penalty != 1.0f) {
apply_repetition_penalty(logits, logits + vocab_size, tokens, repetition_penalty);
}

if (inv_temperature != 1.0f) {
apply_inv_temperature(logits, logits + vocab_size, inv_temperature);
}

std::vector<TokenIdScore> token_scores(vocab_size);
for (size_t i = 0; i < vocab_size; i++) {
token_scores[i] = TokenIdScore{int64_t(i), logits[i]};
}

// top_k sampling
if (0 < top_k && top_k < token_scores.size()) {
std::nth_element(token_scores.data(),
token_scores.data() + top_k,
token_scores.data() + token_scores.size(),
std::greater<TokenIdScore>());
token_scores.resize(top_k);
}

// top_p sampling
if (0.f < top_p && top_p < 1.0f) {
auto pos = sample_top_p(token_scores.data(), token_scores.data() + token_scores.size(), top_p);
token_scores.resize(pos - token_scores.data());
}

// sample next token
apply_softmax_inplace(token_scores);
for (size_t i = 0; i < token_scores.size(); i++) {
logits[i] = token_scores[i].score;
}

std::discrete_distribution<> dist(logits, logits + token_scores.size());
return token_scores[dist(gen)];
}
};
} // namespace

namespace ov {
namespace genai {

ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner,
ov::Tensor input_ids,
ov::Tensor attention_mask,
ov::genai::GenerationConfig config,
std::shared_ptr<ov::genai::StreamerBase> streamer) {
ov::Shape prompts_shape = input_ids.get_shape();
size_t batch_size = prompts_shape[0];

OPENVINO_ASSERT(batch_size == 1, "Only batch size = 1 supported for multinomial decoding");

size_t prompt_len = prompts_shape[1];

ov::genai::EncodedResults results;
results.scores.resize(batch_size, 0);
results.tokens.resize(batch_size);

// Initialize inputs
m_model_runner.set_tensor("input_ids", input_ids);
m_model_runner.set_tensor("attention_mask", attention_mask);

ov::Tensor position_ids = m_model_runner.get_tensor("position_ids");
position_ids.set_shape(input_ids.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);

// Input values are persistent between inference calls.
// That allows to set values, which aren't going to change, only once
m_model_runner.get_tensor("beam_idx").set_shape({batch_size});
m_model_runner.get_tensor("beam_idx").data<int32_t>()[0] = 0;

m_model_runner.infer();

auto logits_tensor = m_model_runner.get_tensor("logits");

int64_t sequence_offset = logits_tensor.get_shape().at(1) - 1;
size_t vocab_size = logits_tensor.get_shape().back();

float* logits = logits_tensor.data<float>() + sequence_offset * vocab_size;

const int64_t* input_ids_data = input_ids.data<const int64_t>();

std::vector<int64_t> tokens{input_ids_data, input_ids_data + input_ids.get_size()};

RandomSampling sampling{config};

TokenIdScore out_token = sampling.get_out_token(logits, vocab_size, tokens);

tokens.push_back(out_token.id);
results.tokens[0].push_back(out_token.id);
results.scores[0] += out_token.score;

if (streamer) {
streamer->put(out_token.id);
}

if (!config.ignore_eos && out_token.id == config.eos_token_id) {
return results;
}

m_model_runner.get_tensor("input_ids").set_shape({batch_size, 1});
m_model_runner.get_tensor("position_ids").set_shape({batch_size, 1});

size_t max_new_tokens = config.get_max_new_tokens(prompt_len);

for (size_t i = 0; i < max_new_tokens - 1; i++) {
ov::genai::utils::update_position_ids(m_model_runner.get_tensor("position_ids"),
m_model_runner.get_tensor("attention_mask"));
m_model_runner.set_tensor("attention_mask",
ov::genai::utils::extend_attention(m_model_runner.get_tensor("attention_mask")));

m_model_runner.get_tensor("input_ids").data<int64_t>()[0] = out_token.id;

m_model_runner.infer();

logits = m_model_runner.get_tensor("logits").data<float>();
out_token = sampling.get_out_token(logits, vocab_size, tokens);

tokens.push_back(out_token.id);
results.tokens[0].push_back(out_token.id);
results.scores[0] += out_token.score;

if (streamer) {
streamer->put(out_token.id);
}

if (!config.ignore_eos && out_token.id == config.eos_token_id) {
break;
}
}

if (streamer) {
streamer->end();
}

return results;
}
} // namespace genai
} // namespace ov
10 changes: 10 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ void set_attention_mask(ov::Tensor&& attention_mask, std::vector<int32_t> next_b
}
}

/**
* Set position ids tensor data for next token inference based on provided attention mask
* Supports multi batch
* Supports sparse attention_mask
*/
void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t atten_length = attention_mask.get_shape().at(1);
Expand All @@ -121,6 +126,11 @@ void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention
}
}

/**
* Get attention mask tensor for next token inference
* Supports multi batch
* Supports sparse attention_mask
*/
ov::Tensor extend_attention(ov::Tensor attention_mask) {
auto shape = attention_mask.get_shape();
auto batch_size = shape[0];
Expand Down
8 changes: 7 additions & 1 deletion text_generation/causal_lm/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ target_include_directories(chat_sample PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}")
set_target_properties(chat_sample PROPERTIES CXX_STANDARD 17)
set_target_properties(chat_sample PROPERTIES CXX_STANDARD_REQUIRED ON)

install(TARGETS greedy_causal_lm beam_search_causal_lm speculative_decoding_lm prompt_lookup_decoding_lm chat_sample
add_executable(multinomial_causal_lm multinomial_causal_lm.cpp)
target_link_libraries(multinomial_causal_lm PRIVATE openvino::genai)
target_include_directories(multinomial_causal_lm PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}")
set_target_properties(multinomial_causal_lm PROPERTIES CXX_STANDARD 17)
set_target_properties(multinomial_causal_lm PROPERTIES CXX_STANDARD_REQUIRED ON)

install(TARGETS greedy_causal_lm beam_search_causal_lm speculative_decoding_lm prompt_lookup_decoding_lm chat_sample multinomial_causal_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
Loading

0 comments on commit 9e37273

Please sign in to comment.