From 00feff576b5290d35c7989387007363b973f7b4f Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Fri, 6 Oct 2023 08:35:03 +0200 Subject: [PATCH] explicit load balancing for stabler parallel pairing under stress --- .../signatures/bls_signatures_parallel.nim | 96 +++++++++++-------- 1 file changed, 56 insertions(+), 40 deletions(-) diff --git a/constantine/signatures/bls_signatures_parallel.nim b/constantine/signatures/bls_signatures_parallel.nim index 70bccadab..cf83356c4 100644 --- a/constantine/signatures/bls_signatures_parallel.nim +++ b/constantine/signatures/bls_signatures_parallel.nim @@ -21,6 +21,9 @@ import ./bls_signatures{.all.} export bls_signatures import + # Standard library + std/atomics, + # Constantine ../threadpool/[threadpool, partitioners], ../platforms/[abstractions, allocs, views], ../serialization/endians, @@ -103,63 +106,76 @@ proc batchVerify_parallel*[Msg, Pubkey, Sig]( type FF2 = Sig.F type FpK = Sig.F.C.getGT() - # Stage 0: Setup per-thread accumulators - let N = pubkeys.len - let numAccums = min(N, tp.numThreads) + # Stage 0a: Setup per-thread accumulators + debug: doAssert pubkeys.len <= 1 shl 32 + let N = pubkeys.len.uint32 + let numAccums = min(N, tp.numThreads.uint32) let accums = allocHeapArray(BLSBatchSigAccumulator[H, FF1, FF2, Fpk, ECP_ShortW_Jac[Sig.F, Sig.G], k], numAccums) - let chunkingDescriptor = balancedChunksPrioNumber(0, N, numAccums) - let - pubkeysView = pubkeys.toView() - messagesView = messages.toView() - signaturesView = signatures.toView() - dstView = domainSepTag.toView() + + # Stage 0b: Setup synchronization + var currentItem {.noInit.}: Atomic[uint32] + var terminateSignal {.noInit.}: Atomic[bool] + currentItem.store(0, moRelaxed) + terminateSignal.store(false, moRelaxed) # Stage 1: Accumulate partial pairings (Miller Loops) # --------------------------------------------------- - proc accumChunk( + proc accumulate( ctx: ptr BLSBatchSigAccumulator, - pubkeys: View[Pubkey], - messages: View[Msg], - signatures: View[Sig], + pubkeys: ptr UncheckedArray[Pubkey], + messages: ptr UncheckedArray[Msg], + signatures: ptr UncheckedArray[Sig], + N: uint32, domainSepTag: View[byte], - secureRandomBytes: array[32, byte], - accumSepTag: array[sizeof(int), byte]): bool {.nimcall, gcsafe, tags: [Alloca, VarTime].} = + secureRandomBytes: ptr array[32, byte], + accumSepTag: array[sizeof(int), byte], + terminateSignal: ptr Atomic[bool], + currentItem: ptr Atomic[uint32]): bool {.nimcall, gcsafe.} = ctx[].init( domainSepTag.toOpenArray(), - secureRandomBytes, + secureRandomBytes[], accumSepTag) - for i in 0 ..< pubkeys.len: + while not terminateSignal[].load(moRelaxed): + let i = currentItem[].fetchAdd(1, moRelaxed) + if i >= N: + break + if not ctx[].update(pubkeys[i], messages[i], signatures[i]): + terminateSignal[].store(true, moRelaxed) return false ctx[].handover() return true + # Stage 2: Schedule work + # --------------------------------------------------- let partialStates = allocStackArray(Flowvar[bool], numAccums) - for (id, start, size) in items(chunkingDescriptor): - partialStates[id] = tp.spawn accumChunk( + for id in 0 ..< numAccums: + partialStates[id] = tp.spawn accumulate( accums[id].addr, - pubkeysView.chunk(start, size), - messagesView.chunk(start, size), - signaturesView.chunk(start, size), - dstView, - secureRandomBytes, - id.uint.toBytes(bigEndian)) - - # Note: to avoid memory leaks, even if there is a `false` partial state - # (for example due to a point at infinity), - # we still need to call `sync` on all tasks. - - # Stage 2: Reduce partial pairings + pubkeys.asUnchecked(), + messages.asUnchecked(), + signatures.asUnchecked(), + N, + domainSepTag.toView(), + secureRandomBytes.unsafeAddr, + id.uint.toBytes(bigEndian), + terminateSignal.addr, + currentItem.addr) + + # Stage 3: Reduce partial pairings # -------------------------------- # Linear merge with latency hiding, we could consider a parallel logarithmic merge via a binary tree merge / divide-and-conquer - result = sync partialStates[0] - for i in 1 ..< numAccums: - result = result and sync partialStates[i] - if result: # As long as no error is returned, accumulate - result = result and accums[0].merge(accums[i]) - if not result: # Don't proceed to final exponentiation if there is already an error - return false - - return accums[0].finalVerify() + block HappyPath: # sync must be called even if result is false in the middle to avoid tasks leaking + result = sync partialStates[0] + for i in 1 ..< numAccums: + result = result and sync partialStates[i] + if result: # As long as no error is returned, accumulate + result = result and accums[0].merge(accums[i]) + if not result: # Don't proceed to final exponentiation if there is already an error + break HappyPath + + result = accums[0].finalVerify() + + freeHeap(accums)