Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix caching. #675

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/cpp/include/openvino/genai/scheduler_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,13 @@ struct SchedulerConfig {

// max number of scheduled sequences (you can think of it as "max batch size")
std::size_t max_num_seqs = 256;

// Enable caching of KV-blocks.
// When turned on all previously calculated KV-caches are kept in memory for future usages.
// KV-caches can be rewritten if KV-cache limit is reached, but blocks are not released.
// This results in more RAM usage, maximum RAM usage is determined by cache_size or num_kv_blocks parameters.
// When turend off only KV-cache required for batch calculation is kept in memory and
// when a sequence has finished genegartion its cache is released.
bool enable_prefix_caching = false;
};
}
259 changes: 244 additions & 15 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
#include <memory>
#include <list>
#include <map>
#include <chrono>

#include "sequence_group.hpp"

namespace ov::genai {
class KVCacheBlock {
int m_ref_count;
int m_index;
size_t m_hash;
size_t m_num_hashed_tokens;
std::chrono::time_point<std::chrono::system_clock> m_timestamp;
public:
using Ptr = std::shared_ptr<KVCacheBlock>;
using CPtr = std::shared_ptr<const KVCacheBlock>;

explicit KVCacheBlock(int index)
: m_ref_count(0),
m_index(index) { }
m_index(index),
popovaan marked this conversation as resolved.
Show resolved Hide resolved
m_timestamp(std::chrono::system_clock::now()) { }

int get_index() const {
return m_index;
Expand All @@ -34,6 +39,7 @@ class KVCacheBlock {
}

void release() {
OPENVINO_ASSERT(m_ref_count > 0);
--m_ref_count;
}

Expand All @@ -44,15 +50,79 @@ class KVCacheBlock {
int get_references_count() const {
return m_ref_count;
}

size_t get_hash() const {
return m_hash;
}

size_t get_num_hashed_tokens() const {
return m_num_hashed_tokens;
}

void set_hash(size_t hash, size_t num_hashed_tokens) {
m_hash = hash;
m_num_hashed_tokens = num_hashed_tokens;
popovaan marked this conversation as resolved.
Show resolved Hide resolved
}

void set_timestamp(const std::chrono::time_point<std::chrono::system_clock>& timestamp) {
m_timestamp = timestamp;
}

std::chrono::time_point<std::chrono::system_clock> get_timestamp() {
return m_timestamp;
}
};


class Evictor {
std::map<size_t, KVCacheBlock::Ptr> blocks;
popovaan marked this conversation as resolved.
Show resolved Hide resolved
public:
void add(size_t hash, KVCacheBlock::Ptr block) {
blocks[hash] = block;
}

static bool block_is_less(const std::pair<size_t, KVCacheBlock::Ptr>& lhs, const std::pair<size_t, KVCacheBlock::Ptr>& rhs) {
return lhs.second->get_timestamp() < rhs.second->get_timestamp();
}
popovaan marked this conversation as resolved.
Show resolved Hide resolved

KVCacheBlock::Ptr get_block(size_t hash) {
if (blocks.find(hash)== blocks.end())
{
return nullptr;
}
KVCacheBlock::Ptr block = blocks[hash];
popovaan marked this conversation as resolved.
Show resolved Hide resolved
block->set_timestamp(std::chrono::system_clock::now());
block->increment();
blocks.erase(hash);
return block;
}

KVCacheBlock::Ptr get_lru_block() {
if (!blocks.size()) {
return nullptr;
}
auto hash_block = std::min_element(std::begin(blocks), std::end(blocks), block_is_less);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we need to store blocks in already sorted manner?
In this case get_lru_block will take O(1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we store blocks in a sorted structure like priority_queue, get_lru_block() will be O(1) but get_block() by hash will be O(n), as we will need to loop over all blocks and check hashes:

KVCacheBlock::Ptr get_block(size_t hash) {

Currently with hash table get_block() is O(1). We can probably have two structures in evictor both hash table and priority_queue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, we can try this optimization.

But maybe let's first perform comparison with vLLM implementation? We can measure how much time this piece of code takes averagely.

auto block = hash_block->second;
block->set_timestamp(std::chrono::system_clock::now());
block->increment();
blocks.erase(hash_block->first);
return block;
}

size_t num_blocks() const {
return blocks.size();
}
};


class BlockAllocator {
std::list<KVCacheBlock::Ptr> m_free_blocks;
ov::genai::Evictor m_evictor;
int m_total_num_blocks;
bool m_enable_prefix_caching;
public:
BlockAllocator(int num_blocks) :
m_total_num_blocks(num_blocks) {
BlockAllocator(int num_blocks, bool enable_prefix_caching) :
m_total_num_blocks(num_blocks), m_enable_prefix_caching(enable_prefix_caching) {
for (int block_id = 0; block_id < m_total_num_blocks; ++block_id) {
m_free_blocks.push_back(std::make_shared<KVCacheBlock>(block_id));
}
Expand All @@ -64,42 +134,113 @@ class BlockAllocator {
}

size_t num_free_blocks() const {
return m_free_blocks.size();
return m_free_blocks.size() + m_evictor.num_blocks();
}

bool can_allocate_blocks(size_t num_blocks) const {
return num_blocks <= m_free_blocks.size();
return num_blocks <= num_free_blocks();
}

void free(KVCacheBlock::Ptr block) {
block->release();
if (block->is_free()) {
m_free_blocks.push_back(block);
if (m_enable_prefix_caching)
{
m_evictor.add(block->get_hash(), block);
popovaan marked this conversation as resolved.
Show resolved Hide resolved
}
else {
m_free_blocks.push_back(block);
}
}
}

KVCacheBlock::Ptr allocate_block() {
OPENVINO_ASSERT(!m_enable_prefix_caching);
OPENVINO_ASSERT(can_allocate_blocks(1));
KVCacheBlock::Ptr allocated_block = m_free_blocks.front();
allocated_block->increment();
m_free_blocks.pop_front();
return allocated_block;
}

KVCacheBlock::Ptr allocate_block(size_t hash, size_t num_hashed_tokens, std::map<uint64_t, KVCacheBlock::Ptr>& cached_blocks) {
OPENVINO_ASSERT(m_enable_prefix_caching);
OPENVINO_ASSERT(can_allocate_blocks(1));
auto block = m_evictor.get_block(hash);
if (block != nullptr) {
// use cached block from evictor
cached_blocks[hash] = block;
return block;
}
// TODO: Currently we cache all allocated blocks which might be redundant for beam search,
// where blocks of non-used candidates are not needed in cache.
// This part can be improved if we cache only blocks for prompt.
if (cached_blocks.find(hash) != cached_blocks.end()) {
// use cashed block from cached_blocks
block = cached_blocks[hash];
cached_blocks[hash]->increment();
return block;
}
if (m_free_blocks.size() > 0) {
// allocate new empty block
KVCacheBlock::Ptr allocated_block = m_free_blocks.front();
allocated_block->increment();
allocated_block->set_hash(hash, num_hashed_tokens);
cached_blocks[hash] = allocated_block;

m_free_blocks.pop_front();
return allocated_block;
}
if (m_evictor.num_blocks() > 0) {
// get least resently used block from evictor and reuse it
KVCacheBlock::Ptr block = m_evictor.get_lru_block();
cached_blocks.erase(block->get_hash());

// update block with new hash
block->set_hash(hash, num_hashed_tokens);
cached_blocks[hash] = block;
return block;
}
// out of memory
return nullptr;
}

KVCacheBlock::Ptr get_cached_block(size_t hash, std::map<uint64_t, KVCacheBlock::Ptr>& cached_blocks) {
auto block = m_evictor.get_block(hash);
if (block != nullptr) {
// use cashed block from evictor
cached_blocks[hash] = block;
return block;
}
if (cached_blocks.find(hash) != cached_blocks.end()) {
// use cashed block from cached_blocks
// TODO: add tokens validation in case of hash collision
block = cached_blocks[hash];
cached_blocks[hash]->increment();
return block;
}
return nullptr;
}


float get_used_percentage() const {
return static_cast<float>(m_total_num_blocks - m_free_blocks.size()) / m_total_num_blocks;
return static_cast<float>(m_total_num_blocks - num_free_blocks()) / m_total_num_blocks;
}
};

class BlockManager {
BlockAllocator m_allocator;
bool m_enable_prefix_caching;
size_t m_block_size;
// TODO: caching time can probably be improved if we use the prefix tree
std::map<uint64_t, KVCacheBlock::Ptr> cached_blocks;
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved

// stores blocks for each sequence (not sequence group)
// the same block can be seen in multiple block_tables for different sequences
std::map<uint64_t, std::vector<KVCacheBlock::Ptr>> m_block_table;
public:
BlockManager(int num_blocks)
: m_allocator(num_blocks) { }
BlockManager(int num_blocks, bool enable_prefix_caching, size_t block_size)
: m_allocator(num_blocks, enable_prefix_caching), m_enable_prefix_caching(enable_prefix_caching), m_block_size(block_size) { }

~BlockManager() {
// sanity check that all sequences are freed
Expand Down Expand Up @@ -195,11 +336,32 @@ class BlockManager {
return m_allocator.can_allocate_blocks(num_blocks);
}

void allocate(uint64_t sequence_id, size_t num_blocks) {
void allocate(ov::genai::Sequence::CPtr sequence, size_t num_blocks, const ov::genai::TokenIds& prompt_ids = {}) {
OPENVINO_ASSERT(num_blocks > 0 && can_allocate_blocks(num_blocks));
if (m_enable_prefix_caching) {
OPENVINO_ASSERT(prompt_ids.size() > 0, "prompt_ids should be set for hash calculation.");
}
popovaan marked this conversation as resolved.
Show resolved Hide resolved
auto sequence_id = sequence->get_id();
auto block_table = m_block_table[sequence_id];
auto content_length = sequence->get_generated_len() + prompt_ids.size();
size_t num_hashed_tokens = block_table.size() * m_block_size;

for (size_t i = 0; i < num_blocks; ++i) {
m_block_table[sequence_id].push_back(m_allocator.allocate_block());

ov::genai::KVCacheBlock::Ptr block = nullptr;
if (m_enable_prefix_caching) {
num_hashed_tokens += m_block_size;
if (num_hashed_tokens > content_length) {
num_hashed_tokens = content_length;
}
auto hash = sequence->get_hash(num_hashed_tokens, prompt_ids);
block = m_allocator.allocate_block(hash, num_hashed_tokens, cached_blocks);
}
else {
block = m_allocator.allocate_block();
}
OPENVINO_ASSERT(block != nullptr);
m_block_table[sequence_id].push_back(block);
}
}

Expand Down Expand Up @@ -324,27 +486,94 @@ class BlockManager {

if (num_logical_blocks > num_physical_blocks) {
OPENVINO_ASSERT(can_allocate_blocks(num_logical_blocks - num_physical_blocks));
allocate(seq_id, num_logical_blocks - num_physical_blocks);
allocate(sequence, num_logical_blocks - num_physical_blocks, seq_group->get_prompt_ids());
} else {
OPENVINO_ASSERT(num_logical_blocks == num_physical_blocks, "A number of physical and logic blocks must be the same in this code path");
KVCacheBlock::Ptr last_block = block_table.back();

if (last_block->copy_on_write()) {
// we need to fork current block, because reference counter is more than 1
KVCacheBlock::Ptr new_block = m_allocator.allocate_block();
KVCacheBlock::Ptr new_block = nullptr;
if (m_enable_prefix_caching) {
auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids());
new_block = m_allocator.allocate_block(hash, seq_group->get_context_len(), cached_blocks);
cached_blocks[hash] = new_block;
}
else {
new_block = m_allocator.allocate_block();
}
block_table[num_physical_blocks - 1] = new_block;
// write information about block forking for later usage in CacheManager
copy_blocks_map[last_block->get_index()].push_back(new_block->get_index());
// release `last_block` usage
m_allocator.free(last_block);
} else {
// nothing to do, because we are the only users of this block
// we are the only users of this block
if (m_enable_prefix_caching) {
// update hash of block
auto prev_hash = last_block->get_hash();
auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids());
last_block->set_hash(hash, seq_group->get_context_len());
cached_blocks.erase(prev_hash);
cached_blocks[hash] = last_block;
}
}
}
}

// it returns information which blocks should be forked by CacheManager
return copy_blocks_map;
}


void _restore_cached_blocks(SequenceGroup::Ptr group, size_t block_size) {
auto prompt_ids = group->get_prompt_ids();
auto sequences = group->get_not_finished_sequences();
OPENVINO_ASSERT(sequences.size() == 1);
auto sequence = sequences[0];
auto seq_id = sequence->get_id();
auto& block_table = m_block_table[seq_id];

size_t content_len = 0;
while (content_len < prompt_ids.size()) {
size_t prev_iteration_content_len = content_len;
content_len += block_size;
if (content_len > prompt_ids.size()) {
content_len = prompt_ids.size();
}
// restore fully filled blocks
auto hash = sequence->get_hash(content_len, prompt_ids);
auto block = m_allocator.get_cached_block(hash, cached_blocks);
if (block != nullptr) {
block->set_timestamp(std::chrono::system_clock::now());
m_block_table[seq_id].push_back(block);
group->update_processed_tokens_num(content_len);
}
else {
// restore partially filled block
for (size_t i = 1; i < block_size; i++) {
if (prev_iteration_content_len + i > prompt_ids.size()) {
break;
}
auto hash = sequence->get_hash(prev_iteration_content_len + i, prompt_ids);
auto block = m_allocator.get_cached_block(hash, cached_blocks);
if (block != nullptr) {
block->set_timestamp(std::chrono::system_clock::now());
m_block_table[seq_id].push_back(block);
group->update_processed_tokens_num(prev_iteration_content_len + i);

size_t new_tokens_count_in_block = std::min(content_len, prev_iteration_content_len + block_size);
if (new_tokens_count_in_block > prev_iteration_content_len + i) {
cached_blocks.erase(hash);
auto new_hash = sequence->get_hash(new_tokens_count_in_block, prompt_ids);
cached_blocks[new_hash] = block;
}

break;
}
}
break;
}
}
}
};
}
Loading
Loading