diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/CMakeLists.txt b/text_generation/causal_lm/cpp/continuous_batching/library/CMakeLists.txt index fa55d19784..e42e52944b 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/CMakeLists.txt +++ b/text_generation/causal_lm/cpp/continuous_batching/library/CMakeLists.txt @@ -67,7 +67,7 @@ FetchContent_MakeAvailable(googletest) set(TEST_TARGET_NAME "tests_continuous_batching") -add_executable(${TEST_TARGET_NAME} "src/tests/scheduler.cpp") +add_executable(${TEST_TARGET_NAME} "src/tests/scheduler.cpp" "src/tests/block_manager.cpp") target_link_libraries(${TEST_TARGET_NAME} PUBLIC ${TARGET_NAME} openvino::runtime gtest_main) target_include_directories(${TEST_TARGET_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/" PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/block_manager.hpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/block_manager.hpp index 57a3babd93..b0c3055bce 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/block_manager.hpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/block_manager.hpp @@ -154,13 +154,11 @@ class BlockManager { // TODO: support for groups with multiple sequences auto block_table = m_block_table[seq_id]; - if (block_num == block_table.size()) - return free_sequence(seq_id); - OPENVINO_ASSERT(block_table.size() >= block_num); for (size_t idx = 0; idx < block_num; idx++) { - m_allocator.free(block_table.back()); - OPENVINO_ASSERT(block_table.back()->is_free()); + size_t block_idx = m_block_table[seq_id].size() - idx - 1; + m_allocator.free(block_table[block_idx]); + OPENVINO_ASSERT(block_table[block_idx]->is_free()); } m_block_table[seq_id].resize(m_block_table[seq_id].size() - block_num); diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/sequence_group.hpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/sequence_group.hpp index cbb1e9a65b..29b0af5133 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/sequence_group.hpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/sequence_group.hpp @@ -87,12 +87,6 @@ class Sequence { return m_cumulative_log_prob; } - // TODO: need to remove this when sampling is fixed to properly handle the case when sequnce group is returned after preemption - void remove_tokens(size_t count) { - OPENVINO_ASSERT(m_generated_ids.size() >= count); - m_generated_ids.erase(m_generated_ids.end() - count, m_generated_ids.end()); - } - float get_beam_search_score(const GenerationConfig& sampling_params) const { float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len(); float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty); @@ -245,13 +239,6 @@ class SequenceGroup { void preempt_tokens(size_t num_preempt_tokens) { OPENVINO_ASSERT(num_preempt_tokens <= m_num_processed_tokens); m_num_processed_tokens -= num_preempt_tokens; - m_max_content_len -= num_preempt_tokens; - - // this removal of tokens prevents duplicating of generated tokens after preemption of a sequence - // TODO: need to remove this when sampling is fixed to properly handle the case when sequnce group is returned after preemption - for (auto seq: m_sequences) { - seq->remove_tokens(std::min(num_preempt_tokens, seq->get_generated_len())); - } } // returns context length taking into account scheduled tokens @@ -261,7 +248,7 @@ class SequenceGroup { } bool requires_sampling() const { - return get_context_len() >= get_prompt_len(); + return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len; } void schedule_tokens(size_t num_tokens) { @@ -327,8 +314,6 @@ class SequenceGroup { } bool is_empty() { - if (m_max_content_len > 0 || m_num_processed_tokens > 0) - return false; if (m_sequences.size() > 1) return false; OPENVINO_ASSERT(m_sequences.size() == 1); diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/block_manager.cpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/block_manager.cpp new file mode 100644 index 0000000000..79762318c9 --- /dev/null +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/block_manager.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "openvino/runtime/core.hpp" +#include "continuous_batching_pipeline.hpp" +#include "sequence_group.hpp" +#include "scheduler.hpp" +#include "generation_config.hpp" + +TEST(TestBlockManager, general_test) { + BlockManager bm = BlockManager(6); + + bm.allocate(0, 6); + EXPECT_TRUE(bm.has_block_table(0)); + EXPECT_EQ(bm.get_block_table(0).size(), 6); + EXPECT_EQ(bm.num_free_blocks(), 0); + + bm.free_sequence_partially(0, 4); + EXPECT_EQ(bm.get_block_table(0).size(), 2); + EXPECT_EQ(bm.num_free_blocks(), 4); + + bm.free_sequence(0); + EXPECT_FALSE(bm.has_block_table(0)); + EXPECT_EQ(bm.num_free_blocks(), 6); + + bm.allocate(0, 2); + bm.fork_sequence(0, 1); + EXPECT_TRUE(bm.has_block_table(1)); + EXPECT_EQ(bm.get_block_table(1).back()->get_references_count(), 2); +} diff --git a/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/scheduler.cpp b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/scheduler.cpp index bce92c3557..3b0ae698c8 100644 --- a/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/scheduler.cpp +++ b/text_generation/causal_lm/cpp/continuous_batching/library/src/tests/scheduler.cpp @@ -108,17 +108,9 @@ TEST(TestScheduler, general_test) { EXPECT_FALSE(out4.m_block_tables[idx2][1]->is_free()); EXPECT_EQ(out4.m_block_tables[idx2][1]->get_index(), 1); - if (scheduler_config.dynamic_split_fuse) { - // requests1[1] should be fully scheduled plus 1 slot for requests[0] for generate phase - EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len() + 1); - EXPECT_EQ(out4.is_prompt, false); - } - else { - // requests1[1] should be fully scheduled on prompt phase, generate phase is not scheduled here - EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len()); - EXPECT_EQ(out4.is_prompt, true); - } - + // requests1[1] should be fully scheduled plus 1 slot for requests[0] for generate phase + EXPECT_EQ(out4.m_total_num_scheduled_tokens, requests[1]->get_context_len() + 1); + EXPECT_EQ(out4.is_prompt, false); } }