From 881051817b0f1d71d742c04191734c00cb19b7e7 Mon Sep 17 00:00:00 2001 From: Max Wittal Date: Thu, 30 Nov 2023 20:12:19 +0700 Subject: [PATCH] mem hash update --- include/mmx/pos/mem_hash.h | 5 ++--- src/pos/mem_hash.cpp | 29 +++++++++++++-------------- src/pos/verify.cpp | 8 +++----- test/test_mem_hash.cpp | 40 +++++++++++++++++++++++++++----------- 4 files changed, 47 insertions(+), 35 deletions(-) diff --git a/include/mmx/pos/mem_hash.h b/include/mmx/pos/mem_hash.h index 30f82e209..e2b5a27a6 100644 --- a/include/mmx/pos/mem_hash.h +++ b/include/mmx/pos/mem_hash.h @@ -21,11 +21,10 @@ namespace pos { void gen_mem_array(uint32_t* mem, const uint8_t* key, const int key_size, const uint32_t mem_size); /* - * M = log2 number of iterations - * mem = array of size (32 << B) + * mem = array of size 1024 * hash = array of size 128 */ -void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B); +void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int num_iter); } // pos diff --git a/src/pos/mem_hash.cpp b/src/pos/mem_hash.cpp index c1a78092c..5cb0709cb 100644 --- a/src/pos/mem_hash.cpp +++ b/src/pos/mem_hash.cpp @@ -65,36 +65,33 @@ void gen_mem_array(uint32_t* mem, const uint8_t* key, const int key_size, const } } -void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B) +void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int num_iter) { static constexpr int N = 32; - const int num_iter = (1 << M); - const uint32_t index_mask = ((1 << B) - 1); + const uint32_t offset_mask = 31u << 5; uint32_t state[N]; for(int i = 0; i < N; ++i) { - state[i] = mem[index_mask * N + i]; + state[i] = mem[31 * N + i]; } for(int k = 0; k < num_iter; ++k) { - uint32_t tmp = 0; + uint32_t sum = 0; for(int i = 0; i < N; ++i) { - tmp += rotl_32(state[i] ^ 0x55555555, (k + i) % N); + sum += rotl_32(state[i], (k + i) % 32); } - tmp ^= 0x55555555; + const uint32_t dir = sum % 1193u; // mod by prime - const auto bits = tmp % 32; -// const auto offset = ((tmp >> 5) & index_mask) * N; - const auto offset = tmp & (index_mask << 5); + const uint32_t bits = dir % 32u; + const uint32_t offset = dir & offset_mask; - for(int i = 0; i < N; ++i) { - const int shift = (k + i) % N; - state[i] += rotl_32(mem[offset + shift] ^ state[i], bits); - } - for(int i = 0; i < N; ++i) { - mem[offset + i] = state[i]; + for(int i = 0; i < N; ++i) + { + state[i] += rotl_32(mem[offset + (k + i) % N], bits) ^ sum; + + mem[offset + i] ^= state[i]; } } diff --git a/src/pos/verify.cpp b/src/pos/verify.cpp index 3eaadbd38..a9e3ff44b 100644 --- a/src/pos/verify.cpp +++ b/src/pos/verify.cpp @@ -17,11 +17,9 @@ namespace mmx { namespace pos { -static constexpr int MEM_HASH_N = 32; -static constexpr int MEM_HASH_M = 8; -static constexpr int MEM_HASH_B = 5; +static constexpr int MEM_HASH_ITER = 256; -static constexpr uint64_t MEM_SIZE = uint64_t(MEM_HASH_N) << MEM_HASH_B; +static constexpr uint64_t MEM_SIZE = 32 * 32; static std::mutex g_mutex; static std::shared_ptr g_threads; @@ -50,7 +48,7 @@ void compute_f1(std::vector* X_tmp, gen_mem_array(mem_buf.data(), key.data(), key.size(), MEM_SIZE); uint8_t mem_hash[128 + 64] = {}; - calc_mem_hash(mem_buf.data(), mem_hash, MEM_HASH_M, MEM_HASH_B); + calc_mem_hash(mem_buf.data(), mem_hash, MEM_HASH_ITER); ::memcpy(mem_hash + 128, key.data(), key.size()); diff --git a/test/test_mem_hash.cpp b/test/test_mem_hash.cpp index d0f182345..d90c4acf5 100644 --- a/test/test_mem_hash.cpp +++ b/test/test_mem_hash.cpp @@ -18,21 +18,23 @@ using namespace mmx::pos; int main(int argc, char** argv) { const int N = 32; - const int M = 8; - const int B = 5; - const int num_iter = argc > 1 ? std::atoi(argv[1]) : 1; + const int count = argc > 1 ? std::atoi(argv[1]) : 1; + const int num_iter = argc > 2 ? std::atoi(argv[2]) : 256; - const uint64_t mem_size = uint64_t(N) << B; + const uint64_t mem_size = uint64_t(N) * N; - std::cout << "N = " << N << std::endl; - std::cout << "M = " << M << std::endl; - std::cout << "B = " << B << std::endl; + std::cout << "count = " << count << std::endl; + std::cout << "num_iter = " << num_iter << std::endl; std::cout << "mem_size = " << mem_size << " (" << mem_size * 4 / 1024 << " KiB)" << std::endl; + size_t pop_sum = 0; + size_t num_pass = 0; + size_t min_pop_count = 1024; + uint32_t* mem = new uint32_t[mem_size]; - for(int iter = 0; iter < num_iter; ++iter) + for(int iter = 0; iter < count; ++iter) { uint8_t key[32] = {}; ::memcpy(key, &iter, sizeof(iter)); @@ -42,7 +44,7 @@ int main(int argc, char** argv) if(iter == 0) { std::map init_count; - for(int k = 0; k < (1 << B); ++k) { + for(int k = 0; k < 32; ++k) { std::cout << "[" << k << "] " << std::hex; for(int i = 0; i < N; ++i) { init_count[mem[k * N + i]]++; @@ -59,11 +61,27 @@ int main(int argc, char** argv) } mmx::bytes_t<128> hash; - calc_mem_hash(mem, hash.data(), M, B); + calc_mem_hash(mem, hash.data(), num_iter); + + size_t pop = 0; + for(int i = 0; i < 1024; ++i) { + pop += (hash[i / 8] >> (i % 8)) & 1; + } + pop_sum += pop; + + min_pop_count = std::min(min_pop_count, pop); - std::cout << "[" << iter << "] " << hash << std::endl; + if(pop <= 469) { + num_pass++; + } + + std::cout << "[" << iter << "] " << hash << " (" << pop << ")" << std::endl; } + std::cout << "num_pass = " << num_pass << " (" << num_pass / double(count) << ")" << std::endl; + std::cout << "min_pop_count = " << min_pop_count << std::endl; + std::cout << "avg_pop_count = " << pop_sum / double(count) << std::endl; + delete [] mem; return 0;