Skip to content

Commit

Permalink
Merge pull request #5 from popovaan/requires_sampling_fix
Browse files Browse the repository at this point in the history
Fixed bug in free_sequence_partially(), fixed requires_sampling().
  • Loading branch information
ilya-lavrenov authored May 8, 2024
2 parents 17fdc12 + fb63c52 commit 96b39c7
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ FetchContent_MakeAvailable(googletest)


set(TEST_TARGET_NAME "tests_continuous_batching")
add_executable(${TEST_TARGET_NAME} "src/tests/scheduler.cpp")
add_executable(${TEST_TARGET_NAME} "src/tests/scheduler.cpp" "src/tests/block_manager.cpp")
target_link_libraries(${TEST_TARGET_NAME} PUBLIC ${TARGET_NAME} openvino::runtime gtest_main)
target_include_directories(${TEST_TARGET_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/"
PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,11 @@ class BlockManager {
// TODO: support for groups with multiple sequences
auto block_table = m_block_table[seq_id];

if (block_num == block_table.size())
return free_sequence(seq_id);

OPENVINO_ASSERT(block_table.size() >= block_num);
for (size_t idx = 0; idx < block_num; idx++) {
m_allocator.free(block_table.back());
OPENVINO_ASSERT(block_table.back()->is_free());
size_t block_idx = m_block_table[seq_id].size() - idx - 1;
m_allocator.free(block_table[block_idx]);
OPENVINO_ASSERT(block_table[block_idx]->is_free());
}
m_block_table[seq_id].resize(m_block_table[seq_id].size() - block_num);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ class Sequence {
return m_cumulative_log_prob;
}

// TODO: need to remove this when sampling is fixed to properly handle the case when sequnce group is returned after preemption
void remove_tokens(size_t count) {
OPENVINO_ASSERT(m_generated_ids.size() >= count);
m_generated_ids.erase(m_generated_ids.end() - count, m_generated_ids.end());
}

float get_beam_search_score(const GenerationConfig& sampling_params) const {
float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len();
float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty);
Expand Down Expand Up @@ -245,13 +239,6 @@ class SequenceGroup {
void preempt_tokens(size_t num_preempt_tokens) {
OPENVINO_ASSERT(num_preempt_tokens <= m_num_processed_tokens);
m_num_processed_tokens -= num_preempt_tokens;
m_max_content_len -= num_preempt_tokens;

// this removal of tokens prevents duplicating of generated tokens after preemption of a sequence
// TODO: need to remove this when sampling is fixed to properly handle the case when sequnce group is returned after preemption
for (auto seq: m_sequences) {
seq->remove_tokens(std::min<size_t>(num_preempt_tokens, seq->get_generated_len()));
}
}

// returns context length taking into account scheduled tokens
Expand All @@ -261,7 +248,7 @@ class SequenceGroup {
}

bool requires_sampling() const {
return get_context_len() >= get_prompt_len();
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len;
}

void schedule_tokens(size_t num_tokens) {
Expand Down Expand Up @@ -327,8 +314,6 @@ class SequenceGroup {
}

bool is_empty() {
if (m_max_content_len > 0 || m_num_processed_tokens > 0)
return false;
if (m_sequences.size() > 1)
return false;
OPENVINO_ASSERT(m_sequences.size() == 1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>
#include "openvino/runtime/core.hpp"
#include "continuous_batching_pipeline.hpp"
#include "sequence_group.hpp"
#include "scheduler.hpp"
#include "generation_config.hpp"

TEST(TestBlockManager, general_test) {
BlockManager bm = BlockManager(6);

bm.allocate(0, 6);
EXPECT_TRUE(bm.has_block_table(0));
EXPECT_EQ(bm.get_block_table(0).size(), 6);
EXPECT_EQ(bm.num_free_blocks(), 0);

bm.free_sequence_partially(0, 4);
EXPECT_EQ(bm.get_block_table(0).size(), 2);
EXPECT_EQ(bm.num_free_blocks(), 4);

bm.free_sequence(0);
EXPECT_FALSE(bm.has_block_table(0));
EXPECT_EQ(bm.num_free_blocks(), 6);

bm.allocate(0, 2);
bm.fork_sequence(0, 1);
EXPECT_TRUE(bm.has_block_table(1));
EXPECT_EQ(bm.get_block_table(1).back()->get_references_count(), 2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,9 @@ TEST(TestScheduler, general_test) {
EXPECT_FALSE(out4.m_block_tables[idx2][1]->is_free());
EXPECT_EQ(out4.m_block_tables[idx2][1]->get_index(), 1);

if (scheduler_config.dynamic_split_fuse) {
// requests1[1] should be fully scheduled plus 1 slot for requests[0] for generate phase
EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len() + 1);
EXPECT_EQ(out4.is_prompt, false);
}
else {
// requests1[1] should be fully scheduled on prompt phase, generate phase is not scheduled here
EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len());
EXPECT_EQ(out4.is_prompt, true);
}

// requests1[1] should be fully scheduled plus 1 slot for requests[0] for generate phase
EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len() + 1);
EXPECT_EQ(out4.is_prompt, false);
}

}
Expand Down

0 comments on commit 96b39c7

Please sign in to comment.