Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jul 22, 2024
1 parent e7ab3c7 commit cc1c465
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SpeculativeDecodingPipeline {
for (const auto& request : candidate_sequences) {
const auto& request_id = request.first;
for (const auto& sequence : request.second) {
model_pipeline.update_generated_sequence(sequence.second, request_id, sequence.first);
model_pipeline.update_generated_sequence(sequence.second.first, sequence.second.second, request_id, sequence.first);
}
}
}
Expand All @@ -69,14 +69,14 @@ class SpeculativeDecodingPipeline {
const auto& sequence_id = sequence.first;
const auto& generated_sequence = sequence.second;
const auto& candidate_sequence = candidate_sequences[request_id].at(sequence_id);
const auto generated_sequence_size = generated_sequence.size();
const auto candidate_sequence_size = candidate_sequence.size();
const auto generated_sequence_size = generated_sequence.first.size();
const auto candidate_sequence_size = candidate_sequence.first.size();
if (generated_sequence_size <= candidate_sequence_size) {
const auto dist = candidate_sequence_size - generated_sequence_size + 1;
assisting_pipeline.remove_tokens_from_sequences(dist, request_id, sequence_id);
max_removed_token_cnt = std::max(max_removed_token_cnt, dist);
}
assisting_pipeline.update_generated_sequence(generated_sequence, request_id, sequence_id);
assisting_pipeline.update_generated_sequence(generated_sequence.first, generated_sequence.second, request_id, sequence_id);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class ContinuousBatchingPipeline {

public:
using GeneratedTokensSeq = std::vector<int64_t>;
using GeneratedLogProbs = std::vector<float>;
// { request_id, { sequence_id, tokens }}
using GeneratedTokensMap = std::map<uint64_t, std::map<uint64_t, GeneratedTokensSeq>>;
using RemoveTokensMap = std::map<uint64_t, std::map<uint64_t, uint64_t>>;
using GeneratedTokensMap = std::map<uint64_t, std::map<uint64_t, std::pair<GeneratedTokensSeq, GeneratedLogProbs>>>;

ContinuousBatchingPipeline(const std::string& models_path,
const SchedulerConfig& scheduler_config,
Expand All @@ -46,10 +46,8 @@ class ContinuousBatchingPipeline {
GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params);

GeneratedTokensMap get_generated_sequences();
void update_generated_sequence(const GeneratedTokensSeq& tokens, uint64_t request_id, uint64_t sequence_id);
void update_generated_sequence(const GeneratedTokensSeq& tokens, const GeneratedLogProbs& log_probs, uint64_t request_id, uint64_t sequence_id);
void remove_tokens_from_sequences(size_t k, uint64_t request_id, uint64_t sequence_id);
void set_to_free_sequences(bool is_free);
void free_all_sequences();
void set_speculative_decoding_mode();

void step();
Expand Down
21 changes: 6 additions & 15 deletions src/cpp/continuous_batching/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class ContinuousBatchingPipeline::Impl {
}
for (const auto& sequence : request->get_sequences()) {
auto generated_ids = sequence->get_generated_ids();
res[request_id].insert({ sequence->get_grouped_id(), generated_ids });
auto log_probs = sequence->get_log_probs();
res[request_id].insert({ sequence->get_grouped_id(), { generated_ids, log_probs }});
}
}
return res;
Expand All @@ -127,7 +128,7 @@ class ContinuousBatchingPipeline::Impl {
m_sampler->set_validation_mode();
}

void update_generated_sequence(const GeneratedTokensSeq& tokens, uint64_t request_id, uint64_t sequence_id) {
void update_generated_sequence(const GeneratedTokensSeq& tokens, const GeneratedLogProbs& log_probs, uint64_t request_id, uint64_t sequence_id) {
// Pull awaiting requests
if (!m_awaiting_requests.empty()) {
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
Expand All @@ -146,8 +147,7 @@ class ContinuousBatchingPipeline::Impl {
OPENVINO_ASSERT(generated_len <= updated_len);
max_inserted_token_cnt = std::max(max_inserted_token_cnt, updated_len - generated_len);
for (size_t i = generated_len; i < updated_len; ++i) {
// todo: insert correct prob
sequence->append_token(tokens[i], 1.f);
sequence->append_token(tokens[i], log_probs[i]);
}
break;
}
Expand Down Expand Up @@ -416,24 +416,15 @@ ContinuousBatchingPipeline::GeneratedTokensMap ContinuousBatchingPipeline::get_g
return m_impl->get_generated_sequences();
}

void ContinuousBatchingPipeline::update_generated_sequence(const GeneratedTokensSeq& tokens, uint64_t request_id, uint64_t sequence_id) {
return m_impl->update_generated_sequence(tokens, request_id, sequence_id);
void ContinuousBatchingPipeline::update_generated_sequence(const GeneratedTokensSeq& tokens, const GeneratedLogProbs& log_probs, uint64_t request_id, uint64_t sequence_id) {
return m_impl->update_generated_sequence(tokens, log_probs, request_id, sequence_id);
}


void ContinuousBatchingPipeline::remove_tokens_from_sequences(size_t k, uint64_t request_id, uint64_t sequence_id) {
return m_impl->remove_tokens_from_sequences(k, request_id, sequence_id);
}

void ContinuousBatchingPipeline::free_all_sequences() {
return m_impl->free_all_sequences();
}


void ContinuousBatchingPipeline::set_to_free_sequences(bool is_free) {
return m_impl->set_to_free_sequences(is_free);
}

void ContinuousBatchingPipeline::set_speculative_decoding_mode() {
return m_impl->set_speculative_decoding_mode();
}
11 changes: 9 additions & 2 deletions src/cpp/continuous_batching/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Sequence {
uint64_t m_id = _get_next_global_sequence_id();
SequenceStatus m_status = SequenceStatus::RUNNING;
float m_cumulative_log_prob = 0.0f;
std::vector<float> m_log_probs;

public:
using Ptr = std::shared_ptr<Sequence>;
Expand All @@ -43,6 +44,7 @@ class Sequence {
// don't use directly
Sequence(const Sequence& seq, const uint64_t id) :
m_generated_ids(seq.m_generated_ids),
m_log_probs(seq.m_log_probs),
m_grouped_id(id),
m_status(seq.m_status),
m_cumulative_log_prob(seq.m_cumulative_log_prob) {
Expand Down Expand Up @@ -92,12 +94,13 @@ class Sequence {
// appends new tokens to a generated part
void append_token(int64_t token_id, float log_prob) {
m_cumulative_log_prob += log_prob;
m_generated_ids.push_back(token_id);
m_generated_ids.push_back(token_id);
m_log_probs.push_back(log_prob);
}

// todo: iefode: remove probs
void remove_last_n_tokens(size_t n) {
m_generated_ids.resize(m_generated_ids.size() - n);
m_log_probs.resize(m_log_probs.size() - n);
}

GenerationOutput get_last_generation_output() {
Expand All @@ -116,6 +119,10 @@ class Sequence {
return m_generated_ids;
}

const std::vector<float> & get_log_probs() const {
return m_log_probs;
}

float get_cumulative_log_probs() const {
return m_cumulative_log_prob;
}
Expand Down

0 comments on commit cc1c465

Please sign in to comment.