Skip to content

Commit

Permalink
refactor structure, add python sample
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Jul 22, 2024
1 parent 7cab496 commit bb1113c
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 154 deletions.
1 change: 1 addition & 0 deletions samples/cpp/benchmark_vanilla_genai/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# benchmark OpenVINO GenAI sample

TODO: adapt from python sample to c++
22 changes: 12 additions & 10 deletions samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ int main(int argc, char* argv[]) try {
("p,prompt", "Prompt", cxxopts::value<std::string>()->default_value("The Sky is blue because"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("nw,num_warmup", "Number of warmup iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
("n,num_iter", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
("n,num_iter", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(5)))
("mt,max_new_tokens", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(20)))
("d,device", "device", cxxopts::value<std::string>()->default_value("CPU"))
("h,help", "Print usage");

Expand All @@ -36,26 +37,27 @@ int main(int argc, char* argv[]) try {
size_t num_iter = result["num_iter"].as<size_t>();

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
config.num_beam_groups = 3;
config.num_beams = 15;
config.max_new_tokens = result["max_new_tokens"].as<size_t>();

ov::genai::LLMPipeline pipe(model_path, device);

for (size_t i = 0; i < num_warmup; i++)
pipe.generate(prompt, config);

ov::genai::PerfMetrics metrics;
for (size_t i = 0; i < num_iter; i++) {
ov::genai::DecodedResults res = pipe.generate(prompt, config);
ov::genai::DecodedResults res = pipe.generate(prompt, config);
ov::genai::PerfMetrics metrics = res.metrics;
for (size_t i = 0; i < num_iter - 1; i++) {
res = pipe.generate(prompt, config);
metrics = metrics + res.metrics;
metrics.load_time = res.metrics.load_time;
}

std::cout << "Load time: " << metrics.load_time << " ms" << std::endl;
std::cout << "Generate time: " << metrics.mean_generate_duration << " ± " << metrics.std_generate_duration << " ms" << std::endl;
std::cout << "Tokenization time: " << metrics.mean_tokenization_duration << " ± " << metrics.std_tokenization_duration << " ms" << std::endl;
std::cout << "Detokenization time: " << metrics.mean_detokenization_duration << " ± " << metrics.std_detokenization_duration << " ms" << std::endl;
std::cout << "ttft: " << metrics.mean_ttft << " ± " << metrics.std_ttft << " ms" << std::endl;
std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms" << std::endl;
std::cout << "Tokens/s: " << metrics.mean_throughput << std::endl;
std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms " << std::endl;
std::cout << "Tokens/s: " << metrics.mean_throughput << " ± " << metrics.std_throughput << std::endl;

return 0;
} catch (const std::exception& error) {
Expand Down
66 changes: 66 additions & 0 deletions samples/python/benchmark_vanilla_genai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Benchmark Vanilla GenAI

This sample script demonstrates how to benchmark an LLMModel in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics.

# ov.genai.PerfMetrics structure
ov.genai.PerfMetrics is a structure which holds performance metric for each generate call. Each generate call calcualtes the following metrics:
- mean_ttft
- std_ttft
- mean_tpot
- std_tpot
- load_time
- mean_generate_duration
- std_generate_duration
- mean_tokenization_duration
- std_tokenization_duration
- mean_detokenization_duration
- std_detokenization_duration
- mean_throughput
- std_throughput
- num_generated_tokens
- num_input_tokens

Performance metrics can be added to one another and accumulated using the += operator or the + operator. In that case the mean values accumulated by several generate calls will be calculated.


## Download and convert the model and tokenizers

The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.

It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported.

```sh
pip install --upgrade-strategy eager -r ../../requirements.txt
optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0
```

## Usage

```sh
python benchmark_vanilla_genai.py [OPTIONS]
```

### Options

- `-m, --model`: Path to the model and tokenizers base directory.
- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text.
- `-nw, --num_warmup` (default: `1`): Number of warmup iterations.
- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations.
- `-n, --num_iter` (default: `3`): Number of iterations.
- `-d, --device` (default: `"CPU"`): Device to run the model on.

### Output:

```
python benchmark_vanilla_genai.py -m TinyLlama-1.1B-Chat-v1.0/
```

```
Load time: 3446 ms
Generate time: 876.2 ± 3.30719 ms
Tokenization time: 0 ± 0 ms
Detokenization time: 0 ± 0 ms
ttft: 168 ± 0 ms
tpot: 174.68 ± 4.08671 ms
Tokens/s: 5.72475 ± 0.133933
```
50 changes: 50 additions & 0 deletions samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import openvino_genai as ov_genai
import pdb

def main():
parser = argparse.ArgumentParser(description="Help command")
parser.add_argument("-m", "--model", type=str, help="Path to model and tokenizers base directory")
parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt")
parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations")
parser.add_argument("-n", "--num_iter", type=int, default=3, help="Number of iterations")
parser.add_argument("-mt", "--max_new_tokens", type=int, default=20, help="Maximal number of new tokens")
parser.add_argument("-d", "--device", type=str, default="CPU", help="Device")

args = parser.parse_args()

prompt = [args.prompt]
model_path = args.model
device = args.device
num_warmup = args.num_warmup
num_iter = args.num_iter


config = ov_genai.GenerationConfig()
config.max_new_tokens = args.num_new_tokens

pipe = ov_genai.LLMPipeline(model_path, device)

for _ in range(num_warmup):
pipe.generate(prompt, config)

res = pipe.generate(prompt, config)
metrics = res.metrics
for _ in range(num_iter - 1):
# pdb.set_trace()
res = pipe.generate(prompt, config)
metrics += res.metrics

print(f"Load time: {metrics.load_time} ms")
print(f"Generate time: {metrics.mean_generate_duration:.2f} ± {metrics.std_generate_duration:.2f} ms")
print(f"Tokenization time: {metrics.mean_tokenization_duration:.2f} ± {metrics.std_tokenization_duration:.2f} ms")
print(f"Detokenization time: {metrics.mean_detokenization_duration:.2f} ± {metrics.std_detokenization_duration:.2f} ms")
print(f"TTFT: {metrics.mean_ttft:.2f} ± {metrics.std_ttft:.2f} ms")
print(f"TPOT: {metrics.mean_tpot:.2f} ± {metrics.std_tpot:.2f} ms")
print(f"Throughput tokens/s: {metrics.mean_throughput:.2f} ± {metrics.std_throughput:.2f}")

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using StringInputs = std::variant<std::string, std::vector<std::string>>;
*
* @param tokens sequence of resulting tokens
* @param scores sum of logarithmic probabilities of all tokens in the sequence
* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics
*/
class EncodedResults {
public:
Expand All @@ -45,6 +46,7 @@ class EncodedResults {
*
* @param texts vector of resulting sequences
* @param scores scores for each sequence
* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics
*/
class DecodedResults {
public:
Expand Down
37 changes: 29 additions & 8 deletions src/cpp/include/openvino/genai/perf_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,34 @@
#include "openvino/genai/visibility.hpp"
#include <vector>
#include <memory>
#include <optional>

namespace ov {
namespace genai {

using TimePoint = std::chrono::steady_clock::time_point;

struct PerfCounters;
/**
* @brief Structure with raw performance metrics for each generation before any statistics calculated.
*/
struct OPENVINO_GENAI_EXPORTS RawPerfMetrics {
std::vector<float> generate_durations;
std::vector<float> tokenization_durations;
std::vector<float> detokenization_durations;

std::vector<float> m_times_to_first_token;
std::vector<TimePoint> m_new_token_times;
std::vector<size_t> m_batch_sizes;
std::vector<float> m_durations;

size_t num_generated_tokens;
size_t num_input_tokens;
};

/**
* @brief Structure to store performance metric for each generation
*
*/
struct OPENVINO_GENAI_EXPORTS PerfMetrics {
// First token time.
float mean_ttft;
Expand All @@ -25,25 +45,26 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics {
float std_tpot;

float load_time;
float start_time;

float mean_generate_duration;
float mean_decoding_duration;
float mean_encoding_duration;
float std_generate_duration;
float mean_tokenization_duration;
float std_tokenization_duration;
float mean_detokenization_duration;
float std_detokenization_duration;

float mean_throughput;
float std_throughput;

size_t num_generated_tokens;
size_t num_input_tokens;

std::shared_ptr<PerfCounters> m_counters;
void evaluate(TimePoint start_time);

void evaluate_statistics(std::optional<TimePoint> start_time = std::nullopt);
static float get_duration_ms(std::chrono::steady_clock::duration duration);
PerfMetrics operator+(const PerfMetrics& metrics) const;
PerfMetrics& operator+=(const PerfMetrics& right);


RawPerfMetrics raw_counters;
};

} // namespace genai
Expand Down
10 changes: 6 additions & 4 deletions src/cpp/src/greedy_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "openvino/genai/perf_metrics.hpp"
#include "perf_counters.hpp"
// #include "perf_counters.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -23,7 +23,7 @@ EncodedResults greedy_decoding(
size_t max_new_tokens = generation_config.get_max_new_tokens(prompt_len);

EncodedResults results;
auto& perf_counters = results.metrics.m_counters;
auto& raw_perf_counters = results.metrics.raw_counters;

results.scores.resize(running_batch_size);
results.tokens.resize(running_batch_size);
Expand Down Expand Up @@ -54,7 +54,8 @@ EncodedResults greedy_decoding(
eos_met[batch] = (out_token == generation_config.eos_token_id);
m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
perf_counters->add_timestamp(running_batch_size);
raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

if (streamer && streamer->put(token_iter_results[0])) {
return results;
Expand Down Expand Up @@ -86,7 +87,8 @@ EncodedResults greedy_decoding(

m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
perf_counters->add_timestamp(running_batch_size);
raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

if (streamer && streamer->put(token_iter_results[0]))
return results;
Expand Down
20 changes: 12 additions & 8 deletions src/cpp/src/group_beam_searcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,6 @@ std::pair<EncodedResults, int32_t> beam_search(ov::InferRequest& lm,
auto batch_size = input_ids.get_shape().at(0);
auto sequence_length = input_ids.get_shape().at(1);

// Initialize time metric counters.
// ov::genai::TimePoints tok_times;
// tok_times.reserve(config.get_max_new_tokens(sequence_length));
// tok_times.emplace_back(std::chrono::steady_clock::now());

// Initialize beam search.
const int64_t* prompt_data = input_ids.data<const int64_t>();
std::vector<std::vector<int64_t>> prompts;
Expand Down Expand Up @@ -407,12 +402,19 @@ std::pair<EncodedResults, int32_t> beam_search(ov::InferRequest& lm,

std::vector<int64_t> next_tokens;
std::vector<int32_t> next_beams;


// Reserve for performance counters.
std::vector<std::chrono::steady_clock::time_point> new_token_times;
std::vector<size_t> batch_sizes;
new_token_times.reserve(parameters.max_new_tokens);
batch_sizes.reserve(parameters.max_new_tokens);

for (size_t length_count = 0; ; ++length_count) {
lm.infer();

std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits"));
// tok_times.emplace_back(std::chrono::steady_clock::now());
new_token_times.emplace_back(std::chrono::steady_clock::now());
batch_sizes.emplace_back(batch_size);

if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) {
// Break the cycle before masks are extended in update_attention_mask_with_beams.
Expand Down Expand Up @@ -442,6 +444,9 @@ std::pair<EncodedResults, int32_t> beam_search(ov::InferRequest& lm,
int32_t res_selected_beam_idx = 0;
results.scores.reserve(config.num_return_sequences * result.size());
results.tokens.reserve(config.num_return_sequences * result.size());
auto& raw_perf_counters = results.metrics.raw_counters;
raw_perf_counters.m_new_token_times = new_token_times;
raw_perf_counters.m_batch_sizes = batch_sizes;

// align output with HF
for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) {
Expand Down Expand Up @@ -471,7 +476,6 @@ std::pair<EncodedResults, int32_t> beam_search(ov::InferRequest& lm,
}
}

// results.metrics = PerfCounters(tok_times);
return {results, res_selected_beam_idx};
}

Expand Down
Loading

0 comments on commit bb1113c

Please sign in to comment.