diff --git a/CMakeLists.txt b/CMakeLists.txt index 2095ffd6d..51c76211d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ add_library(mmx_db SHARED add_library(mmx_pos SHARED src/pos/mem_hash.cpp + src/pos/verify.cpp ) add_library(mmx_modules SHARED @@ -237,7 +238,7 @@ target_include_directories(mmx_chiapos target_link_libraries(mmx_iface uint256_t secp256k1 bech32 bls vnx_base vnx_addons) target_link_libraries(mmx_db mmx_iface) - +target_link_libraries(mmx_pos mmx_iface) target_link_libraries(mmx_vm mmx_db mmx_iface) target_link_libraries(mmx_modules mmx_chiapos mmx_pos mmx_vm mmx_db mmx_iface) diff --git a/include/mmx/pos/mem_hash.h b/include/mmx/pos/mem_hash.h index 80313be94..6f8431166 100644 --- a/include/mmx/pos/mem_hash.h +++ b/include/mmx/pos/mem_hash.h @@ -23,7 +23,7 @@ void gen_mem_array(uint32_t* mem, const uint8_t* key, const uint64_t mem_size); /* * M = log2 number of iterations * mem = array of size (32 << B) - * hash = array of size 32 + * hash = array of size 64 */ void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B); diff --git a/include/mmx/pos/util.h b/include/mmx/pos/util.h index e3a8b5e9a..ee3c8b8f0 100644 --- a/include/mmx/pos/util.h +++ b/include/mmx/pos/util.h @@ -9,6 +9,22 @@ #define INCLUDE_MMX_POS_UTIL_H_ #include +#include + +// compiler-specific byte swap macros. +#if defined(_MSC_VER) + #include + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/byteswap-uint64-byteswap-ulong-byteswap-ushort?view=msvc-160 + inline uint16_t bswap_16(uint16_t x) { return _byteswap_ushort(x); } + inline uint32_t bswap_32(uint32_t x) { return _byteswap_ulong(x); } + inline uint64_t bswap_64(uint64_t x) { return _byteswap_uint64(x); } +#elif defined(__clang__) || defined(__GNUC__) + inline uint16_t bswap_16(uint16_t x) { return __builtin_bswap16(x); } + inline uint32_t bswap_32(uint32_t x) { return __builtin_bswap32(x); } + inline uint64_t bswap_64(uint64_t x) { return __builtin_bswap64(x); } +#else +#error "unknown compiler, don't know how to swap bytes" +#endif #define MMXPOS_QUARTERROUND(a, b, c, d) \ a = a + b; \ @@ -32,6 +48,41 @@ inline uint64_t rotl_64(const uint64_t v, int bits) { return (v << bits) | (v >> (64 - bits)); } +inline +uint64_t write_bits(uint64_t* dst, const uint64_t value, const uint64_t bit_offset, const int num_bits) +{ + const int free_bits = 64 - (bit_offset % 64); + if(free_bits >= num_bits) { + dst[bit_offset / 64] |= bswap_64(value << (free_bits - num_bits)); + } else { + const int suffix_size = num_bits - free_bits; + const uint64_t suffix = value & ((uint64_t(1) << suffix_size) - 1); + dst[bit_offset / 64] |= bswap_64(value >> suffix_size); // prefix (high bits) + dst[bit_offset / 64 + 1] |= bswap_64(suffix << (64 - suffix_size)); // suffix (low bits) + } + return bit_offset + num_bits; +} + +inline +uint64_t read_bits(const uint64_t* src, const uint64_t bit_offset, const int num_bits) +{ + int count = 0; + uint64_t offset = bit_offset; + uint64_t result = 0; + while(count < num_bits) { + const int shift = offset % 64; + const int bits = std::min(num_bits - count, 64 - shift); + uint64_t value = bswap_64(src[offset / 64]) << shift; + if(bits < 64) { + value >>= (64 - bits); + } + result |= value << (num_bits - count - bits); + count += bits; + offset += bits; + } + return result; +} + } // pos } // mmx diff --git a/include/mmx/pos/verify.h b/include/mmx/pos/verify.h index bb3c271f1..944f8ac76 100644 --- a/include/mmx/pos/verify.h +++ b/include/mmx/pos/verify.h @@ -15,16 +15,14 @@ namespace mmx { namespace pos { -std::vector>> -compute(const std::vector& X_values, std::vector* X_out, +std::vector>> +compute(const std::vector& X_values, std::vector* X_out, const uint8_t* id, const int ksize, const int xbits); -hash_t verify(const std::vector& X_values, const uint8_t* id, const hash_t& challenge, const int ksize); +hash_t verify(const std::vector& X_values, const hash_t& challenge, const uint8_t* id, const int ksize); } // pos } // mmx - - #endif /* INCLUDE_MMX_POS_VERIFY_H_ */ diff --git a/src/pos/mem_hash.cpp b/src/pos/mem_hash.cpp index 9650812a9..1d3349203 100644 --- a/src/pos/mem_hash.cpp +++ b/src/pos/mem_hash.cpp @@ -65,7 +65,6 @@ void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B) for(int i = 0; i < N; ++i) { state[i] = mem[index_mask * N + i]; } -// std::map count; for(int k = 0; k < num_iter; ++k) { @@ -78,7 +77,6 @@ void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B) const auto bits = tmp % 32; // const auto offset = ((tmp >> 5) & index_mask) * N; const auto offset = tmp & (index_mask << 5); -// count[offset]++; for(int i = 0; i < N; ++i) { const int shift = (k + i) % N; @@ -89,13 +87,7 @@ void calc_mem_hash(uint32_t* mem, uint8_t* hash, const int M, const int B) } } -// for(const auto& entry : count) { -// if(entry.second > uint32_t(3 << (M - B))) { -// std::cout << "WARN [" << entry.first << "] " << entry.second << std::endl; -// } -// } - - for(int i = 0; i < 8; ++i) { + for(int i = 0; i < 16; ++i) { for(int k = 0; k < 4; ++k) { hash[i * 4 + k] = state[i] >> (24 - k * 8); } diff --git a/src/pos/verify.cpp b/src/pos/verify.cpp index 8004f6e5c..5a6df812d 100644 --- a/src/pos/verify.cpp +++ b/src/pos/verify.cpp @@ -8,130 +8,128 @@ #include #include #include +#include namespace mmx { namespace pos { -static constexpr int N = 32; -static constexpr int M = 8; -static constexpr int B = 5; +static constexpr int N_META = 12; +static constexpr int N_TABLE = 9; +static constexpr int MEM_HASH_N = 32; +static constexpr int MEM_HASH_M = 8; +static constexpr int MEM_HASH_B = 5; -static constexpr uint64_t MEM_SIZE = uint64_t(N) << B; +static constexpr uint64_t MEM_SIZE = uint64_t(MEM_HASH_N) << MEM_HASH_B; -int get_meta_count(const int table) + +std::vector>> +compute(const std::vector& X_values, std::vector* X_out, const uint8_t* id, const int ksize, const int xbits) { - if(table < 1) { - return 0; + if(ksize < 8 || ksize > 32) { + throw std::logic_error("invalid ksize"); } - switch(table) { - case 1: return 2; - case 2: return 4; - case 3: return 8; + if(xbits > ksize) { + throw std::logic_error("invalid xbits"); } - return 12; -} - -std::vector>> -compute(const std::vector& X_values, std::vector* X_out, const uint8_t* id, const int ksize, const int xbits) -{ std::vector mem_buf(MEM_SIZE); - std::vector X_tmp; - std::vector>> LR(8); - std::vector>> entries; + std::vector X_tmp; + std::vector> M_tmp; + std::vector> entries; + std::vector>> LR(N_TABLE - 1); - for(const uint64_t X : X_values) + const uint32_t kmask = ((uint32_t(1) << ksize) - 1); + const uint32_t num_entries_1 = X_values.size() << xbits; + + if(X_out) { + X_tmp.reserve(num_entries_1); + } + M_tmp.reserve(num_entries_1); + + for(const uint32_t X : X_values) { for(uint32_t i = 0; i < (uint32_t(1) << xbits); ++i) { - const uint64_t X_i = (X << xbits) | i; + const uint32_t X_i = (X << xbits) | i; if(X_out) { X_tmp.push_back(X_i); } - uint64_t msg[5] = {}; - msg[0] = bswap_64(X_i); + uint32_t msg[9] = {}; + msg[0] = X_i; ::memcpy(msg + 1, id, 32); - const hash_t hash(&msg, 8 + 32); + const hash_t hash(&msg, 4 + 32); gen_mem_array(mem_buf.data(), hash.data(), MEM_SIZE); - uint64_t mem_hash[4] = {}; - calc_mem_hash(mem_buf.data(), (uint8_t*)hash, M, B); + uint32_t mem_hash[16] = {}; + calc_mem_hash(mem_buf.data(), (uint8_t*)mem_hash, MEM_HASH_M, MEM_HASH_B); - const uint64_t Y_i = read_bits(mem_hash, 0, ksize); + const uint32_t Y_i = mem_hash[0] & kmask; - std::array meta = {}; - for(int k = 0; k < get_meta_count(1); ++k) - { - const auto tmp = read_bits(mem_hash, (k + 1) * ksize, ksize); - write_bits(meta.data(), tmp, k * ksize, ksize); + std::array meta = {}; + for(int k = 0; k < N_META; ++k) { + meta[k] = mem_hash[1 + k] & kmask; } - entries.emplace_back(Y_i, X_tmp.size(), meta); + entries.emplace_back(Y_i, M_tmp.size()); + M_tmp.push_back(meta); } } - for(int t = 2; t <= 9; ++t) + for(int t = 2; t <= N_TABLE; ++t) { - std::vector>> matches; + std::vector> M_next; + std::vector> matches; - const int meta_count_in = get_meta_count(t - 1); - const int meta_count_out = get_meta_count(t); + const size_t num_entries = entries.size(); std::sort(entries.begin(), entries.end()); - for(auto iter = entries.begin(); iter != entries.end(); ++iter) + for(size_t x = 0; x < num_entries; ++x) { - const uint64_t YL = std::get<0>(*iter); - const uint64_t PL = std::get<1>(*iter); - const uint64_t* L_meta = (const uint64_t*)std::get<2>(*iter).data(); + const auto YL = std::get<0>(entries[x]); - for(auto iter2 = iter; iter2 != entries.end(); ++iter2) + for(size_t y = x + 1; y < num_entries; ++y) { - const uint64_t YR = std::get<0>(*iter2); - const uint64_t PR = std::get<1>(*iter2); - const uint64_t* R_meta = (const uint64_t*)std::get<2>(*iter2).data(); + const auto YR = std::get<0>(entries[y]); if(YR == YL + 1) { - uint64_t hash[8] = {}; + const auto PL = std::get<1>(entries[x]); + const auto PR = std::get<1>(entries[y]); + const auto& L_meta = M_tmp[PL]; + const auto& R_meta = M_tmp[PR]; - for(int i = 0; i < (t < 3 ? 1 : 2); ++i) - { - uint64_t msg[16] = {}; - auto bit_offset = write_bits(msg, (t << 4) | i, 0, 8); + uint32_t hash[16] = {}; - bit_offset = write_bits(msg, YL, bit_offset, ksize); + for(int i = 0; i < 2; ++i) + { + uint32_t msg[2 + 2 * N_META] = {}; + msg[0] = (t << 8) | i; + msg[1] = YL; - for(int k = 0; k < meta_count_in; ++k) - { - const auto tmp = read_bits(L_meta, k * ksize, ksize); - bit_offset = write_bits(msg, tmp, bit_offset, ksize); + for(int k = 0; k < N_META; ++k) { + msg[2 + k] = L_meta[k]; } - for(int k = 0; k < meta_count_in; ++k) - { - const auto tmp = read_bits(R_meta, k * ksize, ksize); - bit_offset = write_bits(msg, tmp, bit_offset, ksize); + for(int k = 0; k < N_META; ++k) { + msg[2 + N_META + k] = R_meta[k]; } - const hash_t hash_i(&msg, (bit_offset + 7) / 8); + const hash_t hash_i(&msg, sizeof(msg)); - ::memcpy(hash + i * 4, hash_i.data(), 32); + ::memcpy(hash + i * 8, hash_i.data(), 32); } - const uint64_t Y_new = read_bits(hash, 0, ksize); + const uint32_t Y_i = hash[0] & kmask; - uint64_t bit_offset = 0; - std::array meta = {}; - - for(int k = 0; k < meta_count_out; ++k) - { - const auto tmp = read_bits(hash, (k + 1) * ksize, ksize); - bit_offset = write_bits(meta.data(), tmp, bit_offset, ksize); + std::array meta = {}; + for(int k = 0; k < N_META; ++k) { + meta[k] = hash[1 + k] & kmask; } - matches.emplace_back(Y_new, LR[t-2].size(), meta); + matches.emplace_back(Y_i, M_next.size()); if(X_out) { LR[t-2].emplace_back(PL, PR); } + M_next.push_back(meta); } else if(YR > YL) { break; @@ -142,55 +140,60 @@ compute(const std::vector& X_values, std::vector* X_out, con if(matches.empty()) { throw std::logic_error("zero matches at table " + std::to_string(t)); } + M_tmp = std::move(M_next); entries = std::move(matches); } - std::vector>> out; + std::vector>> out; for(const auto& entry : entries) { if(X_out) { std::vector I_tmp; I_tmp.push_back(std::get<1>(entry)); - for(int t = 7; t >= 0; --t) + for(int k = N_TABLE - 2; k >= 0; --k) { std::vector I_next; for(const auto i : I_tmp) { - I_next.push_back(LR[t][i].first); - I_next.push_back(LR[t][i].second); + I_next.push_back(LR[k][i].first); + I_next.push_back(LR[k][i].second); } I_tmp = std::move(I_next); } - X_out->reserve(I_tmp.size()); - for(const auto i : I_tmp) { X_out->push_back(X_tmp[i]); } } const auto& Y = std::get<0>(entry); - const auto& meta = std::get<2>(entry); + const auto& I = std::get<1>(entry); + const auto& meta = M_tmp[I]; out.emplace_back(Y, bytes_t<48>(meta.data(), sizeof(meta))); } return out; } -hash_t verify(const std::vector& X_values, const uint8_t* id, const hash_t& challenge, const int ksize) +hash_t verify(const std::vector& X_values, const hash_t& challenge, const uint8_t* id, const int ksize) { - const auto entries = compute(X_values, nullptr, id, ksize, 0); - + std::vector X_out; + const auto entries = compute(X_values, &X_out, id, ksize, 0); if(entries.empty()) { throw std::logic_error("invalid proof"); } const auto& result = entries[0]; - uint64_t tmp[4] = {}; - ::memcpy(tmp, challenge.data(), 32); + const uint32_t kmask = ((uint32_t(1) << ksize) - 1); - const auto Y_challenge = read_bits(tmp, 0, ksize); + uint32_t tmp = {}; + ::memcpy(&tmp, challenge.data(), sizeof(tmp)); + + const auto Y_challenge = tmp & kmask; if(result.first != Y_challenge) { throw std::logic_error("Y output value != challenge"); } - return hash_t(std::string("proof_quality") + result.second + challenge); + if(X_out != X_values) { + throw std::logic_error("invalid proof order"); + } + return hash_t(result.second + challenge); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bd506d6bb..25cc163b5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,6 +7,7 @@ add_executable(test_transactions test_transactions.cpp) add_executable(test_swap_algo test_swap_algo.cpp) add_executable(test_database_reads test_database_reads.cpp) add_executable(test_mem_hash test_mem_hash.cpp) +add_executable(test_pos_compute test_pos_compute.cpp) add_executable(database_fill database_fill.cpp) add_executable(mmx_tests mmx_tests.cpp) @@ -21,6 +22,7 @@ target_link_libraries(test_swap_algo mmx_iface) target_link_libraries(test_database_reads mmx_db mmx_iface) target_link_libraries(database_fill mmx_db mmx_iface) target_link_libraries(test_mem_hash vnx_base mmx_pos) +target_link_libraries(test_pos_compute mmx_iface mmx_pos) target_link_libraries(mmx_tests mmx_iface) target_link_libraries(vm_engine_tests mmx_vm) diff --git a/test/test_mem_hash.cpp b/test/test_mem_hash.cpp index e68e86d41..e5c4c0582 100644 --- a/test/test_mem_hash.cpp +++ b/test/test_mem_hash.cpp @@ -57,7 +57,7 @@ int main(int argc, char** argv) } } } - mmx::bytes_t<32> hash; + mmx::bytes_t<64> hash; calc_mem_hash(mem, hash.data(), M, B); diff --git a/test/test_pos_compute.cpp b/test/test_pos_compute.cpp new file mode 100644 index 000000000..e435f2d4f --- /dev/null +++ b/test/test_pos_compute.cpp @@ -0,0 +1,53 @@ +/* + * test_pos_compute.cpp + * + * Created on: Nov 6, 2023 + * Author: mad + */ + +#include +#include + +using namespace mmx; + + +int main(int argc, char** argv) +{ + const int ksize = argc > 1 ? std::atoi(argv[1]) : 32; + const int xbits = argc > 2 ? std::atoi(argv[2]) : 0; + + std::cout << "ksize = " << ksize << std::endl; + std::cout << "xbits = " << xbits << std::endl; + + uint8_t id[32] = {}; + + std::mt19937_64 generator(1337); + std::uniform_int_distribution dist(0, (1 << (ksize - xbits)) - 1); + + std::vector X_values; + if(xbits < ksize) { + for(int i = 0; i < 256; ++i) { + X_values.push_back(dist(generator)); + } + } else { + X_values.push_back(0); + } + + std::cout << "X = "; + for(auto X : X_values) { + std::cout << X << " "; + } + std::cout << std::endl; + + std::vector X_out; + const auto res = pos::compute(X_values, &X_out, id, ksize, xbits); + + for(const auto& entry : res) { + std::cout << "Y = " << entry.first << std::endl; + std::cout << "M = " << entry.second.to_string() << std::endl; + } + std::cout << "num_entries = " << res.size() << std::endl; + +} + +