Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
raft index support bitset filter (#850)
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger authored May 4, 2023
1 parent 6448cf0 commit 78d3351
Showing 1 changed file with 113 additions and 6 deletions.
119 changes: 113 additions & 6 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,45 @@

namespace knowhere {

__global__ void
filter(const int k1, const int k2, const int nq, const uint8_t* bs, int64_t* ids_before, float* dis_before,
int64_t* ids, float* dis) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int ty = blockIdx.y * blockDim.y + threadIdx.y;
extern __shared__ char s[];
int64_t* ids_ = (int64_t*)s;
float* dis_ = (float*)&s[k2 * sizeof(int64_t)];
if (tx >= k2)
return;
int64_t i = ids_before[ty * k2 + tx];
float d = dis_before[ty * k2 + tx];
bool check = bs[i >> 3] & (0x1 << (i & 0x7));
if (!check) {
ids_[tx] = i;
dis_[tx] = d;
} else {
ids_[tx] = -1;
dis_[tx] = -1.0f;
}
__syncthreads();
if (tx == 0) {
int j = 0, k = 0;
while (j < k1 && k < k2) {
while (ids_[k] == -1) k++;
if (k >= k2)
break;
ids[ty * k1 + j] = ids_[k];
dis[ty * k1 + j] = dis_[k];
j++;
k++;
}
if (j != k1) {
ids_before[0] = -1; // destroy answer
}
}
__syncthreads();
}

namespace raft_res_pool {

struct context {
Expand Down Expand Up @@ -381,9 +420,42 @@ class RaftIvfIndexNode : public IndexNode {
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto search_params = raft::neighbors::ivf_flat::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
if (bitset.empty()) {
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
k2 |= k2 >> 1;
k2 |= k2 >> 2;
k2 |= k2 >> 4;
k2 |= k2 >> 8;
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_flat::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());
filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
break;
k2 = k2 << 1;
}
}
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto search_params = raft::neighbors::ivf_pq::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
Expand Down Expand Up @@ -411,9 +483,44 @@ class RaftIvfIndexNode : public IndexNode {
}
search_params.internal_distance_dtype = internal_distance_dtype.value();
search_params.preferred_shmem_carveout = search_params.preferred_shmem_carveout;
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_, data_gpu.data(),
rows, ivf_raft_cfg.k, ids_gpu.data(),
dis_gpu.data());
if (bitset.empty()) {
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
} else {
auto k1 = ivf_raft_cfg.k;
auto k2 = k1;
k2 |= k2 >> 1;
k2 |= k2 >> 2;
k2 |= k2 >> 4;
k2 |= k2 >> 8;
k2 |= k2 >> 14;
k2 += 1;
while (k2 <= 1024) {
auto ids_gpu_before = rmm::device_uvector<std::int64_t>(k2 * rows, stream);
auto dis_gpu_before = rmm::device_uvector<float>(k2 * rows, stream);
auto bs_gpu = rmm::device_uvector<uint8_t>(bitset.byte_size(), stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(bs_gpu.data(), bitset.data(), bitset.byte_size(),
cudaMemcpyDefault, stream.value()));

raft::neighbors::ivf_pq::search<float, std::int64_t>(
*res_, search_params, *gpu_index_, data_gpu.data(), rows, k2, ids_gpu_before.data(),
dis_gpu_before.data());

filter<<<dim3(1, rows), k2, k2 * sizeof(std::int64_t) + k2 * sizeof(float), stream.value()>>>(
k1, k2, rows, bs_gpu.data(), ids_gpu_before.data(), dis_gpu_before.data(), ids_gpu.data(),
dis_gpu.data());

std::int64_t is_fine = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&is_fine, ids_gpu_before.data(), sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
stream.synchronize();
if (is_fine != -1)
break;
k2 = k2 << 1;
}
}

} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down

0 comments on commit 78d3351

Please sign in to comment.