Skip to content

Commit

Permalink
add matacc_acc_kyber opt
Browse files Browse the repository at this point in the history
  • Loading branch information
mkannwischer committed Dec 5, 2024
1 parent 65fcd4b commit 5d8e6fc
Show file tree
Hide file tree
Showing 3 changed files with 380 additions and 186 deletions.
10 changes: 8 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2638,8 +2638,14 @@ def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=Non
super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname)

def core(self, slothy):
# TODO: do actual opt
slothy.write_source_to_file(self.outfile_full)
slothy.config.inputs_are_outputs = True
slothy.config.variable_size = True

slothy.config.outputs = ["r9"]
slothy.optimize(start="slothy_start_1", end="slothy_end_1")
slothy.config.outputs = ["r9"]
slothy.optimize(start="slothy_start_2", end="slothy_end_2")



class matacc_asm_opt_16_32_kyber(Example):
Expand Down
190 changes: 98 additions & 92 deletions examples/naive/armv7m/matacc_acc_kyber.s
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@

.extern shake128_squeezeblocks

rptr .req r0
bptr .req r1
cptr .req r2
bufptr .req r3
zetaptr .req r4
val0 .req r5
val1 .req r6
tmp .req r7
tmp2 .req r8
k .req r9
q .req r10
qa .req r11
qinv .req r12
ctr .req r14

// q locates in the bottom half of the register
.macro plant_red_b q, qa, qinv, tmp
mul \tmp, \tmp, \qinv
Expand All @@ -12,53 +27,103 @@
// result in high half
.endm

