diff --git a/CMakeLists.txt b/CMakeLists.txt index ab50f61e3..fa2c6be3a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,6 @@ include(cmake/libs/libfaiss.cmake) include(cmake/libs/libhnsw.cmake) include_directories(thirdparty/faiss) -include_directories(thirdparty/bitset) find_package(OpenMP REQUIRED) @@ -148,14 +147,6 @@ add_library(knowhere SHARED ${KNOWHERE_SRCS}) add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) if(WITH_RAFT) list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled) - find_library(LIBRAFT_FOUND raft) - if (NOT LIBRAFT_FOUND) - message(WARNING "libraft not found") - else() - message(STATUS "libraft found") - list(APPEND KNOWHERE_LINKER_LIBS ${LIBRAFT_FOUND}) - add_definitions(-DRAFT_COMPILED) - endif() endif() target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS}) target_include_directories( diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index a757b2ec6..1a3478fd3 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -532,11 +532,11 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { if (!this->index_->is_trained) { return unexpected(Status::index_not_trained); } - auto dim = Dim(); - auto rows = dataset.GetRows(); - auto ids = dataset.GetIds(); - if constexpr (std::is_same::value) { + auto dim = Dim(); + auto rows = dataset.GetRows(); + auto ids = dataset.GetIds(); + uint8_t* data = nullptr; try { data = new uint8_t[dim * rows / 8]; @@ -553,6 +553,10 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { return unexpected(Status::faiss_inner_error); } } else if constexpr (std::is_same::value) { + auto dim = Dim(); + auto rows = dataset.GetRows(); + auto ids = dataset.GetIds(); + float* data = nullptr; try { data = new float[dim * rows]; @@ -569,6 +573,10 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { return unexpected(Status::faiss_inner_error); } } else if constexpr (std::is_same::value) { + auto dim = Dim(); + auto rows = dataset.GetRows(); + auto ids = dataset.GetIds(); + float* data = nullptr; try { data = new float[dim * rows]; diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index ab9f79ad2..7226181d8 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -45,42 +45,86 @@ namespace knowhere { __global__ void -filter(const int k1, const int k2, const int nq, const uint8_t* bs, int64_t* ids_before, float* dis_before, - int64_t* ids, float* dis) { +filter(const uint8_t* bs, const int64_t* ids_before, int32_t* ids_block, int len) { int tx = blockIdx.x * blockDim.x + threadIdx.x; - int ty = blockIdx.y * blockDim.y + threadIdx.y; - extern __shared__ char s[]; - int64_t* ids_ = (int64_t*)s; - float* dis_ = (float*)&s[k2 * sizeof(int64_t)]; - if (tx >= k2) + if (tx >= len) return; - int64_t i = ids_before[ty * k2 + tx]; - float d = dis_before[ty * k2 + tx]; - bool check = bs[i >> 3] & (0x1 << (i & 0x7)); - if (!check) { - ids_[tx] = i; - dis_[tx] = d; - } else { - ids_[tx] = -1; - dis_[tx] = -1.0f; + __shared__ int64_t ids_before_s[1024]; + __shared__ int32_t ids_block_s[1024]; + + ids_before_s[threadIdx.x] = ids_before[tx]; + int64_t i = ids_before_s[threadIdx.x]; + ids_block_s[threadIdx.x] = 1 - (bs[i >> 3] && (0x1 << (i & 0x7))); + ids_block[tx] = ids_block_s[threadIdx.x]; +} + +__global__ void +prescan(int32_t* g_odata, int32_t* g_idata, int n) { + extern __shared__ int32_t temp[]; + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * blockDim.y + threadIdx.y; + int offset = 1; + temp[2 * tx] = g_idata[2 * tx + ty * n]; + temp[2 * tx + 1] = g_idata[2 * tx + 1 + ty * n]; + for (int d = n >> 1; d > 0; d >>= 1) { + __syncthreads(); + if (tx < d) { + int ai = offset * (2 * tx + 1) - 1; + int bi = offset * (2 * tx + 2) - 1; + temp[bi] += temp[ai]; + } + offset *= 2; } - __syncthreads(); if (tx == 0) { - int j = 0, k = 0; - while (j < k1 && k < k2) { - while (ids_[k] == -1) k++; - if (k >= k2) - break; - ids[ty * k1 + j] = ids_[k]; - dis[ty * k1 + j] = dis_[k]; - j++; - k++; - } - if (j != k1) { - ids_before[0] = -1; // destroy answer + temp[n - 1] = 0; + } + for (int d = 1; d < n; d *= 2) { + offset >>= 1; + __syncthreads(); + if (tx < d) { + int ai = offset * (2 * tx + 1) - 1; + int bi = offset * (2 * tx + 2) - 1; + int32_t t = temp[ai]; + temp[ai] = temp[bi]; + temp[bi] += t; } } __syncthreads(); + g_odata[2 * tx + ty * n] += temp[2 * tx]; + g_odata[2 * tx + 1 + ty * n] += temp[2 * tx + 1]; +} + +__global__ void +write_back(const int32_t* pre_ids_block, int64_t* ids_before, float* dis_before, int64_t* ids_after, float* dis_after, + int k1, int k2) { + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * blockDim.y + threadIdx.y; + + extern __shared__ int32_t temp[]; + for (int i = tx; i < k2; i += k1) { + temp[i] = pre_ids_block[ty * k2 + i]; + } + __syncthreads(); + int left = 0; + int right = k2; + while (left < right) { + int mid = (left + right) / 2; + if (temp[mid] == tx + 1) { + right = mid; + } else if (temp[mid] > tx + 1) { + right = mid; + } else if (temp[mid] < tx + 1) { + left = mid + 1; + } + } + if (left < k2) { + ids_after[ty * k1 + tx] = ids_before[ty * k2 + left]; + dis_after[ty * k1 + tx] = dis_before[ty * k2 + left]; + } else { + ids_after[ty * k1 + tx] = -1; + dis_after[ty * k1 + tx] = 1.0f / 0.0f; + ids_before[0] = -1; // destory old ans; + } } namespace raft_res_pool { @@ -443,7 +487,7 @@ class RaftIvfIndexNode : public IndexNode { k2 |= k2 >> 8; k2 |= k2 >> 14; k2 += 1; - while (k2 <= 1024) { + while (k2 <= (1 << 14)) { auto ids_gpu_before = raft::make_device_matrix(*res_, rows, k2); auto dis_gpu_before = raft::make_device_matrix(*res_, rows, k2); auto bs_gpu = raft::make_device_vector(*res_, bitset.byte_size()); @@ -453,9 +497,18 @@ class RaftIvfIndexNode : public IndexNode { raft::neighbors::ivf_flat::search( *res_, search_params, *gpu_index_, raft::make_const_mdspan(data_gpu.view()), ids_gpu_before.view(), dis_gpu_before.view()); - filter<<>>( - k1, k2, rows, bs_gpu.data_handle(), ids_gpu_before.data_handle(), - dis_gpu_before.data_handle(), ids_gpu.data_handle(), dis_gpu.data_handle()); + + auto ids_block = raft::make_device_vector(*res_, rows * k2); + filter<<<(rows * k2 + 1023) / 1024, 1024, 0, stream.value()>>>( + bs_gpu.data_handle(), ids_gpu_before.data_handle(), ids_block.data_handle(), rows * k2); + auto pre_ids_block = raft::make_device_vector(*res_, rows * k2); + RAFT_CUDA_TRY(cudaMemcpyAsync(pre_ids_block.data_handle(), ids_block.data_handle(), + rows * k2 * sizeof(int32_t), cudaMemcpyDefault, stream.value())); + prescan<<>>( + pre_ids_block.data_handle(), ids_block.data_handle(), k2); + write_back<<>>( + pre_ids_block.data_handle(), ids_gpu_before.data_handle(), dis_gpu_before.data_handle(), + ids_gpu.data_handle(), dis_gpu.data_handle(), k1, k2); std::int64_t is_fine = 0; RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data_handle(), sizeof(std::int64_t), @@ -506,7 +559,7 @@ class RaftIvfIndexNode : public IndexNode { k2 |= k2 >> 8; k2 |= k2 >> 14; k2 += 1; - while (k2 <= 1024) { + while (k2 <= (1 << 14)) { auto ids_gpu_before = raft::make_device_matrix(*res_, rows, k2); auto dis_gpu_before = raft::make_device_matrix(*res_, rows, k2); auto bs_gpu = raft::make_device_vector(*res_, bitset.byte_size()); @@ -517,9 +570,17 @@ class RaftIvfIndexNode : public IndexNode { *res_, search_params, *gpu_index_, raft::make_const_mdspan(data_gpu.view()), ids_gpu_before.view(), dis_gpu_before.view()); - filter<<>>( - k1, k2, rows, bs_gpu.data_handle(), ids_gpu_before.data_handle(), - dis_gpu_before.data_handle(), ids_gpu.data_handle(), dis_gpu.data_handle()); + auto ids_block = raft::make_device_vector(*res_, rows * k2); + filter<<<(rows * k2 + 1023) / 1024, 1024, 0, stream.value()>>>( + bs_gpu.data_handle(), ids_gpu_before.data_handle(), ids_block.data_handle(), rows * k2); + auto pre_ids_block = raft::make_device_vector(*res_, rows * k2); + RAFT_CUDA_TRY(cudaMemcpyAsync(pre_ids_block.data_handle(), ids_block.data_handle(), + rows * k2 * sizeof(int32_t), cudaMemcpyDefault, stream.value())); + prescan<<>>( + pre_ids_block.data_handle(), ids_block.data_handle(), k2); + write_back<<>>( + pre_ids_block.data_handle(), ids_gpu_before.data_handle(), dis_gpu_before.data_handle(), + ids_gpu.data_handle(), dis_gpu.data_handle(), k1, k2); std::int64_t is_fine = 0; RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data_handle(), sizeof(std::int64_t), @@ -530,7 +591,6 @@ class RaftIvfIndexNode : public IndexNode { k2 = k2 << 1; } } - } else { static_assert(std::is_same_v); } @@ -630,7 +690,7 @@ class RaftIvfIndexNode : public IndexNode { } // TODO(yusheng.ma):support no raw data mode /* -#define RAW_DATA "RAW_DATA" + #define RAW_DATA "RAW_DATA" auto data = binset.GetByName(RAW_DATA); raft_gpu::raw_data_copy(*this->index_, data->data.get(), data->size); */