From 0ff2cec8e79a9148fb164be2c49fd016ebfe49ef Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Thu, 5 Dec 2024 10:03:02 +0800 Subject: [PATCH] matacc_asm_opt_16_32_kyber --- example.py | 13 +- .../naive/armv7m/matacc_asm_opt_16_32_kyber.s | 155 +++++----- .../matacc_asm_opt_16_32_kyber_opt_m7.s | 267 +++++++++++++----- 3 files changed, 287 insertions(+), 148 deletions(-) diff --git a/example.py b/example.py index 0fefa758..ebe4258a 100644 --- a/example.py +++ b/example.py @@ -2677,8 +2677,17 @@ def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=Non super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname) def core(self, slothy): - # TODO: do actual opt - slothy.write_source_to_file(self.outfile_full) + slothy.config.inputs_are_outputs = True + slothy.config.variable_size = True + slothy.config.unsafe_address_offset_fixup = False + + # TODO: r10, r11, r12 shouldn't actually be needed as q,qa,qinv are unused in this code. + slothy.config.reserved_regs = [f"s{i}" for i in range(0, 32)] + ["sp", "r13"] + ["r10", "r11", "r12"] + + slothy.config.outputs = ["r9"] + slothy.optimize(start="slothy_start_1", end="slothy_end_1") + slothy.config.outputs = ["r9"] + slothy.optimize(start="slothy_start_2", end="slothy_end_2") class matacc_asm_opt_32_32_kyber(Example): def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None): diff --git a/examples/naive/armv7m/matacc_asm_opt_16_32_kyber.s b/examples/naive/armv7m/matacc_asm_opt_16_32_kyber.s index d9c59439..a089309d 100644 --- a/examples/naive/armv7m/matacc_asm_opt_16_32_kyber.s +++ b/examples/naive/armv7m/matacc_asm_opt_16_32_kyber.s @@ -4,6 +4,21 @@ .extern shake128_squeezeblocks + rptr .req r0 + bptr .req r1 + cptr .req r2 + bufptr .req r3 + zetaptr .req r4 + val0 .req r5 + val1 .req r6 + tmp .req r7 + tmp2 .req r8 + k .req r9 + q .req r10 + qa .req r11 + qinv .req r12 + ctr .req r14 + // q locates in the bottom half of the register .macro plant_red_b q, qa, qinv, tmp mul \tmp, \tmp, \qinv @@ -12,63 +27,6 @@ // result in high half .endm -// Checks if val0 is suitable and multiplies with values from bptr using func -.macro first_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr - // if (val0 < KYBER_Q) - cmp.w \val0, \q - bhs.w 2f - strh \val0, [\cptr], #2 - add \k, #1 - cmp.w \k, #4 - bne.w 2f - sub \cptr, #4*2 - vmov s18, \bufptr - vmov s19, \ctr - vmov s20, \val1 - \func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr - vmov \bufptr, s18 - vmov \ctr, s19 - vmov \val1, s20 - - add \ctr, #1 - - movw \k, #0 - 2: -.endm - -// Checks if val1 is suitable and multiplies with values from bptr using func -.macro second_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr -// if (val1 < KYBER_Q && ctr < KYBER_N/4) - cmp.w \val1, \q - bhs.w 2f - cmp.w \ctr, #256/4 - bge.w 2f - strh \val1, [\cptr], #2 - add \k, #1 - cmp.w \k, #4 - bne.w 2f - sub \cptr, #4*2 - vmov s18, \bufptr - vmov s19, \ctr - \func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr - vmov \bufptr, s18 - vmov \ctr, s19 - - add \ctr, #1 - - movw \k, #0 - 2: -.endm - -.macro load_vals val0, val1, bufptr, tmp - ldrh \val0, [\bufptr], #2 - ldrb \val1, [\bufptr], #1 - ubfx \tmp, \val0, #12, #4 - orr \val1, \tmp, \val1, lsl #4 - ubfx \val0, \val0, #0, #12 - ubfx \val1, \val1, #0, #12 -.endm - .macro doublebasemul_asm_opt_16_32 rptr_tmp, aptr, bptr, tmp3, poly0, poly1, poly2, poly3, q, qa, qinv, tmp, aprimeptr, tmp2 vmov \aprimeptr, s27 ldr \poly0, [\aptr], #4 @@ -98,6 +56,69 @@ vmov s27, \aprimeptr .endm + +// Checks if val0 is suitable and multiplies with values from bptr using func +.macro first_if + // if (val0 < KYBER_Q) + cmp.w val0, q + bhs.w 2f + strh val0, [cptr], #2 + add k, #1 + cmp.w k, #4 + bne.w 2f + slothy_start_1: + sub cptr, #4*2 + vmov s18, bufptr + vmov s19, ctr + vmov s20, val1 + doublebasemul_asm_opt_16_32 rptr, bptr, cptr, zetaptr, bufptr, k, val0, val1, q, qa, qinv, tmp, tmp2, ctr + vmov bufptr, s18 + vmov ctr, s19 + vmov val1, s20 + + add ctr, #1 + + movw k, #0 + slothy_end_1: + 2: +.endm + +// Checks if val1 is suitable and multiplies with values from bptr using func +.macro second_if +// if (val1 < KYBER_Q && ctr < KYBER_N/4) + cmp.w val1, q + bhs.w 2f + cmp.w ctr, #256/4 + bge.w 2f + strh val1, [cptr], #2 + add k, #1 + cmp.w k, #4 + bne.w 2f + slothy_start_2: + sub cptr, #4*2 + vmov s18, bufptr + vmov s19, ctr + doublebasemul_asm_opt_16_32 rptr, bptr, cptr, zetaptr, bufptr, k, val0, val1, q, qa, qinv, tmp, tmp2, ctr + vmov bufptr, s18 + vmov ctr, s19 + + add ctr, #1 + + movw k, #0 + slothy_end_2: + 2: +.endm + +.macro load_vals val0, val1, bufptr, tmp + ldrh \val0, [\bufptr], #2 + ldrb \val1, [\bufptr], #1 + ubfx \tmp, \val0, #12, #4 + orr \val1, \tmp, \val1, lsl #4 + ubfx \val0, \val0, #0, #12 + ubfx \val1, \val1, #0, #12 +.endm + + // shake128_squeezeblocks into buffer if all bytes have been used .macro third_if tmp, tmp2, rptr, bptr, cptr, bufptr, ctr // if (pos + 3 > buflen && ctr < KYBER_N/4) @@ -137,20 +158,7 @@ .align 2 matacc_asm_opt_16_32: push {r0-r11, r14} - rptr .req r0 - bptr .req r1 - cptr .req r2 - bufptr .req r3 - tmp3 .req r4 - val0 .req r5 - val1 .req r6 - tmp .req r7 - tmp2 .req r8 - k .req r9 - q .req r10 - qa .req r11 - qinv .req r12 - ctr .req r14 + vpush {s16-s31} movw qa, #26632 movw q, #3329 @@ -172,13 +180,14 @@ matacc_asm_opt_16_32: load_vals val0, val1, bufptr, tmp - first_if doublebasemul_asm_opt_16_32, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, tmp3, k, q, qa, qinv, ctr + first_if - second_if doublebasemul_asm_opt_16_32, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, tmp3, k, q, qa, qinv, ctr + second_if third_if tmp, tmp2, rptr, bptr, cptr, bufptr, ctr cmp ctr, #256/4 blt.w 1b + vpop {s16-s31} pop {r0-r11, pc} \ No newline at end of file diff --git a/examples/opt/armv7m/matacc_asm_opt_16_32_kyber_opt_m7.s b/examples/opt/armv7m/matacc_asm_opt_16_32_kyber_opt_m7.s index 50610c3d..087200b3 100644 --- a/examples/opt/armv7m/matacc_asm_opt_16_32_kyber_opt_m7.s +++ b/examples/opt/armv7m/matacc_asm_opt_16_32_kyber_opt_m7.s @@ -4,6 +4,21 @@ .extern shake128_squeezeblocks + rptr .req r0 + bptr .req r1 + cptr .req r2 + bufptr .req r3 + zetaptr .req r4 + val0 .req r5 + val1 .req r6 + tmp .req r7 + tmp2 .req r8 + k .req r9 + q .req r10 + qa .req r11 + qinv .req r12 + ctr .req r14 + // q locates in the bottom half of the register .macro plant_red_b q, qa, qinv, tmp mul \tmp, \tmp, \qinv @@ -12,63 +27,6 @@ // result in high half .endm -// Checks if val0 is suitable and multiplies with values from bptr using func -.macro first_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr - // if (val0 < KYBER_Q) - cmp.w \val0, \q - bhs.w 2f - strh \val0, [\cptr], #2 - add \k, #1 - cmp.w \k, #4 - bne.w 2f - sub \cptr, #4*2 - vmov s18, \bufptr - vmov s19, \ctr - vmov s20, \val1 - \func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr - vmov \bufptr, s18 - vmov \ctr, s19 - vmov \val1, s20 - - add \ctr, #1 - - movw \k, #0 - 2: -.endm - -// Checks if val1 is suitable and multiplies with values from bptr using func -.macro second_if func, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, zetaptr, k, q, qa, qinv, ctr -// if (val1 < KYBER_Q && ctr < KYBER_N/4) - cmp.w \val1, \q - bhs.w 2f - cmp.w \ctr, #256/4 - bge.w 2f - strh \val1, [\cptr], #2 - add \k, #1 - cmp.w \k, #4 - bne.w 2f - sub \cptr, #4*2 - vmov s18, \bufptr - vmov s19, \ctr - \func \rptr, \bptr, \cptr, \zetaptr, \bufptr, \k, \val0, \val1, \q, \qa, \qinv, \tmp, \tmp2, \ctr - vmov \bufptr, s18 - vmov \ctr, s19 - - add \ctr, #1 - - movw \k, #0 - 2: -.endm - -.macro load_vals val0, val1, bufptr, tmp - ldrh \val0, [\bufptr], #2 - ldrb \val1, [\bufptr], #1 - ubfx \tmp, \val0, #12, #4 - orr \val1, \tmp, \val1, lsl #4 - ubfx \val0, \val0, #0, #12 - ubfx \val1, \val1, #0, #12 -.endm - .macro doublebasemul_asm_opt_16_32 rptr_tmp, aptr, bptr, tmp3, poly0, poly1, poly2, poly3, q, qa, qinv, tmp, aprimeptr, tmp2 vmov \aprimeptr, s27 ldr \poly0, [\aptr], #4 @@ -98,6 +56,181 @@ vmov s27, \aprimeptr .endm + +// Checks if val0 is suitable and multiplies with values from bptr using func +.macro first_if + // if (val0 < KYBER_Q) + cmp.w val0, q + bhs.w 2f + strh val0, [cptr], #2 + add k, #1 + cmp.w k, #4 + bne.w 2f + slothy_start_1: + // Instructions: 25 + // Expected cycles: 13 + // Expected IPC: 1.92 + // + // Cycle bound: 13.0 + // IPC bound: 1.92 + // + // Wall time: 0.15s + // User time: 0.15s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + vmov s19, r14 // *............................. + sub r2, #4*2 // *............................. + ldr r4, [r2] // .*............................ + vmov r9, s27 // .*............................ + ldr.w r8, [r9], #4 // ..*........................... + vmov s20, r6 // ..*........................... + ldr r7, [r1], #4 // ...*.......................... + vmov r14, s19 // ...*.......................... + ldr.w r5, [r2, #4] // ....*......................... + smuad r6, r8, r4 // ....*......................... + ldr r8, [r9], #4 // .....*........................ + smuadx r4, r7, r4 // .....*........................ + str r6, [r0], #4 // ......*....................... + ldr r7, [r1], #4 // ......*....................... + vmov s18, r3 // .......*...................... + smuad r3, r8, r5 // .......*...................... + str r4, [r0], #4 // ........*..................... + vmov r6, s20 // ........*..................... + smuadx r8, r7, r5 // .........*.................... + str.w r3, [r0], #4 // ..........*................... + vmov s27, r9 // ..........*................... + str.w r8, [r0], #4 // ...........*.................. + vmov r3, s18 // ...........*.................. + add r14, #1 // ............*................. + movw r9, #0 // ............*................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // sub r2, #4*2 // *.............................. + // vmov s18, r3 // .......*....................... + // vmov s19, r14 // *.............................. + // vmov s20, r6 // ..*............................ + // vmov r8, s27 // .*............................. + // ldr r3, [r1], #4 // ...*........................... + // ldr r9, [r2] // .*............................. + // ldr r5, [r1], #4 // ......*........................ + // ldr.w r6, [r2, #4] // ....*.......................... + // ldr.w r14, [r8], #4 // ..*............................ + // smuad r7, r14, r9 // ....*.......................... + // smuadx r4, r3, r9 // .....*......................... + // str r7, [r0], #4 // ......*........................ + // str r4, [r0], #4 // ........*...................... + // ldr r7, [r8], #4 // .....*......................... + // smuad r14, r7, r6 // .......*....................... + // smuadx r4, r5, r6 // .........*..................... + // str.w r14, [r0], #4 // ..........*.................... + // str.w r4, [r0], #4 // ...........*................... + // vmov s27, r8 // ..........*.................... + // vmov r3, s18 // ...........*................... + // vmov r14, s19 // ...*........................... + // vmov r6, s20 // ........*...................... + // add r14, #1 // ............*.................. + // movw r9, #0 // ............*.................. + + slothy_end_1: + + 2: +.endm + +// Checks if val1 is suitable and multiplies with values from bptr using func +.macro second_if +// if (val1 < KYBER_Q && ctr < KYBER_N/4) + cmp.w val1, q + bhs.w 2f + cmp.w ctr, #256/4 + bge.w 2f + strh val1, [cptr], #2 + add k, #1 + cmp.w k, #4 + bne.w 2f + slothy_start_2: + // Instructions: 23 + // Expected cycles: 12 + // Expected IPC: 1.92 + // + // Cycle bound: 12.0 + // IPC bound: 1.92 + // + // Wall time: 0.12s + // User time: 0.12s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + sub r2, #4*2 // *............................. + ldr r8, [r1], #4 // *............................. + ldr r7, [r2] // .*............................ + vmov r9, s27 // .*............................ + ldr.w r6, [r9], #4 // ..*........................... + vmov s18, r3 // ..*........................... + vmov s19, r14 // ...*.......................... + smuadx r8, r8, r7 // ...*.......................... + ldr r14, [r9], #4 // ....*......................... + smuad r7, r6, r7 // ....*......................... + ldr.w r5, [r2, #4] // .....*........................ + str r7, [r0], #4 // .....*........................ + ldr r4, [r1], #4 // ......*....................... + str r8, [r0], #4 // ......*....................... + vmov r3, s18 // .......*...................... + smuad r6, r14, r5 // .......*...................... + smuadx r14, r4, r5 // ........*..................... + str.w r6, [r0], #4 // .........*.................... + vmov s27, r9 // .........*.................... + str.w r14, [r0], #4 // ..........*................... + vmov r14, s19 // ..........*................... + add r14, #1 // ...........*.................. + movw r9, #0 // ...........*.................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // sub r2, #4*2 // *.............................. + // vmov s18, r3 // ..*............................ + // vmov s19, r14 // ...*........................... + // vmov r8, s27 // .*............................. + // ldr r3, [r1], #4 // *.............................. + // ldr r9, [r2] // .*............................. + // ldr r5, [r1], #4 // ......*........................ + // ldr.w r6, [r2, #4] // .....*......................... + // ldr.w r14, [r8], #4 // ..*............................ + // smuad r7, r14, r9 // ....*.......................... + // smuadx r4, r3, r9 // ...*........................... + // str r7, [r0], #4 // .....*......................... + // str r4, [r0], #4 // ......*........................ + // ldr r7, [r8], #4 // ....*.......................... + // smuad r14, r7, r6 // .......*....................... + // smuadx r4, r5, r6 // ........*...................... + // str.w r14, [r0], #4 // .........*..................... + // str.w r4, [r0], #4 // ..........*.................... + // vmov s27, r8 // .........*..................... + // vmov r3, s18 // .......*....................... + // vmov r14, s19 // ..........*.................... + // add r14, #1 // ...........*................... + // movw r9, #0 // ...........*................... + + slothy_end_2: + + 2: +.endm + +.macro load_vals val0, val1, bufptr, tmp + ldrh \val0, [\bufptr], #2 + ldrb \val1, [\bufptr], #1 + ubfx \tmp, \val0, #12, #4 + orr \val1, \tmp, \val1, lsl #4 + ubfx \val0, \val0, #0, #12 + ubfx \val1, \val1, #0, #12 +.endm + + // shake128_squeezeblocks into buffer if all bytes have been used .macro third_if tmp, tmp2, rptr, bptr, cptr, bufptr, ctr // if (pos + 3 > buflen && ctr < KYBER_N/4) @@ -137,20 +270,7 @@ .align 2 matacc_asm_opt_16_32_opt_m7: push {r0-r11, r14} - rptr .req r0 - bptr .req r1 - cptr .req r2 - bufptr .req r3 - tmp3 .req r4 - val0 .req r5 - val1 .req r6 - tmp .req r7 - tmp2 .req r8 - k .req r9 - q .req r10 - qa .req r11 - qinv .req r12 - ctr .req r14 + vpush {s16-s31} movw qa, #26632 movw q, #3329 @@ -172,13 +292,14 @@ matacc_asm_opt_16_32_opt_m7: load_vals val0, val1, bufptr, tmp - first_if doublebasemul_asm_opt_16_32, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, tmp3, k, q, qa, qinv, ctr + first_if - second_if doublebasemul_asm_opt_16_32, tmp, tmp2, val0, val1, rptr, bptr, cptr, bufptr, tmp3, k, q, qa, qinv, ctr + second_if third_if tmp, tmp2, rptr, bptr, cptr, bufptr, ctr cmp ctr, #256/4 blt.w 1b + vpop {s16-s31} pop {r0-r11, pc} \ No newline at end of file