// res replace poly2
.macro doublebasemul_asm_acc rptr, aptr, bptr, zetaptr, poly0, poly1, res, poly3, q, qa, qinv, tmp, tmp2, zeta
ldr.w \poly0, [\aptr], #4
ldr.w \poly1, [\bptr]
ldr.w \poly3, [\bptr, #4]
ldr.w \res, [\rptr]
ldr.w \zeta, [\zetaptr], #4

//basemul(r->coeffs + 4 * i, a->coeffs + 4 * i, b->coeffs + 4 * i, zetas[64 + i]);
smulwt \tmp, \zeta, \poly1
// b_1*zeta*qinv*plant_const; in low half
smlabb \tmp, \tmp, \q, \qa
// b_1*zeta
smultt \tmp, \poly0, \tmp
//a_1*b_1*zeta <2^32
smlabb \tmp, \poly0, \poly1, \tmp
// a1*b1*zeta+a0*b0
plant_red_b \q, \qa, \qinv, \tmp
// r[0] in upper half of tmp
smuadx \tmp2, \poly0, \poly1
plant_red_b \q, \qa, \qinv, \tmp2
// r[1] in upper half of tmp2
pkhtb \tmp, \tmp2, \tmp, asr#16
uadd16 \res, \res, \tmp
str \res, [\rptr], #4

neg \zeta, \zeta

ldr.w \res, [\rptr]
ldr \poly0, [\aptr], #4
//basemul(r->coeffs + 4 * i + 2, a->coeffs + 4 * i + 2, b->coeffs + 4 * i + 2, - zetas[64 + i]);
smulwt \tmp, \zeta, \poly3
smlabb \tmp, \tmp, \q, \qa
smultt \tmp, \poly0, \tmp
smlabb \tmp, \poly0, \poly3, \tmp
plant_red_b \q, \qa, \qinv, \tmp
// r[0] in upper half of tmp

smuadx \tmp2, \poly0, \poly3
plant_red_b \q, \qa, \qinv, \tmp2
// r[1] in upper half of tmp2
pkhtb \tmp, \tmp2, \tmp, asr#16
uadd16 \res, \res, \tmp
str \res, [\rptr], #4
.endm


// s17: bufptr; s26: state
// Checks if val0 is suitable and multiplies with values from bptr using func
.macro first_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr
.macro first_if
// if (val0 < KYBER_Q)
cmp.w \val0, \q
cmp.w val0, q
bhs.w 2f
strh \val0, [\cptr], #2
add \k, #1
cmp.w \k, #4
strh val0, [cptr], #2
add k, #1
cmp.w k, #4
bne.w 2f
sub \cptr, #4*2
vmov s18, \bufptr
vmov s19, \ctr
vmov s20, \val1
\func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr
vmov \bufptr, s18
vmov \ctr, s19
vmov \val1, s20

add \ctr, #1
slothy_start_1:
sub cptr, #4*2
vmov s18, bufptr
vmov s19, ctr
vmov s20, val1
doublebasemul_asm_acc rptr, bptr, cptr, zetaptr, bufptr, k, val0, val1, q, qa, qinv, tmp, tmp2, ctr
vmov bufptr, s18
vmov ctr, s19
vmov val1, s20

add ctr, #1

movw \k, #0
movw k, #0
slothy_end_1:
2:
.endm

// Checks if val1 is suitable and multiplies with values from bptr using func
.macro second_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr
.macro second_if
// if (val1 < KYBER_Q && ctr < KYBER_N/4)
cmp.w \val1, \q
cmp.w val1, q
bhs.w 2f
cmp.w \ctr, #256/4
cmp.w ctr, #256/4
bge.w 2f
strh \val1, [\cptr], #2
add \k, #1
cmp.w \k, #4
strh val1, [cptr], #2
add k, #1
cmp.w k, #4
bne.w 2f
sub \cptr, #4*2
vmov s18, \bufptr
vmov s19, \ctr
\func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr
vmov \bufptr, s18
vmov \ctr, s19

add \ctr, #1
slothy_start_2:
sub cptr, #4*2
vmov s18, bufptr
vmov s19, ctr
doublebasemul_asm_acc rptr, bptr, cptr, zetaptr, bufptr, k, val0, val1, q, qa, qinv, tmp, tmp2, ctr
vmov bufptr, s18
vmov ctr, s19

add ctr, #1

movw \k, #0
movw k, #0
slothy_end_2:
2:
.endm

Expand Down Expand Up @@ -95,72 +160,13 @@
2:
.endm

// res replace poly2
.macro doublebasemul_asm_acc rptr, aptr, bptr, zetaptr, poly0, poly1, res, poly3, q, qa, qinv, tmp, tmp2, zeta
ldr.w \poly0, [\aptr], #4
ldr.w \poly1, [\bptr]
ldr.w \poly3, [\bptr, #4]
ldr.w \res, [\rptr]
ldr.w \zeta, [\zetaptr], #4

//basemul(r->coeffs + 4 * i, a->coeffs + 4 * i, b->coeffs + 4 * i, zetas[64 + i]);
smulwt \tmp, \zeta, \poly1
// b_1*zeta*qinv*plant_const; in low half
smlabb \tmp, \tmp, \q, \qa
// b_1*zeta
smultt \tmp, \poly0, \tmp
//a_1*b_1*zeta <2^32
smlabb \tmp, \poly0, \poly1, \tmp
// a1*b1*zeta+a0*b0
plant_red_b \q, \qa, \qinv, \tmp
// r[0] in upper half of tmp
smuadx \tmp2, \poly0, \poly1
plant_red_b \q, \qa, \qinv, \tmp2
// r[1] in upper half of tmp2
pkhtb \tmp, \tmp2, \tmp, asr#16
uadd16 \res, \res, \tmp
str \res, [\rptr], #4

neg \zeta, \zeta

ldr.w \res, [\rptr]
ldr \poly0, [\aptr], #4
//basemul(r->coeffs + 4 * i + 2, a->coeffs + 4 * i + 2, b->coeffs + 4 * i + 2, - zetas[64 + i]);
smulwt \tmp, \zeta, \poly3
smlabb \tmp, \tmp, \q, \qa
smultt \tmp, \poly0, \tmp
smlabb \tmp, \poly0, \poly3, \tmp
plant_red_b \q, \qa, \qinv, \tmp
// r[0] in upper half of tmp

smuadx \tmp2, \poly0, \poly3
plant_red_b \q, \qa, \qinv, \tmp2
// r[1] in upper half of tmp2
pkhtb \tmp, \tmp2, \tmp, asr#16
uadd16 \res, \res, \tmp
str \res, [\rptr], #4
.endm

// void matacc_asm(int16_t *r, const int16_t *b, int16_t c[4], unsigned char buf[XOF_BLOCKBYTES+2], const int32_t zetas[64], xof_state *state)
.global matacc_asm_acc
.type matacc_asm_acc, %function
.align 2
matacc_asm_acc:
push {r0-r11, r14}
rptr .req r0
bptr .req r1
cptr .req r2
bufptr .req r3
zetaptr .req r4
val0 .req r5
val1 .req r6
tmp .req r7
tmp2 .req r8
k .req r9
q .req r10
qa .req r11
qinv .req r12
ctr .req r14

ldr.w zetaptr, [sp, #13*4] // load zetaptr from stack
ldr.w tmp, [sp, #14*4] // load state from stack
Expand All @@ -185,9 +191,9 @@ matacc_asm_acc:
ubfx val0, val0, #0, #12
ubfx val1, val1, #0, #12

first_if doublebasemul_asm_acc, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr
first_if

second_if doublebasemul_asm_acc, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr
second_if

third_if tmp, tmp2, rptr, bptr, cptr, bufptr, ctr

Expand Down
Loading

0 comments on commit 5d8e6fc

Please sign in to comment.