Skip to content

Commit

Permalink
performance optimization w mt
Browse files Browse the repository at this point in the history
  • Loading branch information
alex v committed Aug 25, 2024
1 parent 69bf803 commit 708ba07
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 99 deletions.
83 changes: 73 additions & 10 deletions src/bls/src/bls_c_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#define BLS_MULTI_VERIFY_THREAD
#endif

#include <future>
#include <thread>
#include <vector>

inline void Gmul(G1& z, const G1& x, const Fr& y) { G1::mul(z, x, y); }
inline void Gmul(G2& z, const G2& x, const Fr& y) { G2::mul(z, x, y); }
Expand Down Expand Up @@ -440,30 +443,90 @@ int blsAggregateVerifyNoCheck(const blsSignature *sig, const blsPublicKey *pubVe
if (n == 0) return 0;
#if 1 // 1.1 times faster
GT e;
const char *msg = (const char*)msgVec;
const size_t N = 16;
G1 g1Vec[N+1];
G2 g2Vec[N+1];
const char* msg = (const char*)msgVec;
constexpr size_t N = 16;
G1 g1Vec[N + 1];
G2 g2Vec[N + 1];
bool initE = true;

std::vector<std::future<GT>> futuresMiller;
std::vector<std::future<bool>> futures;
size_t numThreads = std::thread::hardware_concurrency();

while (n > 0) {
size_t m = mcl::fp::min_<size_t>(n, N);
for (size_t i = 0; i < m; i++) {
g1Vec[i] = *cast(&pubVec[i].v);
if (g1Vec[i].isZero()) return 0;
hashAndMapToG(g2Vec[i], &msg[i * msgSize], msgSize);

// Lambda function to compute g1Vec and g2Vec in parallel
auto computeVectors = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
g1Vec[i] = *cast(&pubVec[i].v);
if (g1Vec[i].isZero()) return false; // Indicate failure to main thread
hashAndMapToG(g2Vec[i], &msg[i * msgSize], msgSize);
}
return true;
};

// Divide work among threads
size_t chunkSize = m / numThreads;
futures.clear();

// Launch threads to compute g1Vec and g2Vec
for (size_t t = 0; t < numThreads; ++t) {
size_t start = t * chunkSize;
size_t end = (t == numThreads - 1) ? m : (t + 1) * chunkSize;
futures.emplace_back(std::async(std::launch::async, computeVectors, start, end));
}

// Check for any failures in thread execution
for (auto& fut : futures) {
if (!fut.get()) return 0; // If any thread found a zero vector, return 0
}

pubVec += m;
msg += m * msgSize;
n -= m;

if (n == 0) {
g1Vec[m] = getBasePoint();
G2::neg(g2Vec[m], *cast(&sig->v));
m++;
}
millerLoopVec(e, g1Vec, g2Vec, m, initE);
initE = false;

// Prepare for parallel miller loop
auto millerLoopTask = [&](size_t start, size_t end, bool initE) {
GT localE;
millerLoopVec(localE, g1Vec + start, g2Vec + start, end - start, initE);
return localE;
};

// Launch threads to execute miller loop in parallel
futuresMiller.clear();
size_t mlChunkSize = m / numThreads; // Divide miller loop work among threads
std::vector<GT> partialResults(numThreads); // To store partial results

for (size_t t = 0; t < numThreads; ++t) {
size_t start = t * mlChunkSize;
size_t end = (t == numThreads - 1) ? m : (t + 1) * mlChunkSize;
futuresMiller.emplace_back(std::async(std::launch::async, millerLoopTask, start, end, true));
}

// Combine results from each thread
for (size_t t = 0; t < numThreads; ++t) {
partialResults[t] = futuresMiller[t].get();
}

if (initE)
e = partialResults[0];

// Combine partial results into final result e
for (size_t t = initE ? 1 : 0; t < numThreads; ++t) {
e *= partialResults[t]; // Combine partial results
}

initE = false; // Ensure next iteration is not initialized
}

