Skip to content

Commit

Permalink
explicit load balancing for stabler parallel pairing under stress
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Oct 6, 2023
1 parent 253e041 commit 00feff5
Showing 1 changed file with 56 additions and 40 deletions.
96 changes: 56 additions & 40 deletions constantine/signatures/bls_signatures_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import ./bls_signatures{.all.}
export bls_signatures

import
# Standard library
std/atomics,
# Constantine
../threadpool/[threadpool, partitioners],
../platforms/[abstractions, allocs, views],
../serialization/endians,
Expand Down Expand Up @@ -103,63 +106,76 @@ proc batchVerify_parallel*[Msg, Pubkey, Sig](
type FF2 = Sig.F
type FpK = Sig.F.C.getGT()

# Stage 0: Setup per-thread accumulators
let N = pubkeys.len
let numAccums = min(N, tp.numThreads)
# Stage 0a: Setup per-thread accumulators
debug: doAssert pubkeys.len <= 1 shl 32
let N = pubkeys.len.uint32
let numAccums = min(N, tp.numThreads.uint32)
let accums = allocHeapArray(BLSBatchSigAccumulator[H, FF1, FF2, Fpk, ECP_ShortW_Jac[Sig.F, Sig.G], k], numAccums)
let chunkingDescriptor = balancedChunksPrioNumber(0, N, numAccums)
let
pubkeysView = pubkeys.toView()
messagesView = messages.toView()
signaturesView = signatures.toView()
dstView = domainSepTag.toView()

# Stage 0b: Setup synchronization
var currentItem {.noInit.}: Atomic[uint32]
var terminateSignal {.noInit.}: Atomic[bool]
currentItem.store(0, moRelaxed)
terminateSignal.store(false, moRelaxed)

# Stage 1: Accumulate partial pairings (Miller Loops)
# ---------------------------------------------------
proc accumChunk(
proc accumulate(
ctx: ptr BLSBatchSigAccumulator,
pubkeys: View[Pubkey],
messages: View[Msg],
signatures: View[Sig],
pubkeys: ptr UncheckedArray[Pubkey],
messages: ptr UncheckedArray[Msg],
signatures: ptr UncheckedArray[Sig],
N: uint32,
domainSepTag: View[byte],
secureRandomBytes: array[32, byte],
accumSepTag: array[sizeof(int), byte]): bool {.nimcall, gcsafe, tags: [Alloca, VarTime].} =
secureRandomBytes: ptr array[32, byte],
accumSepTag: array[sizeof(int), byte],
terminateSignal: ptr Atomic[bool],
currentItem: ptr Atomic[uint32]): bool {.nimcall, gcsafe.} =
ctx[].init(
domainSepTag.toOpenArray(),
secureRandomBytes,
secureRandomBytes[],
accumSepTag)

for i in 0 ..< pubkeys.len:
while not terminateSignal[].load(moRelaxed):
let i = currentItem[].fetchAdd(1, moRelaxed)
if i >= N:
break

if not ctx[].update(pubkeys[i], messages[i], signatures[i]):
terminateSignal[].store(true, moRelaxed)
return false

ctx[].handover()
return true

# Stage 2: Schedule work
# ---------------------------------------------------
let partialStates = allocStackArray(Flowvar[bool], numAccums)
for (id, start, size) in items(chunkingDescriptor):
partialStates[id] = tp.spawn accumChunk(
for id in 0 ..< numAccums:
partialStates[id] = tp.spawn accumulate(
accums[id].addr,
pubkeysView.chunk(start, size),
messagesView.chunk(start, size),
signaturesView.chunk(start, size),
dstView,
secureRandomBytes,
id.uint.toBytes(bigEndian))

# Note: to avoid memory leaks, even if there is a `false` partial state
# (for example due to a point at infinity),
# we still need to call `sync` on all tasks.

# Stage 2: Reduce partial pairings
pubkeys.asUnchecked(),
messages.asUnchecked(),
signatures.asUnchecked(),
N,
domainSepTag.toView(),
secureRandomBytes.unsafeAddr,
id.uint.toBytes(bigEndian),
terminateSignal.addr,
currentItem.addr)

# Stage 3: Reduce partial pairings
# --------------------------------
# Linear merge with latency hiding, we could consider a parallel logarithmic merge via a binary tree merge / divide-and-conquer
result = sync partialStates[0]
for i in 1 ..< numAccums:
result = result and sync partialStates[i]
if result: # As long as no error is returned, accumulate
result = result and accums[0].merge(accums[i])
if not result: # Don't proceed to final exponentiation if there is already an error
return false

return accums[0].finalVerify()
block HappyPath: # sync must be called even if result is false in the middle to avoid tasks leaking
result = sync partialStates[0]
for i in 1 ..< numAccums:
result = result and sync partialStates[i]
if result: # As long as no error is returned, accumulate
result = result and accums[0].merge(accums[i])
if not result: # Don't proceed to final exponentiation if there is already an error
break HappyPath

result = accums[0].finalVerify()

freeHeap(accums)

0 comments on commit 00feff5

Please sign in to comment.