Skip to content

Commit

Permalink
multi-threaded pos compute
Browse files Browse the repository at this point in the history
  • Loading branch information
madMAx43v3r committed Nov 6, 2023
1 parent a81ddb3 commit 6688e44
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 54 deletions.
148 changes: 102 additions & 46 deletions src/pos/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <mmx/pos/verify.h>
#include <mmx/pos/mem_hash.h>
#include <mmx/hash_t.hpp>
#include <tuple>

#include <vnx/vnx.h>
#include <vnx/ThreadPool.h>


namespace mmx {
Expand All @@ -22,81 +24,131 @@ static constexpr int MEM_HASH_B = 5;

static constexpr uint64_t MEM_SIZE = uint64_t(MEM_HASH_N) << MEM_HASH_B;

static std::mutex g_mutex;
static std::shared_ptr<vnx::ThreadPool> g_threads;


void compute_f1(std::vector<uint32_t>* X_tmp,
std::vector<std::array<uint32_t, N_META>>& M_tmp,
std::vector<std::pair<uint32_t, uint32_t>>& entries,
std::mutex& mutex,
const uint32_t X,
const uint8_t* id, const int ksize, const int xbits)
{
const uint32_t kmask = ((uint64_t(1) << ksize) - 1);

std::vector<uint32_t> mem_buf(MEM_SIZE);

for(uint32_t i = 0; i < (uint32_t(1) << xbits); ++i)
{
const uint32_t X_i = (X << xbits) | i;

uint32_t msg[9] = {};
msg[0] = X_i;
::memcpy(msg + 1, id, 32);

const hash_t hash(&msg, 4 + 32);
gen_mem_array(mem_buf.data(), hash.data(), MEM_SIZE);

uint32_t mem_hash[16] = {};
calc_mem_hash(mem_buf.data(), (uint8_t*)mem_hash, MEM_HASH_M, MEM_HASH_B);

const uint32_t Y_i = mem_hash[0] & kmask;

std::array<uint32_t, N_META> meta = {};
for(int k = 0; k < N_META; ++k) {
meta[k] = mem_hash[1 + k] & kmask;
}
std::lock_guard<std::mutex> lock(mutex);

if(X_tmp) {
X_tmp->push_back(X_i);
}
entries.emplace_back(Y_i, M_tmp.size());
M_tmp.push_back(meta);
}
}

std::vector<std::pair<uint32_t, bytes_t<48>>>
compute(const std::vector<uint32_t>& X_values, std::vector<uint32_t>* X_out, const uint8_t* id, const int ksize, const int xbits)
{
if(ksize < 8 || ksize > 32) {
throw std::logic_error("invalid ksize");
}
if(xbits > ksize) {
if(xbits > ksize - 8 && xbits != ksize) {
throw std::logic_error("invalid xbits");
}
std::vector<uint32_t> mem_buf(MEM_SIZE);
const bool use_threads = (xbits > 6);

if(use_threads) {
std::lock_guard<std::mutex> lock(g_mutex);
if(!g_threads) {
const auto cpu_threads = std::thread::hardware_concurrency();
const auto num_threads = cpu_threads > 0 ? cpu_threads : 16;
g_threads = std::make_shared<vnx::ThreadPool>(num_threads, 1024);
vnx::log_info() << "Using " << num_threads << " CPU threads (for proof compute)";
}
}
const uint32_t kmask = ((uint64_t(1) << ksize) - 1);
const uint32_t num_entries_1 = X_values.size() << xbits;

std::mutex mutex;
std::vector<int64_t> jobs;
std::vector<uint32_t> X_tmp;
std::vector<uint32_t> mem_buf(MEM_SIZE);
std::vector<std::array<uint32_t, N_META>> M_tmp;
std::vector<std::tuple<uint32_t, uint32_t>> entries;
std::vector<std::vector<std::pair<uint32_t, uint32_t>>> LR(N_TABLE - 1);
std::vector<std::pair<uint32_t, uint32_t>> entries;
std::vector<std::vector<std::pair<uint32_t, uint32_t>>> LR_tmp(N_TABLE - 1);

const uint32_t kmask = ((uint32_t(1) << ksize) - 1);
const uint32_t num_entries_1 = X_values.size() << xbits;
// const auto t1_begin = vnx::get_time_millis();

if(X_out) {
X_tmp.reserve(num_entries_1);
}
M_tmp.reserve(num_entries_1);
entries.reserve(num_entries_1);

for(const uint32_t X : X_values)
{
for(uint32_t i = 0; i < (uint32_t(1) << xbits); ++i)
{
const uint32_t X_i = (X << xbits) | i;

if(X_out) {
X_tmp.push_back(X_i);
}
uint32_t msg[9] = {};
msg[0] = X_i;
::memcpy(msg + 1, id, 32);

const hash_t hash(&msg, 4 + 32);
gen_mem_array(mem_buf.data(), hash.data(), MEM_SIZE);

uint32_t mem_hash[16] = {};
calc_mem_hash(mem_buf.data(), (uint8_t*)mem_hash, MEM_HASH_M, MEM_HASH_B);

const uint32_t Y_i = mem_hash[0] & kmask;

std::array<uint32_t, N_META> meta = {};
for(int k = 0; k < N_META; ++k) {
meta[k] = mem_hash[1 + k] & kmask;
}
entries.emplace_back(Y_i, M_tmp.size());
M_tmp.push_back(meta);
if(use_threads) {
const auto job = g_threads->add_task(
[X_out, &X_tmp, &M_tmp, &entries, &mutex, X, id, ksize, xbits]() {
compute_f1(X_out ? &X_tmp : nullptr, M_tmp, entries, mutex, X, id, ksize, xbits);
});
jobs.push_back(job);
} else {
compute_f1(X_out ? &X_tmp : nullptr, M_tmp, entries, mutex, X, id, ksize, xbits);
}
}

if(use_threads) {
g_threads->sync(jobs);
jobs.clear();
}
// std::cout << "Table 1 took " << (vnx::get_time_millis() - t1_begin) << " ms" << std::endl;

for(int t = 2; t <= N_TABLE; ++t)
{
std::vector<std::array<uint32_t, N_META>> M_next;
std::vector<std::tuple<uint32_t, uint32_t>> matches;
// const auto time_begin = vnx::get_time_millis();

const size_t num_entries = entries.size();
std::vector<std::array<uint32_t, N_META>> M_next;
std::vector<std::pair<uint32_t, uint32_t>> matches;

std::sort(entries.begin(), entries.end());

for(size_t x = 0; x < num_entries; ++x)
// std::cout << "Table " << t << " sort took " << (vnx::get_time_millis() - time_begin) << " ms" << std::endl;

for(size_t x = 0; x < entries.size(); ++x)
{
const auto YL = std::get<0>(entries[x]);
const auto YL = entries[x].first;

for(size_t y = x + 1; y < num_entries; ++y)
for(size_t y = x + 1; y < entries.size(); ++y)
{
const auto YR = std::get<0>(entries[y]);
const auto YR = entries[y].first;

if(YR == YL + 1) {
const auto PL = std::get<1>(entries[x]);
const auto PR = std::get<1>(entries[y]);
const auto PL = entries[x].second;
const auto PR = entries[y].second;
const auto& L_meta = M_tmp[PL];
const auto& R_meta = M_tmp[PR];

Expand Down Expand Up @@ -127,7 +179,7 @@ compute(const std::vector<uint32_t>& X_values, std::vector<uint32_t>* X_out, con
matches.emplace_back(Y_i, M_next.size());

if(X_out) {
LR[t-2].emplace_back(PL, PR);
LR_tmp[t-2].emplace_back(PL, PR);
}
M_next.push_back(meta);
}
Expand All @@ -142,8 +194,12 @@ compute(const std::vector<uint32_t>& X_values, std::vector<uint32_t>* X_out, con
}
M_tmp = std::move(M_next);
entries = std::move(matches);

// std::cout << "Table " << t << " took " << (vnx::get_time_millis() - time_begin) << " ms, " << entries.size() << " entries" << std::endl;
}

std::sort(entries.begin(), entries.end());

std::vector<std::pair<uint32_t, bytes_t<48>>> out;
for(const auto& entry : entries)
{
Expand All @@ -155,17 +211,17 @@ compute(const std::vector<uint32_t>& X_values, std::vector<uint32_t>* X_out, con
{
std::vector<uint32_t> I_next;
for(const auto i : I_tmp) {
I_next.push_back(LR[k][i].first);
I_next.push_back(LR[k][i].second);
I_next.push_back(LR_tmp[k][i].first);
I_next.push_back(LR_tmp[k][i].second);
}
I_tmp = std::move(I_next);
}
for(const auto i : I_tmp) {
X_out->push_back(X_tmp[i]);
}
}
const auto& Y = std::get<0>(entry);
const auto& I = std::get<1>(entry);
const auto& Y = entry.first;
const auto& I = entry.second;
const auto& meta = M_tmp[I];
out.emplace_back(Y, bytes_t<48>(meta.data(), sizeof(meta)));
}
Expand Down
28 changes: 21 additions & 7 deletions test/test_pos_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include <mmx/pos/verify.h>
#include <vnx/vnx.h>
#include <random>

using namespace mmx;
Expand All @@ -16,21 +17,27 @@ int main(int argc, char** argv)
const int ksize = argc > 1 ? std::atoi(argv[1]) : 32;
const int xbits = argc > 2 ? std::atoi(argv[2]) : 0;

vnx::init("test_pos_compute", 0, nullptr);

std::cout << "ksize = " << ksize << std::endl;
std::cout << "xbits = " << xbits << std::endl;

const bool full_mode = (xbits == ksize);

uint8_t id[32] = {};

std::mt19937_64 generator(1337);
std::uniform_int_distribution<uint64_t> dist(0, (1 << (ksize - xbits)) - 1);
std::uniform_int_distribution<uint64_t> dist(0, (uint64_t(1) << (ksize - xbits)) - 1);

std::vector<uint32_t> X_values;
if(xbits < ksize) {
if(!full_mode) {
for(int i = 0; i < 256; ++i) {
X_values.push_back(dist(generator));
}
} else {
X_values.push_back(0);
for(int i = 0; i < 256; ++i) {
X_values.push_back(i);
}
}

std::cout << "X = ";
Expand All @@ -40,14 +47,21 @@ int main(int argc, char** argv)
std::cout << std::endl;

std::vector<uint32_t> X_out;
const auto res = pos::compute(X_values, &X_out, id, ksize, xbits);
const auto res = pos::compute(X_values, &X_out, id, ksize, full_mode ? ksize - 8 : xbits);

for(const auto& entry : res) {
std::cout << "Y = " << entry.first << std::endl;
std::cout << "M = " << entry.second.to_string() << std::endl;
for(size_t i = 0; i < std::min<size_t>(res.size(), 5); ++i)
{
std::cout << "Y = " << res[i].first << std::endl;
std::cout << "M = " << res[i].second.to_string() << std::endl;
std::cout << "X = ";
for(size_t k = 0; k < 256; ++k) {
std::cout << X_out[i * 256 + k] << " ";
}
std::cout << std::endl;
}
std::cout << "num_entries = " << res.size() << std::endl;

vnx::close();
}


2 changes: 1 addition & 1 deletion vnx-base

0 comments on commit 6688e44

Please sign in to comment.