Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
refine scheduler naming
Browse files Browse the repository at this point in the history
Signed-off-by: Yu, Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Jan 24, 2024
1 parent 074885b commit 949a463
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
13 changes: 12 additions & 1 deletion neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ModelServer {
num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty,
early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype,
true, max_request_num, model_scratch_enlarge_scale);
cbg_scheduler scheduler(this->params, policy, print_log ? 0 : 1);
Cont_batch_gen_scheduler scheduler(this->params, policy, print_log ? 0 : 1);
std::vector<sequence> added_seqs;
while (running) {
{ // add waitting tasks queue to running queue
Expand Down Expand Up @@ -279,15 +279,26 @@ class ModelServer {
}

private:
// response function from outside for collecting generation results and checking server working status
const ResponseCallback response;
// waiting pool for new queries added into server
std::vector<Query> waiting;
// lock for waiting pool
std::mutex queue_mtx;
// status for telling server if it still need to continue running or not
// true: checking waiting pool and performing one step (or waiting new query)
// false: stop server
bool running;
gpt_params params;
// server policy (only FCFS (first come, first serve) now)
std::string policy;
// if server scheduler has no queries to run or not
bool scheduler_empty;
// current number of queries the server need to deal with
uint64_t working_size;
// add prompt token ids before generated tokens in results if set it true
bool return_prompt;
// server working thread
std::thread worker;
};

Expand Down
51 changes: 27 additions & 24 deletions neural_speed/models/model_utils/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

#include "models/model_utils/scheduler.h"

// il_worker
il_worker::il_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params)) {
// Iter_level_worker
Iter_level_worker::Iter_level_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_params(params)) {
if (m_ctx == nullptr) {
fprintf(stderr, "%s: error: unable to load model.\n", __func__);
exit(0);
Expand All @@ -30,7 +30,7 @@ il_worker::il_worker(const gpt_params& params) : m_ctx(model_init_from_gpt_param
threads = params.n_threads;
}

il_worker::~il_worker() {
Iter_level_worker::~Iter_level_worker() {
if (m_ctx != nullptr) {
model_free(m_ctx);
}
Expand All @@ -39,10 +39,12 @@ il_worker::~il_worker() {
}
}

// cbg_worker
cbg_worker::cbg_worker(const gpt_params& params) : il_worker(params) { m_ctx->cont_batching = true; }
// Cont_batch_gen_worker
Cont_batch_gen_worker::Cont_batch_gen_worker(const gpt_params& params) : Iter_level_worker(params) {
m_ctx->cont_batching = true;
}

bool cbg_worker::prepare_inputs(std::vector<sequence>* seqs, const int& n_input, model_input* inputs) {
bool Cont_batch_gen_worker::prepare_inputs(std::vector<sequence>* seqs, const int& n_input, model_input* inputs) {
for (int i = 0; i < n_input; ++i) {
if ((seqs->at(i)).status != seq_status::PREFILL && (seqs->at(i)).status != seq_status::DECODING) {
fprintf(stderr, "%s: error: request %d status is unright (%d).\n", __func__, seqs->at(i).request_idx,
Expand Down Expand Up @@ -73,7 +75,7 @@ bool cbg_worker::prepare_inputs(std::vector<sequence>* seqs, const int& n_input,
return true;
}

bool cbg_worker::beam_search_step(std::vector<sequence>* seqs, const int& n_input) {
bool Cont_batch_gen_worker::beam_search_step(std::vector<sequence>* seqs, const int& n_input) {
std::vector<model_input> step_inputs(n_input);
if (!prepare_inputs(seqs, n_input, step_inputs.data())) {
return false;
Expand All @@ -85,7 +87,7 @@ bool cbg_worker::beam_search_step(std::vector<sequence>* seqs, const int& n_inpu
return true;
}

bool cbg_worker::step(std::vector<sequence>* seqs, const int& n_input) {
bool Cont_batch_gen_worker::step(std::vector<sequence>* seqs, const int& n_input) {
reqidx_to_vecid.clear();
for (int ni = 0; ni < n_input; ++ni) {
reqidx_to_vecid.emplace(seqs->at(ni).request_idx, ni);
Expand All @@ -99,7 +101,7 @@ bool cbg_worker::step(std::vector<sequence>* seqs, const int& n_input) {
return update_seqs(seqs, n_input);
}

bool cbg_worker::update_seqs(std::vector<sequence>* seqs, const int& n_input) {
bool Cont_batch_gen_worker::update_seqs(std::vector<sequence>* seqs, const int& n_input) {
empty_request_done_ids();
for (int ni = 0; ni < n_input; ++ni) {
if (seqs->at(ni).status == seq_status::PREFILL) {
Expand Down Expand Up @@ -142,18 +144,18 @@ bool cbg_worker::update_seqs(std::vector<sequence>* seqs, const int& n_input) {
return false; // TODO (YZT) greedy search and top_p-top_k sampling
}

// il_scheduler
il_scheduler::il_scheduler(const gpt_params& params, const std::string& policy, const int& log_level)
// Iter_level_scheduler
Iter_level_scheduler::Iter_level_scheduler(const gpt_params& params, const std::string& policy, const int& log_level)
: params(params),
policy(parse_serve_policy(policy)),
waiting_pool(pool_property::WAITING),
running_pool(pool_property::RUNNING),
finished_pool(pool_property::FINISHED),
log_level(log_level) {}

il_scheduler::il_scheduler(const gpt_params& params) : il_scheduler(params, "fcfs", 1) {}
Iter_level_scheduler::Iter_level_scheduler(const gpt_params& params) : Iter_level_scheduler(params, "fcfs", 1) {}

std::vector<sequence> il_scheduler::pop_completed_requests() {
std::vector<sequence> Iter_level_scheduler::pop_completed_requests() {
std::vector<sequence> ret_seqs;
const int length = finished_pool.size();
if (length == 0) {
Expand All @@ -174,22 +176,23 @@ std::vector<sequence> il_scheduler::pop_completed_requests() {
return ret_seqs;
}

// cbg_scheduler
cbg_scheduler::cbg_scheduler(const gpt_params& params)
: il_scheduler(params),
// Cont_batch_gen_scheduler
Cont_batch_gen_scheduler::Cont_batch_gen_scheduler(const gpt_params& params)
: Iter_level_scheduler(params),
max_requests(params.max_request_num),
wr(params),
free_req_idx(max_requests, true),
waiting_free_req_idx_seqs_num(0) {}

cbg_scheduler::cbg_scheduler(const gpt_params& params, const std::string& policy, const int& log_level)
: il_scheduler(params, policy, log_level),
Cont_batch_gen_scheduler::Cont_batch_gen_scheduler(const gpt_params& params, const std::string& policy,
const int& log_level)
: Iter_level_scheduler(params, policy, log_level),
max_requests(params.max_request_num),
wr(params),
free_req_idx(max_requests, true),
waiting_free_req_idx_seqs_num(0) {}

int cbg_scheduler::query_free_req_idx() {
int Cont_batch_gen_scheduler::query_free_req_idx() {
auto iter = std::find_if(free_req_idx.begin(), free_req_idx.end(), [](const bool flag) { return flag; });
if (iter == free_req_idx.end()) {
return -1;
Expand All @@ -200,7 +203,7 @@ int cbg_scheduler::query_free_req_idx() {
}
}

bool cbg_scheduler::add_request(sequence seq) {
bool Cont_batch_gen_scheduler::add_request(sequence seq) {
seq.receive_time = model_time_us();
if (seq.status != seq_status::UNKNOWN) {
fprintf(stderr, "%s: error: seq status is not UNKNOWN, can not decide to add into which pool.\n", __func__);
Expand All @@ -216,7 +219,7 @@ bool cbg_scheduler::add_request(sequence seq) {
return waiting_pool.add(seq);
}

bool cbg_scheduler::prepare_seqs() {
bool Cont_batch_gen_scheduler::prepare_seqs() {
executed_seqs.clear();
cur_running_num = running_pool.size();
if (cur_running_num > max_requests) {
Expand Down Expand Up @@ -269,7 +272,7 @@ bool cbg_scheduler::prepare_seqs() {
return true;
}

bool cbg_scheduler::step() {
bool Cont_batch_gen_scheduler::step() {
int64_t s_t0 = model_time_us();
if (done()) {
fprintf(stderr,
Expand Down Expand Up @@ -309,7 +312,7 @@ bool cbg_scheduler::step() {
return success;
}

bool cbg_scheduler::update_pools() {
bool Cont_batch_gen_scheduler::update_pools() {
for (int ns = 0; ns < executed_seqs.size(); ++ns) {
if (executed_seqs[ns].status == seq_status::DECODING) {
running_pool.add(executed_seqs[ns]);
Expand All @@ -332,7 +335,7 @@ bool cbg_scheduler::update_pools() {
return true;
}

bool cbg_scheduler::done() {
bool Cont_batch_gen_scheduler::done() {
if (waiting_pool.empty() && running_pool.empty()) {
return true;
} else {
Expand Down
32 changes: 16 additions & 16 deletions neural_speed/models/model_utils/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
#include "models/model_utils/model_utils.h"

// iteration-level worker
class il_worker {
class Iter_level_worker {
public:
explicit il_worker(const gpt_params& params);
virtual ~il_worker();
explicit Iter_level_worker(const gpt_params& params);
virtual ~Iter_level_worker();
virtual bool step(std::vector<sequence>* seqs, const int& n_input) = 0;
// virtual bool greedy_search_step(sequence seqs, const int& n_input) = 0;
virtual bool beam_search_step(std::vector<sequence>* seqs, const int& n_input) = 0;
Expand All @@ -43,11 +43,11 @@ class il_worker {
};

// continuous batching generation worker
class cbg_worker : public il_worker {
class Cont_batch_gen_worker : public Iter_level_worker {
public:
explicit cbg_worker(const gpt_params& params);
cbg_worker(const gpt_params& params, const int& n_threads);
~cbg_worker() = default;
explicit Cont_batch_gen_worker(const gpt_params& params);
Cont_batch_gen_worker(const gpt_params& params, const int& n_threads);
~Cont_batch_gen_worker() = default;

bool step(std::vector<sequence>* seqs, const int& n_input) override;
// bool greedy_search_step(sequence seqs, const int& n_input) override;
Expand All @@ -59,11 +59,11 @@ class cbg_worker : public il_worker {
};

// iteration-level scheduler
class il_scheduler {
class Iter_level_scheduler {
public:
explicit il_scheduler(const gpt_params& params);
il_scheduler(const gpt_params& params, const std::string& policy, const int& log_level);
virtual ~il_scheduler() = default;
explicit Iter_level_scheduler(const gpt_params& params);
Iter_level_scheduler(const gpt_params& params, const std::string& policy, const int& log_level);
virtual ~Iter_level_scheduler() = default;

// TODO (YZT) kv cache ptr as input params
virtual bool add_request(sequence seq) = 0;
Expand All @@ -85,11 +85,11 @@ class il_scheduler {
};

// continuous batching generation scheduler
class cbg_scheduler : public il_scheduler {
class Cont_batch_gen_scheduler : public Iter_level_scheduler {
public:
explicit cbg_scheduler(const gpt_params& params);
cbg_scheduler(const gpt_params& params, const std::string& policy, const int& log_level);
~cbg_scheduler() = default;
explicit Cont_batch_gen_scheduler(const gpt_params& params);
Cont_batch_gen_scheduler(const gpt_params& params, const std::string& policy, const int& log_level);
~Cont_batch_gen_scheduler() = default;

bool add_request(sequence seq) override;
bool step() override;
Expand All @@ -101,7 +101,7 @@ class cbg_scheduler : public il_scheduler {
int query_free_req_idx();

const int max_requests;
cbg_worker wr;
Cont_batch_gen_worker wr;
std::vector<sequence> executed_seqs;
std::vector<bool> free_req_idx;
int waiting_free_req_idx_seqs_num;
Expand Down

0 comments on commit 949a463

Please sign in to comment.