Skip to content

Commit

Permalink
2x faster parallel EC sum for less than 8192 points
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Oct 5, 2023
1 parent 0493154 commit 6b58d26
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 58 deletions.
16 changes: 8 additions & 8 deletions benchmarks/bench_ec_g1_batch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ proc main() =
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
multiAddParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
multiAddBench(ECP_ShortW_JacExt[Fp[curve], G1], numPoints, useBatching = false, batchIters)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
multiAddBench(ECP_ShortW_JacExt[Fp[curve], G1], numPoints, useBatching = true, batchIters)
# separator()
# for numPoints in testNumPoints:
# let batchIters = max(1, Iters div numPoints)
# multiAddBench(ECP_ShortW_JacExt[Fp[curve], G1], numPoints, useBatching = false, batchIters)
# separator()
# for numPoints in testNumPoints:
# let batchIters = max(1, Iters div numPoints)
# multiAddBench(ECP_ShortW_JacExt[Fp[curve], G1], numPoints, useBatching = true, batchIters)
separator()
separator()

Expand Down
61 changes: 42 additions & 19 deletions constantine/math/elliptic/ec_shortweierstrass_batch_ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,6 @@ func accum_half_vartime[F; G: static Subgroup](
# Batch addition - High-level
# ------------------------------------------------------------

template `+=`[F; G: static Subgroup](P: var ECP_ShortW_JacExt[F, G], Q: ECP_ShortW_Aff[F, G]) =
# All vartime procedures MUST be tagged vartime
# Hence we do not expose `+=` for extended jacobian operation to prevent `vartime` mistakes
# The following algorithms are all tagged vartime, hence for genericity
# we create a local `+=` for this module only
madd_vartime(P, P, Q)

func accumSum_chunk_vartime*[F; G: static Subgroup](
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G] or ECP_ShortW_JacExt[F, G]),
points: ptr UncheckedArray[ECP_ShortW_Aff[F, G]], len: int) {.noInline, tags:[VarTime, Alloca].} =
Expand All @@ -398,7 +391,7 @@ func accumSum_chunk_vartime*[F; G: static Subgroup](
while n >= minNumPointsSerial:
if (n and 1) == 1: # odd number of points
## Accumulate the last
r += points[n-1]
r.madd_vartime(r, points[n-1])
n -= 1

# Compute [0, n/2) += [n/2, n)
Expand All @@ -409,7 +402,7 @@ func accumSum_chunk_vartime*[F; G: static Subgroup](

# Tail
for i in 0 ..< n:
r += points[i]
r.madd_vartime(r, points[i])

func accum_batch_vartime[F; G: static Subgroup](
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G] or ECP_ShortW_JacExt[F, G]),
Expand Down Expand Up @@ -472,36 +465,66 @@ func sum_reduce_vartime*[F; G: static Subgroup](
type EcAddAccumulator_vartime*[EC, F; G: static Subgroup; AccumMax: static int] = object
## Elliptic curve addition accumulator
## **Variable-Time**
# The `cur` is dereferenced first so better locality if at the beginning
# The `len` is dereferenced first so better locality if at the beginning
# Do we want alignment guarantees?
cur: uint32
len: uint32
accum: EC
buffer: array[AccumMax, ECP_ShortW_Aff[F, G]]

func init*(ctx: var EcAddAccumulator_vartime) =
static: doAssert EcAddAccumulator_vartime.AccumMax >= 16, "There is no point in a EcAddBatchAccumulator if the batch size is too small"
ctx.accum.setInf()
ctx.cur = 0
ctx.len = 0

func consumeBuffer[EC, F; G: static Subgroup; AccumMax: static int](
ctx: var EcAddAccumulator_vartime[EC, F, G, AccumMax]) {.noInline, tags: [VarTime, Alloca].}=
if ctx.cur == 0:
if ctx.len == 0:
return

ctx.accum.accumSum_chunk_vartime(ctx.buffer.asUnchecked(), ctx.cur)
ctx.cur = 0
ctx.accum.accumSum_chunk_vartime(ctx.buffer.asUnchecked(), ctx.len.int)
ctx.len = 0

func update*[EC, F, G; AccumMax: static int](
ctx: var EcAddAccumulator_vartime[EC, F, G, AccumMax],
P: ECP_ShortW_Aff[F, G]) =

if ctx.cur == AccumMax:
if P.isInf().bool:
return

if ctx.len == AccumMax:
ctx.consumeBuffer()

ctx.buffer[ctx.cur] = P
ctx.cur += 1
ctx.buffer[ctx.len] = P
ctx.len += 1

func handover*(ctx: var EcAddAccumulator_vartime) {.inline.} =
ctx.consumeBuffer()

func merge*[EC, F, G; AccumMax: static int](
ctxDst: var EcAddAccumulator_vartime[EC, F, G, AccumMax],
ctxSrc: EcAddAccumulator_vartime[EC, F, G, AccumMax]) =

var sCur = 0'u32
var itemsLeft = ctxSrc.len

if ctxDst.len + ctxSrc.len >= AccumMax:
# previous partial update, fill the buffer and do a batch addition
let free = AccumMax - ctxDst.len
for i in 0 ..< free:
ctxDst.buffer[ctxDst.len+i] = ctxSrc.buffer[i]
ctxDst.len = AccumMax
ctxDst.consumeBuffer()
sCur = free
itemsLeft -= free

# Store the tail
for i in 0 ..< itemsLeft:
ctxDst.buffer[ctxDst.len+i] = ctxSrc.buffer[sCur+i]

ctxDst.len += itemsLeft

ctxDst.accum.sum_vartime(ctxDst.accum, ctxSrc.accum)

# TODO: `merge` for parallel recursive divide-and-conquer processing

func finish*[EC, F, G; AccumMax: static int](
ctx: var EcAddAccumulator_vartime[EC, F, G, AccumMax],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,13 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
points: openArray[ECP_ShortW_Aff[F, G]]) {.noInline.} =
## Batch addition of `points` into `r`
## `r` is overwritten
## Compute is parallelized, if beneficial.
## This function can be nested in another parallel function
## Scales better for large number of points

# Chunking constants in ec_shortweierstrass_batch_ops.nim
const maxTempMem = 262144 # 2¹⁸ = 262144
const maxChunkSize = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
const minChunkSize = (maxChunkSize * 60) div 100 # We want 60%~100% full chunks

if points.len <= maxChunkSize:
r.setInf()
r.accumSum_chunk_vartime(points.asUnchecked(), points.len)
return

let chunkDesc = balancedChunksPrioSize(
start = 0, stopEx = points.len,
minChunkSize, maxChunkSize,
Expand Down Expand Up @@ -72,48 +66,58 @@ proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
partialResultsAffine.batchAffine(partialResults, chunkDesc.numChunks)
r.sum_reduce_vartime(partialResultsAffine, chunkDesc.numChunks)

proc sum_reduce_vartime_parallelFor[F; G: static Subgroup](
proc sum_reduce_vartime_parallelAccums[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) =
## Batch addition of `points` into `r`
## `r` is overwritten
## Compute is parallelized, if beneficial.
## 2x faster for low number of points

mixin globalSum
const maxTempMem = 1 shl 18 # 2¹⁸ = 262144
const maxChunkSize = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
type Acc = EcAddAccumulator_vartime[typeof(r), F, G, maxChunkSize]

const maxTempMem = 262144 # 2¹⁸ = 262144
const maxStride = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
let ps = points.asUnchecked()
let N = points.len

let p = points.asUnchecked
let pointsLen = points.len
mixin globalAcc

tp.parallelFor i in 0 ..< points.len:
stride: maxStride
captures: {p, pointsLen}
reduceInto(globalSum: typeof(r)):
const chunkSize = 32

tp.parallelFor i in 0 ..< N:
stride: chunkSize
captures: {ps, N}
reduceInto(globalAcc: ptr Acc):
prologue:
var localSum {.noInit.}: typeof(r)
localSum.setInf()
var workerAcc = allocHeap(Acc)
workerAcc[].init()
forLoop:
let n = min(maxStride, pointsLen-i)
localSum.accumSum_chunk_vartime(p +% i, n)
merge(remoteSum: Flowvar[typeof(r)]):
localSum.sum_vartime(localSum, sync(remoteSum))
for j in i ..< min(i+chunkSize, N):
workerAcc[].update(ps[j])
merge(remoteAccFut: Flowvar[ptr Acc]):
let remoteAcc = sync(remoteAccFut)
workerAcc[].merge(remoteAcc[])
freeHeap(remoteAcc)
epilogue:
return localSum
workerAcc[].handover()
return workerAcc

r = sync(globalSum)
let ctx = sync(globalAcc)
ctx[].finish(r)
freeHeap(ctx)

proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) {.inline.} =
## Batch addition of `points` into `r`
## Parallel Batch addition of `points` into `r`
## `r` is overwritten
## Compute is parallelized, if beneficial.
## This function cannot be nested in another parallel function
when false:
tp.sum_reduce_vartime_parallelFor(r, points)

if points.len < 256:
r.setInf()
r.accumSum_chunk_vartime(points.asUnchecked(), points.len)
elif points.len < 8192:
tp.sum_reduce_vartime_parallelAccums(r, points)
else:
tp.sum_reduce_vartime_parallelChunks(r, points)

0 comments on commit 6b58d26

Please sign in to comment.