Skip to content

Commit

Permalink
[Opt] let LRU mode use device clock
Browse files Browse the repository at this point in the history
- Remove `cur_score` from `Bucket`
- Switch benchmark to LRU strategy
  • Loading branch information
rhdong committed Jun 7, 2023
1 parent a16c6e1 commit 460db25
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 49 deletions.
31 changes: 16 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Your environment must meet the following requirements:
* Key Type = uint64_t
* Value Type = float32 * {dim}
* Key-Values per OP = 1048576
* Evict strategy: LRU
* `λ`: load factor
* `find*` means the `find` API that directly returns the addresses of values.
* `find_or_insert*` means the `find_or_insert` API that directly returns the addresses of values.
Expand All @@ -131,37 +132,37 @@ Your environment must meet the following requirements:

* dim = 4, capacity = 64 Million-KV, HBM = 32 GB, HMEM = 0 GB

| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|------:|---------------:|-------:|------:|----------------:|-----------------:|
| 0.50 | 1.397 | 2.923 | 1.724 | 1.945 | 3.609 | 1.756 | 1.158 |
| 0.75 | 1.062 | 1.607 | 0.615 | 0.910 | 1.836 | 1.175 | 0.900 |
| 1.00 | 0.352 | 0.826 | 0.342 | 0.551 | 0.894 | 0.357 | 0.302 |
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
| 0.50 | 1.418 | 3.616 | 1.925 | 1.973 | 4.522 | 1.943 | 1.186 |
| 0.75 | 1.095 | 1.829 | 0.686 | 0.915 | 2.106 | 1.291 | 0.923 |
| 1.00 | 0.360 | 0.887 | 0.362 | 0.546 | 0.963 | 0.380 | 0.311 |

* dim = 64, capacity = 64 Million-KV, HBM = 16 GB, HMEM = 0 GB

| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|------:|---------------:|-------:|------:|----------------:|-----------------:|
| 0.50 | 0.924 | 1.587 | 0.888 | 1.125 | 3.628 | 1.756 | 0.789 |
| 0.75 | 0.662 | 1.115 | 0.540 | 0.833 | 1.844 | 1.177 | 0.566 |
| 1.00 | 0.323 | 0.642 | 0.314 | 0.512 | 0.897 | 0.358 | 0.177 |
| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* | insert_and_evict |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|-----------------:|
| 0.50 | 0.943 | 1.766 | 0.936 | 1.134 | 4.569 | 1.954 | 0.806 |
| 0.75 | 0.675 | 1.216 | 0.589 | 0.825 | 2.107 | 1.293 | 0.577 |
| 1.00 | 0.328 | 0.678 | 0.329 | 0.503 | 0.963 | 0.380 | 0.179 |

### On HBM+HMEM hybrid mode:

* dim = 64, capacity = 128 Million-KV, HBM = 16 GB, HMEM = 16 GB

| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
| 0.50 | 0.122 | 0.149 | 0.120 | 0.148 | 3.414 | 1.690 |
| 0.75 | 0.117 | 0.145 | 0.115 | 0.143 | 1.808 | 1.161 |
| 1.00 | 0.088 | 0.125 | 0.087 | 0.114 | 0.884 | 0.355 |
| 0.50 | 0.121 | 0.150 | 0.121 | 0.147 | 4.254 | 1.875 |
| 0.75 | 0.116 | 0.146 | 0.116 | 0.143 | 2.054 | 1.281 |
| 1.00 | 0.088 | 0.126 | 0.088 | 0.114 | 0.949 | 0.377 |

* dim = 64, capacity = 1024 Million-KV, HBM = 56 GB, HMEM = 200 GB

| λ | insert_or_assign | find | find_or_insert | assign | find* | find_or_insert* |
|-----:|-----------------:|-------:|---------------:|-------:|-------:|----------------:|
| 0.50 | 0.037 | 0.053 | 0.034 | 0.050 | 2.822 | 1.715 |
| 0.75 | 0.036 | 0.053 | 0.033 | 0.049 | 1.920 | 1.082 |
| 1.00 | 0.032 | 0.049 | 0.030 | 0.044 | 0.855 | 0.351 |
| 0.75 | 0.027 | 0.040 | 0.025 | 0.037 | 1.744 | 0.905 |
| 1.00 | 0.033 | 0.050 | 0.030 | 0.044 | 0.917 | 0.373 |

### Support and Feedback:

