Skip to content

Commit

Permalink
remove extra
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Sep 6, 2024
1 parent c39ff0a commit 6acf230
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int main(int argc, char* argv[]) try {
// create dataset

std::vector<std::string> prompt_examples = {
"The United Arab Emirates[c] (UAE), or simply the Emirates,[d] is a country in West Asia, in the Middle East, at the eastern end of the Arabian Peninsula. It is a federal, elective monarchy composed of seven emirates, with Abu Dhabi as its capital.[13] It shares land borders with Oman to the east and northwest, and with Saudi Arabia to the southwest; as well as maritime borders in the Persian Gulf with Qatar and Iran, and with Oman in the Gulf of Oman.",
"What is OpenVINO?",
"How are you?",
"What is your name?",
Expand Down
66 changes: 0 additions & 66 deletions samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ class AssistedCandidateGenerator {
size_t draft_model_seq_length = 0;

public:
size_t max_match = 0, avg_match = 0;
int64_t m_speculative_model_duration = 0;
std::vector<size_t> m_matches_info;

AssistedCandidateGenerator(ov::InferRequest draft_model,
const size_t max_seq_length,
const size_t num_pred_tokens,
Expand Down Expand Up @@ -178,10 +174,7 @@ class AssistedCandidateGenerator {

draft_model.get_tensor("beam_idx").set_shape({BATCH_SIZE});
draft_model.get_tensor("beam_idx").data<int32_t>()[0] = 0;
auto start_time = std::chrono::system_clock::now();
draft_model.infer();
auto end_time = std::chrono::system_clock::now();
m_speculative_model_duration += std::chrono::duration_cast<std::chrono::nanoseconds>(end_time - start_time).count();

auto logits = draft_model.get_tensor("logits");
size_t vocab_size = logits.get_shape().back();
Expand Down Expand Up @@ -226,49 +219,8 @@ class AssistedCandidateGenerator {
} else {
num_pred_tokens = std::max(int64_t(num_pred_tokens) - 1, int64_t(1));
}
// std::cout << "num_matches: " << num_matches << " m_candidates_num: " << num_pred_tokens << std::endl;
max_match = std::max(num_matches, max_match);
avg_match += num_matches;
}


// inline size_t get_median(std::vector<size_t> values) {
// const auto size = values.size();
// if (size == 0) {
// return 0;
// }
// size_t offset = values.size() / 2;

// auto it = values.begin() + offset;
// std::nth_element(values.begin(), it, values.end());

// if (size % 2 != 0) {
// return *it;
// }
// auto it_1 = values.begin() + offset - 1;
// std::nth_element(values.begin(), it_1, values.end());
// return (*it + *it_1) / 2;
// }


// void update_candidate_strategy(size_t num_matches) {
// // std::cout << "num_matches: " << num_matches << " m_candidates_num: " << m_candidates_num << std::endl;
// if (max_pred_tokens == 0) {
// return;
// }

// if (max_match < num_matches) {
// max_match = num_matches;
// }
// if (num_matches == num_pred_tokens) {
// num_pred_tokens = std::min(std::max(num_pred_tokens + 1, max_match), max_pred_tokens);
// } else {
// num_pred_tokens = num_matches > 0 ? num_matches : std::max(get_median(m_matches_info), size_t(1));
// }
// m_matches_info.push_back(num_matches);
// avg_match += num_matches;
// }

void update_kv_cache(const size_t seq_length) {
// this is the case when main model accepted all candidates from draft model
// we need to collect kv cache for out_of_kv_cache_token by infering it
Expand Down Expand Up @@ -383,10 +335,6 @@ int main(int argc, char* argv[]) try {
while (out_token != EOS_TOKEN && seq_len < max_sequence_length) {
// generate candidates from the draft model
std::vector<int64_t> candidates = candidateGenerator.generate_candidates(out_token);
// std::cout << "ASIISTING MODEL: " << std::endl;
// for (size_t i = 0; i < candidates.size(); ++i) {
// std::cout << "N max_sampled_tokens: " << candidates[i] << std::endl;
// }
size_t candidates_size = candidates.size();

// For the main network, candidates_size + 1 tokens will be fed at once in a single infer request.
Expand Down Expand Up @@ -415,13 +363,11 @@ int main(int argc, char* argv[]) try {
// 2.2 it it's mismatch, stop iteration but still accept current token as it was last token generated by
// model from a valid sequence.
size_t accepted_tokens_number = 0;
// std::cout << "MODEL: " << std::endl;
for (size_t i = 0; i < candidates_size + 1; i++) {
auto start = data_logits + vocab_size * i;
auto stop = data_logits + vocab_size * (i + 1);
out_token = std::max_element(start, stop) - start;

// std::cout << "N max_sampled_tokens: " << out_token << std::endl;
if (out_token == EOS_TOKEN) {
break;
}
Expand All @@ -441,7 +387,6 @@ int main(int argc, char* argv[]) try {
if (accepted_tokens_number > 0) {
candidateGenerator.update_candidate_strategy(accepted_tokens_number - 1);
}
// std::cout << "=========================" << std::endl;

candidateGenerator.update_kv_cache(seq_len);
update_kv_cache(main_model, main_model_seq_len_axis, seq_len);
Expand All @@ -456,17 +401,6 @@ int main(int argc, char* argv[]) try {
// it is called for education purposes:
draft_model.reset_state();
main_model.reset_state();

auto end_time = std::chrono::system_clock::now();
std::chrono::duration<double> duration = end_time - start_time;
std::cout << std::endl;
std::cout << "Duration: " << duration.count() << std::endl;
std::cout << "Infer number: " << iteration_cnt << std::endl;
std::cout << "MAX matches number: " << candidateGenerator.max_match << std::endl;
// auto a = std::accumulate(pipe.m_matches_info.begin(), pipe.m_matches_info.end(), 0);
std::cout << "AVG matches number: " << float(candidateGenerator.avg_match) / iteration_cnt << std::endl;
double c = double(candidateGenerator.m_speculative_model_duration) * 100 / std::chrono::duration_cast<std::chrono::nanoseconds>(duration).count();
std::cout << "Speculative model time duration in %: " << c << std::endl;
} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
Expand Down

0 comments on commit 6acf230

Please sign in to comment.