diff --git a/samples/cpp/continuous_batching_speculative_decoding/continuous_batching_speculative_decoding.cpp b/samples/cpp/continuous_batching_speculative_decoding/continuous_batching_speculative_decoding.cpp index f66797e4d6..782003d789 100644 --- a/samples/cpp/continuous_batching_speculative_decoding/continuous_batching_speculative_decoding.cpp +++ b/samples/cpp/continuous_batching_speculative_decoding/continuous_batching_speculative_decoding.cpp @@ -51,6 +51,7 @@ int main(int argc, char* argv[]) try { // create dataset std::vector prompt_examples = { + "The United Arab Emirates[c] (UAE), or simply the Emirates,[d] is a country in West Asia, in the Middle East, at the eastern end of the Arabian Peninsula. It is a federal, elective monarchy composed of seven emirates, with Abu Dhabi as its capital.[13] It shares land borders with Oman to the east and northwest, and with Saudi Arabia to the southwest; as well as maritime borders in the Persian Gulf with Qatar and Iran, and with Oman in the Gulf of Oman.", "What is OpenVINO?", "How are you?", "What is your name?", diff --git a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp index 66c7c231fa..f535ebb751 100644 --- a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp +++ b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp @@ -147,10 +147,6 @@ class AssistedCandidateGenerator { size_t draft_model_seq_length = 0; public: - size_t max_match = 0, avg_match = 0; - int64_t m_speculative_model_duration = 0; - std::vector m_matches_info; - AssistedCandidateGenerator(ov::InferRequest draft_model, const size_t max_seq_length, const size_t num_pred_tokens, @@ -178,10 +174,7 @@ class AssistedCandidateGenerator { draft_model.get_tensor("beam_idx").set_shape({BATCH_SIZE}); draft_model.get_tensor("beam_idx").data()[0] = 0; - auto start_time = std::chrono::system_clock::now(); draft_model.infer(); - auto end_time = std::chrono::system_clock::now(); - m_speculative_model_duration += std::chrono::duration_cast(end_time - start_time).count(); auto logits = draft_model.get_tensor("logits"); size_t vocab_size = logits.get_shape().back(); @@ -226,49 +219,8 @@ class AssistedCandidateGenerator { } else { num_pred_tokens = std::max(int64_t(num_pred_tokens) - 1, int64_t(1)); } - // std::cout << "num_matches: " << num_matches << " m_candidates_num: " << num_pred_tokens << std::endl; - max_match = std::max(num_matches, max_match); - avg_match += num_matches; } - -// inline size_t get_median(std::vector values) { -// const auto size = values.size(); -// if (size == 0) { -// return 0; -// } -// size_t offset = values.size() / 2; - -// auto it = values.begin() + offset; -// std::nth_element(values.begin(), it, values.end()); - -// if (size % 2 != 0) { -// return *it; -// } -// auto it_1 = values.begin() + offset - 1; -// std::nth_element(values.begin(), it_1, values.end()); -// return (*it + *it_1) / 2; -// } - - -// void update_candidate_strategy(size_t num_matches) { -// // std::cout << "num_matches: " << num_matches << " m_candidates_num: " << m_candidates_num << std::endl; -// if (max_pred_tokens == 0) { -// return; -// } - -// if (max_match < num_matches) { -// max_match = num_matches; -// } -// if (num_matches == num_pred_tokens) { -// num_pred_tokens = std::min(std::max(num_pred_tokens + 1, max_match), max_pred_tokens); -// } else { -// num_pred_tokens = num_matches > 0 ? num_matches : std::max(get_median(m_matches_info), size_t(1)); -// } -// m_matches_info.push_back(num_matches); -// avg_match += num_matches; -// } - void update_kv_cache(const size_t seq_length) { // this is the case when main model accepted all candidates from draft model // we need to collect kv cache for out_of_kv_cache_token by infering it @@ -383,10 +335,6 @@ int main(int argc, char* argv[]) try { while (out_token != EOS_TOKEN && seq_len < max_sequence_length) { // generate candidates from the draft model std::vector candidates = candidateGenerator.generate_candidates(out_token); - // std::cout << "ASIISTING MODEL: " << std::endl; - // for (size_t i = 0; i < candidates.size(); ++i) { - // std::cout << "N max_sampled_tokens: " << candidates[i] << std::endl; - // } size_t candidates_size = candidates.size(); // For the main network, candidates_size + 1 tokens will be fed at once in a single infer request. @@ -415,13 +363,11 @@ int main(int argc, char* argv[]) try { // 2.2 it it's mismatch, stop iteration but still accept current token as it was last token generated by // model from a valid sequence. size_t accepted_tokens_number = 0; - // std::cout << "MODEL: " << std::endl; for (size_t i = 0; i < candidates_size + 1; i++) { auto start = data_logits + vocab_size * i; auto stop = data_logits + vocab_size * (i + 1); out_token = std::max_element(start, stop) - start; - // std::cout << "N max_sampled_tokens: " << out_token << std::endl; if (out_token == EOS_TOKEN) { break; } @@ -441,7 +387,6 @@ int main(int argc, char* argv[]) try { if (accepted_tokens_number > 0) { candidateGenerator.update_candidate_strategy(accepted_tokens_number - 1); } - // std::cout << "=========================" << std::endl; candidateGenerator.update_kv_cache(seq_len); update_kv_cache(main_model, main_model_seq_len_axis, seq_len); @@ -456,17 +401,6 @@ int main(int argc, char* argv[]) try { // it is called for education purposes: draft_model.reset_state(); main_model.reset_state(); - - auto end_time = std::chrono::system_clock::now(); - std::chrono::duration duration = end_time - start_time; - std::cout << std::endl; - std::cout << "Duration: " << duration.count() << std::endl; - std::cout << "Infer number: " << iteration_cnt << std::endl; - std::cout << "MAX matches number: " << candidateGenerator.max_match << std::endl; - // auto a = std::accumulate(pipe.m_matches_info.begin(), pipe.m_matches_info.end(), 0); - std::cout << "AVG matches number: " << float(candidateGenerator.avg_match) / iteration_cnt << std::endl; - double c = double(candidateGenerator.m_speculative_model_duration) * 100 / std::chrono::duration_cast(duration).count(); - std::cout << "Speculative model time duration in %: " << c << std::endl; } catch (const std::exception& error) { try { std::cerr << error.what() << '\n';