Skip to content

Commit

Permalink
batched-bench : add fattn arg
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Apr 18, 2024
1 parent c16a7c2 commit 9ca8698
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1 ;
Expand All @@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
int n_kv_max = 2048;
int n_batch = 2048;
int n_ubatch = 512;
bool flash_attn = false;
int is_pp_shared = 0;
int n_gpu_layers = 0;

Expand All @@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
}

if (argc >= 6) {
is_pp_shared = std::atoi(argv[5]);
flash_attn = std::atoi(argv[5]);
}

if (argc >= 7) {
n_gpu_layers = std::atoi(argv[6]);
is_pp_shared = std::atoi(argv[6]);
}

if (argc >= 8) {
n_pp = parse_list(argv[7]);
n_gpu_layers = std::atoi(argv[7]);
}

if (argc >= 9) {
n_tg = parse_list(argv[8]);
n_pp = parse_list(argv[8]);
}

if (argc >= 10) {
n_pl = parse_list(argv[9]);
n_tg = parse_list(argv[9]);
}

if (argc >= 11) {
n_pl = parse_list(argv[10]);
}

// init LLM
Expand All @@ -108,10 +113,11 @@ int main(int argc, char ** argv) {

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = n_batch;
ctx_params.n_ubatch = n_ubatch;
ctx_params.flash_attn = flash_attn;

ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
Expand Down Expand Up @@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down

0 comments on commit 9ca8698

Please sign in to comment.