Expand Down
11 changes: 2 additions & 9 deletions benchmark/merlin_hashtable_benchmark.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ float test_one_api(const API_Select api, const size_t dim,
options.dim = dim;
options.max_hbm_for_vectors = nv::merlin::GB(hbm4values);
options.io_by_cpu = io_by_cpu;
options.evict_strategy = EvictStrategy::kCustomized;
options.evict_strategy = EvictStrategy::kLru;

std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);
Expand All @@ -203,7 +203,6 @@ float test_one_api(const API_Select api, const size_t dim,
S* d_evict_scores;

CUDA_CHECK(cudaMalloc(&d_keys, key_num_per_op * sizeof(K)));
CUDA_CHECK(cudaMalloc(&d_scores, key_num_per_op * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_vectors, key_num_per_op * sizeof(V) * options.dim));
CUDA_CHECK(cudaMalloc(&d_def_val, key_num_per_op * sizeof(V) * options.dim));
CUDA_CHECK(cudaMalloc(&d_vectors_ptr, key_num_per_op * sizeof(V*)));
Expand Down Expand Up @@ -239,8 +238,6 @@ float test_one_api(const API_Select api, const size_t dim,
create_continuous_keys<K, S>(h_keys, h_scores, key_num_cur_insert, start);
CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_cur_insert * sizeof(K),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_scores, h_scores, key_num_cur_insert * sizeof(S),
cudaMemcpyHostToDevice));
table->insert_or_assign(key_num_cur_insert, d_keys, d_vectors, d_scores,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -258,8 +255,6 @@ float test_one_api(const API_Select api, const size_t dim,
create_continuous_keys<K, S>(h_keys, h_scores, key_num_append, start);
CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_append * sizeof(K),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_scores, h_scores, key_num_append * sizeof(S),
cudaMemcpyHostToDevice));
table->insert_or_assign(key_num_append, d_keys, d_vectors, d_scores,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand Down Expand Up @@ -344,8 +339,6 @@ float test_one_api(const API_Select api, const size_t dim,
Hit_Mode::last_insert, start, true /*reset*/);
CUDA_CHECK(cudaMemcpy(d_keys, h_keys, key_num_per_op * sizeof(K),
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(d_scores, h_scores, key_num_per_op * sizeof(S),
cudaMemcpyHostToDevice));
auto timer = benchmark::Timer<double>();
switch (api) {
case API_Select::find: {
Expand Down Expand Up @@ -433,7 +426,6 @@ float test_one_api(const API_Select api, const size_t dim,
CUDA_CHECK(cudaFreeHost(h_found));

CUDA_CHECK(cudaFree(d_keys));
CUDA_CHECK(cudaFree(d_scores));
CUDA_CHECK(cudaFree(d_vectors));
CUDA_CHECK(cudaFree(d_def_val));
CUDA_CHECK(cudaFree(d_vectors_ptr));
Expand Down Expand Up @@ -557,6 +549,7 @@ int main() {
<< "* Key Type = uint64_t" << endl
<< "* Value Type = float32 * {dim}" << endl
<< "* Key-Values per OP = " << key_num_per_op << endl
<< "* Evict strategy: LRU" << endl
<< "* `λ`: load factor" << endl
<< "* `find*` means the `find` API that directly returns the addresses "
"of values."
Expand Down
7 changes: 2 additions & 5 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ __global__ void create_atomic_scores(Bucket<K, V, S>* __restrict buckets,
new (buckets[start + tid].scores(i))
AtomicScore<S>{static_cast<S>(EMPTY_SCORE)};
}
new (&(buckets[start + tid].cur_score))
AtomicScore<S>{static_cast<S>(EMPTY_SCORE)};
new (&(buckets[start + tid].min_score))
AtomicScore<S>{static_cast<S>(EMPTY_SCORE)};
new (&(buckets[start + tid].min_pos)) AtomicPos<int>{1};
Expand Down Expand Up @@ -1064,9 +1062,8 @@ __forceinline__ __device__ void update_score(Bucket<K, V, S>* __restrict bucket,
const S* __restrict scores,
const int key_idx) {
if (scores == nullptr) {
S cur_score =
bucket->cur_score.fetch_add(1, cuda::std::memory_order_relaxed) + 1;
bucket->scores(key_pos)->store(cur_score, cuda::std::memory_order_relaxed);
bucket->scores(key_pos)->store(device_nano<S>(),
cuda::std::memory_order_relaxed);
} else {
bucket->scores(key_pos)->store(scores[key_idx],
cuda::std::memory_order_relaxed);
Expand Down
5 changes: 0 additions & 5 deletions include/merlin/types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ struct Bucket {
AtomicScore<S>* scores_;
V* vectors; // Pinned memory or HBM

/* For upsert_kernel without user specified scores
recording the current score, the cur_score will
increment by 1 when a new inserting happens. */
AtomicScore<S> cur_score;

/* min_score and min_pos is for or upsert_kernel
with user specified score. They record the minimum
score and its pos in the bucket. */
Expand Down
7 changes: 7 additions & 0 deletions include/merlin/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ __inline__ __device__ uint64_t atomicAdd(uint64_t* address,
namespace nv {
namespace merlin {

template <class S>
static __forceinline__ __device__ S device_nano() {
S mclk;
asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(mclk));
return mclk;
}

inline void __cudaCheckError(const char* file, const int line) {
#ifdef CUDA_ERROR_CHECK
cudaError err = cudaGetLastError();
Expand Down
16 changes: 11 additions & 5 deletions tests/find_or_insert_ptr_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,8 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {

size_t total_size = 0;
size_t dump_counter = 0;
S start_ts;
S end_ts;
for (int i = 0; i < TEST_TIMES; i++) {
std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);
Expand All @@ -1530,11 +1532,13 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMalloc(&d_vectors_ptr, BASE_KEY_NUM * sizeof(V*)));
test_util::array2ptr(d_vectors_ptr, d_vectors_temp, options.dim,
BASE_KEY_NUM, stream);
start_ts = test_util::host_nano<S>(stream);
table->find_or_insert(BASE_KEY_NUM, d_keys_temp, d_vectors_ptr, d_found,
nullptr, stream);
test_util::read_or_write_ptr(d_vectors_ptr, d_vectors_temp, d_found,
options.dim, BASE_KEY_NUM, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
end_ts = test_util::host_nano<S>(stream);
CUDA_CHECK(cudaFree(d_vectors_ptr));
CUDA_CHECK(cudaFree(d_found));
}
Expand All @@ -1560,8 +1564,8 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
std::array<S, BASE_KEY_NUM> h_scores_temp_sorted(h_scores_temp);
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());

ASSERT_TRUE(
(h_scores_temp_sorted == test_util::range<S, TEMP_KEY_NUM>(1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[BASE_KEY_NUM - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand All @@ -1578,6 +1582,7 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_test.data(),
TEST_KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));
start_ts = test_util::host_nano<S>(stream);
table->assign(TEST_KEY_NUM, d_keys_temp, d_vectors_temp, nullptr, stream);

{
Expand All @@ -1592,6 +1597,7 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
test_util::read_or_write_ptr(d_vectors_ptr, d_vectors_temp, d_found,
options.dim, TEST_KEY_NUM, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
end_ts = test_util::host_nano<S>(stream);
CUDA_CHECK(cudaFree(d_vectors_ptr));
CUDA_CHECK(cudaFree(d_found));
}
Expand Down Expand Up @@ -1622,14 +1628,14 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
ASSERT_GT(h_scores_temp[i], BUCKET_MAX_SIZE);
h_scores_temp_sorted[ctr++] = h_scores_temp[i];
} else {
ASSERT_LE(h_scores_temp[i], BUCKET_MAX_SIZE);
ASSERT_LE(h_scores_temp[i], start_ts);
}
}
std::sort(h_scores_temp_sorted.begin(),
h_scores_temp_sorted.begin() + ctr);

ASSERT_TRUE((h_scores_temp_sorted ==
test_util::range<S, TEST_KEY_NUM>(BUCKET_MAX_SIZE + 1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[ctr - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand Down
14 changes: 9 additions & 5 deletions tests/find_or_insert_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1342,9 +1342,11 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_base.data(),
BASE_KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));
S start_ts = test_util::host_nano<S>(stream);
table->find_or_insert(BASE_KEY_NUM, d_keys_temp, d_vectors_temp, nullptr,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
S end_ts = test_util::host_nano<S>(stream);

size_t total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -1365,8 +1367,8 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
std::array<S, BASE_KEY_NUM> h_scores_temp_sorted(h_scores_temp);
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());

ASSERT_TRUE(
(h_scores_temp_sorted == test_util::range<S, TEMP_KEY_NUM>(1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[TEST_KEY_NUM - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand All @@ -1383,10 +1385,12 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_test.data(),
TEST_KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));
S start_ts = test_util::host_nano<S>(stream);
table->assign(TEST_KEY_NUM, d_keys_temp, d_vectors_temp, nullptr, stream);
table->find_or_insert(TEST_KEY_NUM, d_keys_temp, d_vectors_temp, nullptr,
stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
S end_ts = test_util::host_nano<S>(stream);

size_t total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -1412,14 +1416,14 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
ASSERT_GT(h_scores_temp[i], BUCKET_MAX_SIZE);
h_scores_temp_sorted[ctr++] = h_scores_temp[i];
} else {
ASSERT_LE(h_scores_temp[i], BUCKET_MAX_SIZE);
ASSERT_LE(h_scores_temp[i], start_ts);
}
}
std::sort(h_scores_temp_sorted.begin(),
h_scores_temp_sorted.begin() + ctr);

ASSERT_TRUE((h_scores_temp_sorted ==
test_util::range<S, TEST_KEY_NUM>(BUCKET_MAX_SIZE + 1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[ctr - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand Down
14 changes: 9 additions & 5 deletions tests/merlin_hashtable_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1241,9 +1241,11 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_base.data(),
BASE_KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));
S start_ts = test_util::host_nano<S>(stream);
table->insert_or_assign(BASE_KEY_NUM, d_keys_temp, d_vectors_temp,
nullptr, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
S end_ts = test_util::host_nano<S>(stream);

size_t total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -1264,8 +1266,8 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
std::array<S, BASE_KEY_NUM> h_scores_temp_sorted(h_scores_temp);
std::sort(h_scores_temp_sorted.begin(), h_scores_temp_sorted.end());

ASSERT_TRUE(
(h_scores_temp_sorted == test_util::range<S, TEMP_KEY_NUM>(1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[BASE_KEY_NUM - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand All @@ -1282,9 +1284,11 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
CUDA_CHECK(cudaMemcpy(d_vectors_temp, h_vectors_test.data(),
TEST_KEY_NUM * sizeof(V) * options.dim,
cudaMemcpyHostToDevice));
S start_ts = test_util::host_nano<S>(stream);
table->insert_or_assign(TEST_KEY_NUM, d_keys_temp, d_vectors_temp,
nullptr, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
S end_ts = test_util::host_nano<S>(stream);

size_t total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -1310,14 +1314,14 @@ void test_evict_strategy_lru_basic(size_t max_hbm_for_vectors) {
ASSERT_GT(h_scores_temp[i], BUCKET_MAX_SIZE);
h_scores_temp_sorted[ctr++] = h_scores_temp[i];
} else {
ASSERT_LE(h_scores_temp[i], BUCKET_MAX_SIZE);
ASSERT_LE(h_scores_temp[i], start_ts);
}
}
std::sort(h_scores_temp_sorted.begin(),
h_scores_temp_sorted.begin() + ctr);

ASSERT_TRUE((h_scores_temp_sorted ==
test_util::range<S, TEST_KEY_NUM>(BUCKET_MAX_SIZE + 1)));
ASSERT_GE(h_scores_temp_sorted[0], start_ts);
ASSERT_LE(h_scores_temp_sorted[ctr - 1], end_ts);
for (int i = 0; i < dump_counter; i++) {
for (int j = 0; j < options.dim; j++) {
ASSERT_EQ(h_vectors_temp[i * options.dim + j],
Expand Down
21 changes: 21 additions & 0 deletions tests/test_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@

namespace test_util {

template <class S>
__global__ void host_nano_kernel(S* d_clk) {
S mclk;
asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(mclk));
*d_clk = mclk;
}

template <class S>
S host_nano(cudaStream_t stream = 0) {
S h_clk = 0;
S* d_clk;

CUDA_CHECK(cudaMalloc((void**)&(d_clk), sizeof(S)));
host_nano_kernel<S><<<1, 1, 0, stream>>>(d_clk);
CUDA_CHECK(cudaStreamSynchronize(stream));

CUDA_CHECK(cudaMemcpy(&h_clk, d_clk, sizeof(S), cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(d_clk));
return h_clk;
}

__global__ void all_true(const bool* conds, size_t n, int* nfalse) {
const size_t stripe =
(n + gridDim.x - 1) /
Expand Down

0 comments on commit 460db25

Please sign in to comment.