// Final exponentiation outside the loop
BN::finalExp(e, e);
return e.isOne();
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
namespace bulletproofs {

template <typename T>
AmountRecoveryRequest<T> AmountRecoveryRequest<T>::of(const RangeProofWithSeed<T>& proof, const range_proof::GammaSeed<T>& nonce)
AmountRecoveryRequest<T> AmountRecoveryRequest<T>::of(const RangeProofWithSeed<T>& proof, const range_proof::GammaSeed<T>& nonce, const size_t& id)
{
auto proof_with_transcript = RangeProofWithTranscript<T>::Build(proof);

AmountRecoveryRequest<T> req{
1,
id,
proof.seed,
proof_with_transcript.x,
proof_with_transcript.z,
Expand All @@ -31,7 +31,7 @@ AmountRecoveryRequest<T> AmountRecoveryRequest<T>::of(const RangeProofWithSeed<T
0};
return req;
}
template AmountRecoveryRequest<Mcl> AmountRecoveryRequest<Mcl>::of(const RangeProofWithSeed<Mcl>&, const range_proof::GammaSeed<Mcl>&);
template AmountRecoveryRequest<Mcl> AmountRecoveryRequest<Mcl>::of(const RangeProofWithSeed<Mcl>&, const range_proof::GammaSeed<Mcl>&, const size_t&);

} // namespace bulletproofs

3 changes: 2 additions & 1 deletion src/blsct/range_proof/bulletproofs/amount_recovery_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ struct AmountRecoveryRequest

static AmountRecoveryRequest<T> of(
const RangeProofWithSeed<T>& proof,
const range_proof::GammaSeed<T>& nonce);
const range_proof::GammaSeed<T>& nonce,
const size_t& id = 0);
};

} // namespace bulletproofs
Expand Down
148 changes: 67 additions & 81 deletions src/blsct/range_proof/bulletproofs/range_proof_logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#include <blsct/range_proof/bulletproofs/range_proof_logic.h>
#include <blsct/range_proof/common.h>
#include <blsct/range_proof/msg_amt_cipher.h>
#include <future>
#include <optional>
#include <stdexcept>
#include <variant>
#include <vector>

