From 949a463ef3412dcae2971989b678ffd3f62d3271 Mon Sep 17 00:00:00 2001 From: "Yu, Zhentao" Date: Wed, 24 Jan 2024 07:15:36 +0000 Subject: [PATCH] refine scheduler naming Signed-off-by: Yu, Zhentao --- neural_speed/application/main_pybind.cpp | 13 ++++- neural_speed/models/model_utils/scheduler.cpp | 51 ++++++++++--------- neural_speed/models/model_utils/scheduler.h | 32 ++++++------ 3 files changed, 55 insertions(+), 41 deletions(-) diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index 86ca372d4..0bca2b798 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -158,7 +158,7 @@ class ModelServer { num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, true, max_request_num, model_scratch_enlarge_scale); - cbg_scheduler scheduler(this->params, policy, print_log ? 0 : 1); + Cont_batch_gen_scheduler scheduler(this->params, policy, print_log ? 0 : 1); std::vector added_seqs; while (running) { { // add waitting tasks queue to running queue @@ -279,15 +279,26 @@ class ModelServer { } private: + // response function from outside for collecting generation results and checking server working status const ResponseCallback response; + // waiting pool for new queries added into server std::vector waiting; + // lock for waiting pool std::mutex queue_mtx; + // status for telling server if it still need to continue running or not + // true: checking waiting pool and performing one step (or waiting new query) + // false: stop server bool running; gpt_params params; + // server policy (only FCFS (first come, first serve) now) std::string policy; + // if server scheduler has no queries to run or not bool scheduler_empty; + // current number of queries the server need to deal with uint64_t working_size; + // add prompt token ids before generated tokens in results if set it true bool return_prompt; + // server working thread std::thread worker; }; diff --git a/neural_speed/models/model_utils/scheduler.cpp b/neural_speed/models/model_utils/scheduler.cpp index ab5db1980..b73beed8c 100644 --- a/neural_speed/models/model_utils/scheduler.cpp +++ b/neural_speed/models/model_utils/scheduler.cpp @@ -14,8 +14,8 @@ #include "models/model_utils/scheduler.h" -// il_worker -il_worker::il_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params)) { +// Iter_level_worker +Iter_level_worker::Iter_level_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params)) { if (m_ctx == nullptr) { fprintf(stderr, "%s: error: unable to load model.\n", __func__); exit(0); @@ -30,7 +30,7 @@ il_worker::il_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_param threads = params.n_threads; } -il_worker::~il_worker() { +Iter_level_worker::~Iter_level_worker() { if (m_ctx != nullptr) { model_free(m_ctx); } @@ -39,10 +39,12 @@ il_worker::~il_worker() { } } -// cbg_worker -cbg_worker::cbg_worker(const gpt_params& params) : il_worker(params) { m_ctx->cont_batching = true; } +// Cont_batch_gen_worker +Cont_batch_gen_worker::Cont_batch_gen_worker(const gpt_params& params) : Iter_level_worker(params) { + m_ctx->cont_batching = true; +} -bool cbg_worker::prepare_inputs(std::vector* seqs, const int& n_input, model_input* inputs) { +bool Cont_batch_gen_worker::prepare_inputs(std::vector* seqs, const int& n_input, model_input* inputs) { for (int i = 0; i < n_input; ++i) { if ((seqs->at(i)).status != seq_status::PREFILL && (seqs->at(i)).status != seq_status::DECODING) { fprintf(stderr, "%s: error: request %d status is unright (%d).\n", __func__, seqs->at(i).request_idx, @@ -73,7 +75,7 @@ bool cbg_worker::prepare_inputs(std::vector* seqs, const int& n_input, return true; } -bool cbg_worker::beam_search_step(std::vector* seqs, const int& n_input) { +bool Cont_batch_gen_worker::beam_search_step(std::vector* seqs, const int& n_input) { std::vector step_inputs(n_input); if (!prepare_inputs(seqs, n_input, step_inputs.data())) { return false; @@ -85,7 +87,7 @@ bool cbg_worker::beam_search_step(std::vector* seqs, const int& n_inpu return true; } -bool cbg_worker::step(std::vector* seqs, const int& n_input) { +bool Cont_batch_gen_worker::step(std::vector* seqs, const int& n_input) { reqidx_to_vecid.clear(); for (int ni = 0; ni < n_input; ++ni) { reqidx_to_vecid.emplace(seqs->at(ni).request_idx, ni); @@ -99,7 +101,7 @@ bool cbg_worker::step(std::vector* seqs, const int& n_input) { return update_seqs(seqs, n_input); } -bool cbg_worker::update_seqs(std::vector* seqs, const int& n_input) { +bool Cont_batch_gen_worker::update_seqs(std::vector* seqs, const int& n_input) { empty_request_done_ids(); for (int ni = 0; ni < n_input; ++ni) { if (seqs->at(ni).status == seq_status::PREFILL) { @@ -142,8 +144,8 @@ bool cbg_worker::update_seqs(std::vector* seqs, const int& n_input) { return false; // TODO (YZT) greedy search and top_p-top_k sampling } -// il_scheduler -il_scheduler::il_scheduler(const gpt_params& params, const std::string& policy, const int& log_level) +// Iter_level_scheduler +Iter_level_scheduler::Iter_level_scheduler(const gpt_params& params, const std::string& policy, const int& log_level) : params(params), policy(parse_serve_policy(policy)), waiting_pool(pool_property::WAITING), @@ -151,9 +153,9 @@ il_scheduler::il_scheduler(const gpt_params& params, const std::string& policy, finished_pool(pool_property::FINISHED), log_level(log_level) {} -il_scheduler::il_scheduler(const gpt_params& params) : il_scheduler(params, "fcfs", 1) {} +Iter_level_scheduler::Iter_level_scheduler(const gpt_params& params) : Iter_level_scheduler(params, "fcfs", 1) {} -std::vector il_scheduler::pop_completed_requests() { +std::vector Iter_level_scheduler::pop_completed_requests() { std::vector ret_seqs; const int length = finished_pool.size(); if (length == 0) { @@ -174,22 +176,23 @@ std::vector il_scheduler::pop_completed_requests() { return ret_seqs; } -// cbg_scheduler -cbg_scheduler::cbg_scheduler(const gpt_params& params) - : il_scheduler(params), +// Cont_batch_gen_scheduler +Cont_batch_gen_scheduler::Cont_batch_gen_scheduler(const gpt_params& params) + : Iter_level_scheduler(params), max_requests(params.max_request_num), wr(params), free_req_idx(max_requests, true), waiting_free_req_idx_seqs_num(0) {} -cbg_scheduler::cbg_scheduler(const gpt_params& params, const std::string& policy, const int& log_level) - : il_scheduler(params, policy, log_level), +Cont_batch_gen_scheduler::Cont_batch_gen_scheduler(const gpt_params& params, const std::string& policy, + const int& log_level) + : Iter_level_scheduler(params, policy, log_level), max_requests(params.max_request_num), wr(params), free_req_idx(max_requests, true), waiting_free_req_idx_seqs_num(0) {} -int cbg_scheduler::query_free_req_idx() { +int Cont_batch_gen_scheduler::query_free_req_idx() { auto iter = std::find_if(free_req_idx.begin(), free_req_idx.end(), [](const bool flag) { return flag; }); if (iter == free_req_idx.end()) { return -1; @@ -200,7 +203,7 @@ int cbg_scheduler::query_free_req_idx() { } } -bool cbg_scheduler::add_request(sequence seq) { +bool Cont_batch_gen_scheduler::add_request(sequence seq) { seq.receive_time = model_time_us(); if (seq.status != seq_status::UNKNOWN) { fprintf(stderr, "%s: error: seq status is not UNKNOWN, can not decide to add into which pool.\n", __func__); @@ -216,7 +219,7 @@ bool cbg_scheduler::add_request(sequence seq) { return waiting_pool.add(seq); } -bool cbg_scheduler::prepare_seqs() { +bool Cont_batch_gen_scheduler::prepare_seqs() { executed_seqs.clear(); cur_running_num = running_pool.size(); if (cur_running_num > max_requests) { @@ -269,7 +272,7 @@ bool cbg_scheduler::prepare_seqs() { return true; } -bool cbg_scheduler::step() { +bool Cont_batch_gen_scheduler::step() { int64_t s_t0 = model_time_us(); if (done()) { fprintf(stderr, @@ -309,7 +312,7 @@ bool cbg_scheduler::step() { return success; } -bool cbg_scheduler::update_pools() { +bool Cont_batch_gen_scheduler::update_pools() { for (int ns = 0; ns < executed_seqs.size(); ++ns) { if (executed_seqs[ns].status == seq_status::DECODING) { running_pool.add(executed_seqs[ns]); @@ -332,7 +335,7 @@ bool cbg_scheduler::update_pools() { return true; } -bool cbg_scheduler::done() { +bool Cont_batch_gen_scheduler::done() { if (waiting_pool.empty() && running_pool.empty()) { return true; } else { diff --git a/neural_speed/models/model_utils/scheduler.h b/neural_speed/models/model_utils/scheduler.h index 575be400b..7b34044c5 100644 --- a/neural_speed/models/model_utils/scheduler.h +++ b/neural_speed/models/model_utils/scheduler.h @@ -19,10 +19,10 @@ #include "models/model_utils/model_utils.h" // iteration-level worker -class il_worker { +class Iter_level_worker { public: - explicit il_worker(const gpt_params& params); - virtual ~il_worker(); + explicit Iter_level_worker(const gpt_params& params); + virtual ~Iter_level_worker(); virtual bool step(std::vector* seqs, const int& n_input) = 0; // virtual bool greedy_search_step(sequence seqs, const int& n_input) = 0; virtual bool beam_search_step(std::vector* seqs, const int& n_input) = 0; @@ -43,11 +43,11 @@ class il_worker { }; // continuous batching generation worker -class cbg_worker : public il_worker { +class Cont_batch_gen_worker : public Iter_level_worker { public: - explicit cbg_worker(const gpt_params& params); - cbg_worker(const gpt_params& params, const int& n_threads); - ~cbg_worker() = default; + explicit Cont_batch_gen_worker(const gpt_params& params); + Cont_batch_gen_worker(const gpt_params& params, const int& n_threads); + ~Cont_batch_gen_worker() = default; bool step(std::vector* seqs, const int& n_input) override; // bool greedy_search_step(sequence seqs, const int& n_input) override; @@ -59,11 +59,11 @@ class cbg_worker : public il_worker { }; // iteration-level scheduler -class il_scheduler { +class Iter_level_scheduler { public: - explicit il_scheduler(const gpt_params& params); - il_scheduler(const gpt_params& params, const std::string& policy, const int& log_level); - virtual ~il_scheduler() = default; + explicit Iter_level_scheduler(const gpt_params& params); + Iter_level_scheduler(const gpt_params& params, const std::string& policy, const int& log_level); + virtual ~Iter_level_scheduler() = default; // TODO (YZT) kv cache ptr as input params virtual bool add_request(sequence seq) = 0; @@ -85,11 +85,11 @@ class il_scheduler { }; // continuous batching generation scheduler -class cbg_scheduler : public il_scheduler { +class Cont_batch_gen_scheduler : public Iter_level_scheduler { public: - explicit cbg_scheduler(const gpt_params& params); - cbg_scheduler(const gpt_params& params, const std::string& policy, const int& log_level); - ~cbg_scheduler() = default; + explicit Cont_batch_gen_scheduler(const gpt_params& params); + Cont_batch_gen_scheduler(const gpt_params& params, const std::string& policy, const int& log_level); + ~Cont_batch_gen_scheduler() = default; bool add_request(sequence seq) override; bool step() override; @@ -101,7 +101,7 @@ class cbg_scheduler : public il_scheduler { int query_free_req_idx(); const int max_requests; - cbg_worker wr; + Cont_batch_gen_worker wr; std::vector executed_seqs; std::vector free_req_idx; int waiting_free_req_idx_seqs_num;