namespace bulletproofs {

Expand Down Expand Up @@ -248,102 +250,86 @@ bool RangeProofLogic<T>::VerifyProofs(
using Scalar = typename T::Scalar;
using Scalars = Elements<Scalar>;

// Vector to hold future results from async tasks
std::vector<std::future<bool>> futures;

// Launch a verification task for each proof transcript in parallel
for (const RangeProofWithTranscript<T>& p : proof_transcripts) {
if (p.proof.Ls.Size() != p.proof.Rs.Size()) return false;

const range_proof::Generators<T> gens = m_common.Gf().GetInstance(p.proof.seed);
G_H_Gi_Hi_ZeroVerifier<T> verifier(max_mn);

auto num_rounds = range_proof::Common<T>::GetNumRoundsExclLast(p.proof.Vs.Size());
Scalar weight_y = Scalar::Rand();
Scalar weight_z = Scalar::Rand();

Scalars z_pows_from_2 = Scalars::FirstNPow(p.z, p.num_input_values_power_2 + 1, 2); // z^2, z^3, ... // VectorPowers(pd.z, M+3);
Scalar y_pows_sum = Scalars::FirstNPow(p.y, p.concat_input_values_in_bits).Sum(); // VectorPowerSum(p.y, MN);

//////// (65)
// g^t_hat * h^tau_x = V^(z^2) * g^delta_yz * T1^x * T2^(x^2)
// g^(t_hat - delta_yz) = h^(-tau_x) * V^(z^2) * T1^x * T2^(x^2)

// LHS (65)
verifier.AddNegativeH(p.proof.tau_x * weight_y); // LHS (65)

// delta(y,z) in (39)
// = (z - z^2)*<1^n, y^n> - z^3<1^n,2^n>
// = z*<1^n, y^n> (1) - z^2*<1^n, y^n> (2) - z^3<1^n,2^n> (3)
Scalar delta_yz =
p.z * y_pows_sum // (1)
- (z_pows_from_2[0] * y_pows_sum); // (2)
for (size_t i = 1; i <= p.num_input_values_power_2; ++i) {
// multiply z^3, z^4, ..., z^(mn+3)
delta_yz = delta_yz - z_pows_from_2[i] * m_common.InnerProd1x2Pows64(); // (3)
}
futures.emplace_back(std::async(std::launch::async, [this, &p, max_mn]() -> bool {
if (p.proof.Ls.Size() != p.proof.Rs.Size()) return false;

// g part of LHS in (65) where delta_yz on RHS is moved to LHS
// g^t_hat ... = ... g^delta_yz
// g^(t_hat - delta_yz) = ...
verifier.AddNegativeG((p.proof.t_hat - delta_yz) * weight_y);
const range_proof::Generators<T> gens = m_common.Gf().GetInstance(p.proof.seed);
G_H_Gi_Hi_ZeroVerifier<T> verifier(max_mn);

// V^(z^2) in RHS (65)
for (size_t i = 0; i < p.proof.Vs.Size(); ++i) {
verifier.AddPoint(LazyPoint<T>(p.proof.Vs[i] - (gens.G * p.proof.min_value), z_pows_from_2[i] * weight_y)); // multiply z^2, z^3, ...
}
auto num_rounds = range_proof::Common<T>::GetNumRoundsExclLast(p.proof.Vs.Size());
Scalar weight_y = Scalar::Rand();
Scalar weight_z = Scalar::Rand();

// T1^x and T2^(x^2) in RHS (65)
verifier.AddPoint(LazyPoint<T>(p.proof.T1, p.x * weight_y)); // T1^x
verifier.AddPoint(LazyPoint<T>(p.proof.T2, p.x.Square() * weight_y)); // T2^(x^2)
Scalars z_pows_from_2 = Scalars::FirstNPow(p.z, p.num_input_values_power_2 + 1, 2); // z^2, z^3, ...
Scalar y_pows_sum = Scalars::FirstNPow(p.y, p.concat_input_values_in_bits).Sum();

//////// (66)
// P = A * S^x * g^(-z) * (h')^(z * y^n + z^2 * 2^n)
// exponents of g and (h') are created in a loop later
//////// (65)
verifier.AddNegativeH(p.proof.tau_x * weight_y);

// A and S^x in RHS (66)
verifier.AddPoint(LazyPoint<T>(p.proof.A, weight_z)); // A
verifier.AddPoint(LazyPoint<T>(p.proof.S, p.x * weight_z)); // S^x
Scalar delta_yz = p.z * y_pows_sum - (z_pows_from_2[0] * y_pows_sum);
for (size_t i = 1; i <= p.num_input_values_power_2; ++i) {
delta_yz = delta_yz - z_pows_from_2[i] * m_common.InnerProd1x2Pows64();
}

//////// (67), (68)
auto gen_exps = ImpInnerProdArg::GenGeneratorExponents<T>(num_rounds, p.xs);
verifier.AddNegativeG((p.proof.t_hat - delta_yz) * weight_y);

// for all bits of concat input values, do:
ImpInnerProdArg::LoopWithYPows<Mcl>(p.concat_input_values_in_bits, p.y,
[&](const size_t& i, const Scalar& y_pow, const Scalar& y_inv_pow) {
// g^a * h^b (16)
Scalar gi_exp = p.proof.a * gen_exps[i]; // g^a in (16) is distributed to each generator
Scalar hi_exp = p.proof.b *
y_inv_pow *
gen_exps[p.concat_input_values_in_bits - 1 - i]; // h^b in (16) is distributed to each generator. y_inv_pow to turn generator to (h')
for (size_t i = 0; i < p.proof.Vs.Size(); ++i) {
verifier.AddPoint(LazyPoint<T>(p.proof.Vs[i] - (gens.G * p.proof.min_value), z_pows_from_2[i] * weight_y));
}

gi_exp = gi_exp + p.z; // g^(-z) in RHS (66)
verifier.AddPoint(LazyPoint<T>(p.proof.T1, p.x * weight_y));
verifier.AddPoint(LazyPoint<T>(p.proof.T2, p.x.Square() * weight_y));

// ** z^2 * 2^n in (h')^(z * y^n + z^2 * 2^n) in RHS (66)
Scalar tmp =
z_pows_from_2[i / range_proof::Setup::num_input_value_bits] * // skipping the first 2 powers. different z_pow is assigned to each number
m_common.TwoPows64()[i % range_proof::Setup::num_input_value_bits]; // power of 2 corresponding to i-th bit of the number being processed
//////// (66)
verifier.AddPoint(LazyPoint<T>(p.proof.A, weight_z));
verifier.AddPoint(LazyPoint<T>(p.proof.S, p.x * weight_z));

// ** z * y^n in (h')^(z * y^n + z^2 * 2^n) (66)
hi_exp = hi_exp - (tmp + p.z * y_pow) * y_inv_pow;
//////// (67), (68)
auto gen_exps = ImpInnerProdArg::GenGeneratorExponents<T>(num_rounds, p.xs);

verifier.SetGiExp(i, (gi_exp * weight_z).Negate()); // (16) g^a moved to LHS
verifier.SetHiExp(i, (hi_exp * weight_z).Negate()); // (16) h^b moved to LHS
});
ImpInnerProdArg::LoopWithYPows<Mcl>(p.concat_input_values_in_bits, p.y,
[&](const size_t& i, const Scalar& y_pow, const Scalar& y_inv_pow) {
Scalar gi_exp = p.proof.a * gen_exps[i];
Scalar hi_exp = p.proof.b * y_inv_pow * gen_exps[p.concat_input_values_in_bits - 1 - i];

verifier.AddNegativeH(p.proof.mu * weight_z); // ** h^mu (67) RHS
auto x_invs = p.xs.Invert();
gi_exp = gi_exp + p.z;

// add L and R of all rounds to RHS (66) which equals P to generate the P of the final round on LHS (16)
for (size_t i = 0; i < num_rounds; ++i) {
verifier.AddPoint(LazyPoint<T>(p.proof.Ls[i], p.xs[i].Square() * weight_z));
verifier.AddPoint(LazyPoint<T>(p.proof.Rs[i], x_invs[i].Square() * weight_z));
}
Scalar tmp = z_pows_from_2[i / range_proof::Setup::num_input_value_bits] *
m_common.TwoPows64()[i % range_proof::Setup::num_input_value_bits];

hi_exp = hi_exp - (tmp + p.z * y_pow) * y_inv_pow;

verifier.AddPositiveG((p.proof.t_hat - p.proof.a * p.proof.b) * p.c_factor * weight_z);
verifier.SetGiExp(i, (gi_exp * weight_z).Negate());
verifier.SetHiExp(i, (hi_exp * weight_z).Negate());
});

verifier.AddNegativeH(p.proof.mu * weight_z);
auto x_invs = p.xs.Invert();

for (size_t i = 0; i < num_rounds; ++i) {
verifier.AddPoint(LazyPoint<T>(p.proof.Ls[i], p.xs[i].Square() * weight_z));
verifier.AddPoint(LazyPoint<T>(p.proof.Rs[i], x_invs[i].Square() * weight_z));
}

verifier.AddPositiveG((p.proof.t_hat - p.proof.a * p.proof.b) * p.c_factor * weight_z);

bool res = verifier.Verify(
gens.G,
gens.H,
gens.GetGiSubset(max_mn),
gens.GetHiSubset(max_mn));
return res;
}));
}

bool res = verifier.Verify(
gens.G,
gens.H,
gens.GetGiSubset(max_mn),
gens.GetHiSubset(max_mn));
if (!res) return false;
// Wait for all threads to finish and collect results
for (auto& fut : futures) {
if (!fut.get()) return false;
}

return true;
Expand Down Expand Up @@ -446,7 +432,7 @@ AmountRecoveryResult<T> RangeProofLogic<T>::RecoverAmounts(
auto msg_amt = maybe_msg_amt.value();

auto x = range_proof::RecoveredData<T>(
i,
req.id,
msg_amt.amount,
req.nonce.GetHashWithSalt(100), // gamma for vs[0]
msg_amt.msg);
Expand Down
4 changes: 3 additions & 1 deletion src/blsct/wallet/keyman.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,11 @@ bulletproofs::AmountRecoveryResult<Arith> KeyMan::RecoverOutputs(const std::vect

for (size_t i = 0; i < outs.size(); i++) {
CTxOut out = outs[i];
if (out.blsctData.viewTag != CalculateViewTag(out.blsctData.blindingKey, viewKey.GetScalar()))
continue;
auto nonce = CalculateNonce(out.blsctData.blindingKey, viewKey.GetScalar());
bulletproofs::RangeProofWithSeed<Arith> proof = {out.blsctData.rangeProof, out.tokenId};
reqs.push_back(bulletproofs::AmountRecoveryRequest<Arith>::of(proof, nonce));
reqs.push_back(bulletproofs::AmountRecoveryRequest<Arith>::of(proof, nonce, i));
}

return rp.RecoverAmounts(reqs);
Expand Down
4 changes: 2 additions & 2 deletions src/test/coins_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool operator==(const Coin &a, const Coin &b) {
class CCoinsViewTest : public CCoinsView
{
uint256 hashBestBlock_;
CStakedCommitmentsMap cacheStakedCommitments_;
CStakedCommitmentsMap deltaStakedCommitments_;
std::map<COutPoint, Coin> map_;

public:
Expand Down Expand Up @@ -72,7 +72,7 @@ class CCoinsViewTest : public CCoinsView
hashBestBlock_ = hashBlock;

for (auto& it : stakedCommitments) {
cacheStakedCommitments_[it.first] = it.second;
deltaStakedCommitments_[it.first] = it.second;
};
if (erase)
stakedCommitments.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/wallet/rpc/transactions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ RPCHelpMan listtransactions()
// iterate backwards until we have nCount items to return:
for (CWallet::TxItems::const_reverse_iterator it = txOrdered.rbegin(); it != txOrdered.rend(); ++it) {
CWalletTx* const pwtx = (*it).second;
ListTransactions(*pwallet, *pwtx, 0, 100000000, true, ret, filter, filter_label);
ListTransactions(*pwallet, *pwtx, 1, 100000000, true, ret, filter, filter_label);
if ((int)ret.size() >= (nCount + nFrom)) break;
}
}
Expand Down

0 comments on commit 708ba07

Please sign in to comment.