From ce534fa10ca7e6bcef2854380b53f8e18070d56a Mon Sep 17 00:00:00 2001 From: Amin Abdulrahman Date: Mon, 7 Oct 2024 13:23:07 +0200 Subject: [PATCH] Add dilithium 257 NTT, iNTT, basemul * Includes addition to the parser that allows register ranges (e.g., for vldm) --- example.py | 85 +++- .../armv7m/basemul_257_asymmetric_dilithium.s | 46 +++ examples/naive/armv7m/basemul_257_dilithium.s | 61 +++ examples/naive/armv7m/fnt_257_dilithium.s | 301 ++++++++++++++ examples/naive/armv7m/ifnt_257_dilithium.s | 372 ++++++++++++++++++ slothy/targets/arm_v7m/arch_v7m.py | 97 +++++ 6 files changed, 961 insertions(+), 1 deletion(-) create mode 100644 examples/naive/armv7m/basemul_257_asymmetric_dilithium.s create mode 100644 examples/naive/armv7m/basemul_257_dilithium.s create mode 100644 examples/naive/armv7m/fnt_257_dilithium.s create mode 100644 examples/naive/armv7m/ifnt_257_dilithium.s diff --git a/example.py b/example.py index 191cc4ee..b0d96b2b 100644 --- a/example.py +++ b/example.py @@ -1547,6 +1547,85 @@ def core(self, slothy): slothy.optimize(start="pointwise_montgomery_acc_start", end="pointwise_montgomery_acc_end") +class fnt_257_dilithium(Example): + def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None): + name = "fnt_257_dilithium" + infile = name + funcname = "__asm_fnt_257" + + if var != "": + name += f"_{var}" + infile += f"_{var}" + name += f"_{target_label_dict[target]}" + + super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname) + + def core(self, slothy): + slothy.config.outputs = ["r14", "r12"] + slothy.config.inputs_are_outputs = True + + slothy.optimize(start="_fnt_0_1_2_start", end="_fnt_0_1_2_end") + slothy.optimize(start="_fnt_3_4_5_6_start", end="_fnt_3_4_5_6_end") + slothy.optimize(start="_fnt_to_16_bit_start", end="_fnt_to_16_bit_end") + +class ifnt_257_dilithium(Example): + def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None): + name = "ifnt_257_dilithium" + infile = name + funcname = "__asm_ifnt_257" + + if var != "": + name += f"_{var}" + infile += f"_{var}" + name += f"_{target_label_dict[target]}" + + super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname) + + def core(self, slothy): + slothy.config.outputs = ["r14", "s1", "r12"] + slothy.config.inputs_are_outputs = True + + slothy.optimize(start="_ifnt_7_6_5_4_start", end="_ifnt_7_6_5_4_end") + slothy.optimize(start="_ifnt_0_1_2_start", end="_ifnt_0_1_2_end") + +class basemul_257_dilithium(Example): + def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None): + name = "basemul_257_dilithium" + infile = name + funcname = "__asm_point_mul_257_16" + + if var != "": + name += f"_{var}" + infile += f"_{var}" + name += f"_{target_label_dict[target]}" + + super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname) + + def core(self, slothy): + slothy.config.outputs = ["r14", "r12"] + slothy.config.inputs_are_outputs = True + + slothy.optimize(start="_point_mul_16_loop_start", end="_point_mul_16_loop_end") + +class basemul_257_asymmetric_dilithium(Example): + def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7, timeout=None): + name = "basemul_257_asymmetric_dilithium" + infile = name + funcname = "__asm_asymmetric_mul_257_16" + + if var != "": + name += f"_{var}" + infile += f"_{var}" + name += f"_{target_label_dict[target]}" + + super().__init__(infile, name, rename=True, arch=arch, target=target, timeout=timeout, funcname=funcname) + + def core(self, slothy): + slothy.config.outputs = ["r14", "r12"] + slothy.config.inputs_are_outputs = True + + slothy.optimize(start="_asymmetric_mul_16_loop_start", end="_asymmetric_mul_16_loop_end") + def main(): examples = [ ExampleDilithium(), ExampleKyber(), @@ -1697,7 +1776,11 @@ def main(): intt_dilithium_123_456_78(), pointwise_montgomery_dilithium(), - pointwise_acc_montgomery_dilithium() + pointwise_acc_montgomery_dilithium(), + fnt_257_dilithium(), + ifnt_257_dilithium(), + basemul_257_dilithium(), + basemul_257_asymmetric_dilithium() ] all_example_names = [e.name for e in examples] diff --git a/examples/naive/armv7m/basemul_257_asymmetric_dilithium.s b/examples/naive/armv7m/basemul_257_asymmetric_dilithium.s new file mode 100644 index 00000000..0bc5ba08 --- /dev/null +++ b/examples/naive/armv7m/basemul_257_asymmetric_dilithium.s @@ -0,0 +1,46 @@ +// 2 +.macro barrett_32 a, Qbar, Q, tmp + smmulr.w \tmp, \a, \Qbar + mls.w \a, \tmp, \Q, \a +.endm + +.syntax unified +.cpu cortex-m4 + +.align 2 +.global __asm_asymmetric_mul_257_16 +.type __asm_asymmetric_mul_257_16, %function +__asm_asymmetric_mul_257_16: + push.w {r4-r11, lr} + + .equ width, 4 + + add.w r12, r0, #256*width + _asymmetric_mul_16_loop: + _asymmetric_mul_16_loop_start: + + ldr.w r7, [r1, #width] + ldr.w r4, [r1], #2*width + ldr.w r8, [r2, #width] + ldr.w r5, [r2], #2*width + ldr.w r9, [r3, #width] + ldr.w r6, [r3], #2*width + + smuad r10, r4, r6 + smuadx r11, r4, r5 + + str.w r11, [r0, #width] + str.w r10, [r0], #2*width + + smuad r10, r7, r9 + smuadx r11, r7, r8 + + str.w r11, [r0, #width] + str.w r10, [r0], #2*width + + _asymmetric_mul_16_loop_end: + + cmp.w r0, r12 + bne.w _asymmetric_mul_16_loop + + pop.w {r4-r11, pc} \ No newline at end of file diff --git a/examples/naive/armv7m/basemul_257_dilithium.s b/examples/naive/armv7m/basemul_257_dilithium.s new file mode 100644 index 00000000..f6821fbb --- /dev/null +++ b/examples/naive/armv7m/basemul_257_dilithium.s @@ -0,0 +1,61 @@ +// 2 +.macro barrett_32 a, Qbar, Q, tmp + smmulr.w \tmp, \a, \Qbar + mls.w \a, \tmp, \Q, \a +.endm + +.syntax unified +.cpu cortex-m4 + +.align 2 +.global __asm_point_mul_257_16 +.type __asm_point_mul_257_16, %function +__asm_point_mul_257_16: + push.w {r4-r11, lr} + + ldr.w r14, [sp, #36] + + .equ width, 4 + + add.w r12, r14, #64*width + _point_mul_16_loop: + _point_mul_16_loop_start: + + ldr.w r7, [r1, #2*width] + ldr.w r8, [r1, #3*width] + ldr.w r9, [r14, #1*width] + ldr.w r5, [r1, #1*width] + ldr.w r4, [r1], #4*width + ldr.w r6, [r14], #2*width + + smultb r10, r4, r6 + barrett_32 r10, r2, r3, r11 + pkhbt r4, r4, r10, lsl #16 + + neg.w r6, r6 + + smultb r10, r5, r6 + barrett_32 r10, r2, r3, r11 + pkhbt r5, r5, r10, lsl #16 + + str.w r5, [r0, #1*width] + str.w r4, [r0], #2*width + + smultb r10, r7, r9 + barrett_32 r10, r2, r3, r11 + pkhbt r7, r7, r10, lsl #16 + + neg.w r9, r9 + + smultb r10, r8, r9 + barrett_32 r10, r2, r3, r11 + pkhbt r8, r8, r10, lsl #16 + + str.w r8, [r0, #1*width] + str.w r7, [r0], #2*width + + _point_mul_16_loop_end: + cmp.w r14, r12 + bne.w _point_mul_16_loop + + pop.w {r4-r11, pc} diff --git a/examples/naive/armv7m/fnt_257_dilithium.s b/examples/naive/armv7m/fnt_257_dilithium.s new file mode 100644 index 00000000..31e4ea8a --- /dev/null +++ b/examples/naive/armv7m/fnt_257_dilithium.s @@ -0,0 +1,301 @@ +// 4 +.macro ldrstr4 ldrstr, target, c0, c1, c2, c3, mem0, mem1, mem2, mem3 + \ldrstr \c0, [\target, \mem0] + \ldrstr \c1, [\target, \mem1] + \ldrstr \c2, [\target, \mem2] + \ldrstr \c3, [\target, \mem3] +.endm + +// 4 +.macro ldrstr4jump ldrstr, target, c0, c1, c2, c3, mem1, mem2, mem3, jump + \ldrstr \c1, [\target, \mem1] + \ldrstr \c2, [\target, \mem2] + \ldrstr \c3, [\target, \mem3] + \ldrstr \c0, [\target], \jump +.endm + +// 8 +.macro ldrstrvec ldrstr, target, c0, c1, c2, c3, c4, c5, c6, c7, mem0, mem1, mem2, mem3, mem4, mem5, mem6, mem7 + ldrstr4 \ldrstr, \target, \c0, \c1, \c2, \c3, \mem0, \mem1, \mem2, \mem3 + ldrstr4 \ldrstr, \target, \c4, \c5, \c6, \c7, \mem4, \mem5, \mem6, \mem7 +.endm + +// 8 +.macro ldrstrvecjump ldrstr, target, c0, c1, c2, c3, c4, c5, c6, c7, mem1, mem2, mem3, mem4, mem5, mem6, mem7, jump + ldrstr4 \ldrstr, \target, \c4, \c5, \c6, \c7, \mem4, \mem5, \mem6, \mem7 + ldrstr4jump \ldrstr, \target, \c0, \c1, \c2, \c3, \mem1, \mem2, \mem3, \jump +.endm + +// 2 +.macro barrett_32 a, Qbar, Q, tmp + smmulr.w \tmp, \a, \Qbar + mls.w \a, \tmp, \Q, \a +.endm + +.macro FNT_CT_butterfly c0, c1, logW + add.w \c0, \c0, \c1, lsl #\logW + sub.w \c1, \c0, \c1, lsl #(\logW+1) +.endm + +// 46 +.macro _3_layer_CT_32_FNT c0, c1, c2, c3, c4, c5, c6, c7, xi0, xi1, xi2, xi3, xi4, xi5, xi6, twiddle, Qprime, Q, tmp, tmp2 + vmov.w \twiddle, \xi0 + + // c0, c1, c2, c3, c4, c5, c6, c7, c8 + // 0,4 + mla \tmp, \c4, \twiddle, \c0 + mls \c4, \c4, \twiddle, \c0 + + // 1,5 + mla \c0, \c5, \twiddle, \c1 + mls \c5, \c5, \twiddle, \c1 + + // 2,6 + mla \c1, \c6, \twiddle, \c2 + mls \c6, \c6, \twiddle, \c2 + + // 3,7 + mla \c2, \c7, \twiddle, \c3 + mls \c7, \c7, \twiddle, \c3 + + // tmp, c0, c1, c2, c4, c5, c6, c7 + + barrett_32 \tmp, \Qprime, \Q, \c3 + barrett_32 \c0, \Qprime, \Q, \c3 + barrett_32 \c1, \Qprime, \Q, \c3 + barrett_32 \c2, \Qprime, \Q, \c3 + barrett_32 \c4, \Qprime, \Q, \c3 + barrett_32 \c5, \Qprime, \Q, \c3 + barrett_32 \c6, \Qprime, \Q, \c3 + barrett_32 \c7, \Qprime, \Q, \c3 + + vmov.w \twiddle, \xi1 + // 0,2 + mla \tmp2, \c1, \twiddle, \tmp + mls \c3, \c1, \twiddle, \tmp + + // 1,3 + mla \tmp, \c2, \twiddle, \c0 + mls \c0, \c2, \twiddle, \c0 + + vmov.w \twiddle, \xi2 + + // 4,6 + mla \c2, \c6, \twiddle, \c4 + mls \c1, \c6, \twiddle, \c4 + + // 5,7 + mla \c6, \c7, \twiddle, \c5 + mls \c7, \c7, \twiddle, \c5 + + // tmp2, tmp, c3, c0 | c2, c6, c1, c7 + + // 4,5 + vmov.w \twiddle, \xi5 + mla \c4, \c6, \twiddle, \c2 + mls \c5, \c6, \twiddle, \c2 + + // 6,7 + vmov.w \twiddle, \xi6 + mla \c6, \c7, \twiddle, \c1 + mls \c7, \c7, \twiddle, \c1 + + // 2,3 + vmov.w \twiddle, \xi4 + mla \c2, \c0, \twiddle, \c3 + mls \c3, \c0, \twiddle, \c3 + + // 0,1 + vmov.w \twiddle, \xi3 + mla \c0, \tmp, \twiddle, \tmp2 + mls \c1, \tmp, \twiddle, \tmp2 +.endm + +.macro final_butterfly c0, c1f, twiddle, c0out, c1, qprime, q, tmp + vmov.w \c1, \c1f + vmov.w \tmp, \twiddle + + mla \c0out, \c1, \tmp, \c0 + mls \c1, \c1, \tmp, \c0 + + barrett_32 \c0out, \qprime, \q, \tmp + barrett_32 \c1, \qprime, \q, \tmp +.endm + + +.syntax unified +.cpu cortex-m4 + +.align 2 +.global __asm_fnt_257 +.type __asm_fnt_257, %function +__asm_fnt_257: + push.w {r4-r11, lr} + vpush.w {s16-s27} + + vmov.w s27, r1 + + .equ width, 4 + + add.w r12, r0, #32*width + _fnt_0_1_2: + _fnt_0_1_2_start: + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(32*0*width), #(32*1*width), #(32*2*width), #(32*3*width), #(32*4*width), #(32*5*width), #(32*6*width), #(32*7*width) + + FNT_CT_butterfly r4, r8, 4 + FNT_CT_butterfly r5, r9, 4 + FNT_CT_butterfly r6, r10, 4 + FNT_CT_butterfly r7, r11, 4 + + FNT_CT_butterfly r4, r6, 2 + FNT_CT_butterfly r5, r7, 2 + FNT_CT_butterfly r8, r10, 6 + FNT_CT_butterfly r9, r11, 6 + + FNT_CT_butterfly r4, r5, 1 + FNT_CT_butterfly r6, r7, 5 + FNT_CT_butterfly r8, r9, 3 + FNT_CT_butterfly r10, r11, 7 + + ldrstrvecjump str.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(32*1*width), #(32*2*width), #(32*3*width), #(32*4*width), #(32*5*width), #(32*6*width), #(32*7*width), #width + _fnt_0_1_2_end: + cmp.w r0, r12 + bne.w _fnt_0_1_2 + + sub.w r0, r0, #32*width + + add.w r12, r0, #256*width + vmov.w s25, r12 + _fnt_3_4_5_6: + _fnt_3_4_5_6_start: + vmov r1, s27 + vldm.w r1!, {s2-s16} + vmov s27, r1 + + // rep 1 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(4*0*width+2*width), #(4*1*width+2*width), #(4*2*width+2*width), #(4*3*width+2*width), #(4*4*width+2*width), #(4*5*width+2*width), #(4*6*width+2*width), #(4*7*width+2*width) + + _3_layer_CT_32_FNT r4, r5, r6, r7, r8, r9, r10, r11, s2, s3, s4, s5, s6, s7, s8, r14, r2, r3, r1, r12 + + vmov.w s17, s18, r4, r5 // a1, a3 + vmov.w s19, s20, r6, r7 // a5, a7 + vmov.w s21, s22, r8, r9 // a9, a11 + vmov.w s23, s24, r10, r11 // a13, a15 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(4*0*width), #(4*1*width), #(4*2*width), #(4*3*width), #(4*4*width), #(4*5*width), #(4*6*width), #(4*7*width) + + _3_layer_CT_32_FNT r4, r5, r6, r7, r8, r9, r10, r11, s2, s3, s4, s5, s6, s7, s8, r14, r2, r3, r1, r12 + + final_butterfly r5, s18, s10, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*1*width+2*width)] + str.w r1, [r0, #(4*1*width)] + + final_butterfly r6, s19, s11, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*2*width+2*width)] + str.w r1, [r0, #(4*2*width)] + + final_butterfly r7, s20, s12, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*3*width+2*width)] + str.w r1, [r0, #(4*3*width)] + + final_butterfly r8, s21, s13, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*4*width+2*width)] + str.w r1, [r0, #(4*4*width)] + + final_butterfly r9, s22, s14, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*5*width+2*width)] + str.w r1, [r0, #(4*5*width)] + + final_butterfly r10, s23, s15, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*6*width+2*width)] + str.w r1, [r0, #(4*6*width)] + + final_butterfly r11, s24, s16, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*7*width+2*width)] + str.w r1, [r0, #(4*7*width)] + + final_butterfly r4, s17, s9, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*0*width+2*width)] + str.w r1, [r0], #width + + // rep 2 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(4*0*width+2*width), #(4*1*width+2*width), #(4*2*width+2*width), #(4*3*width+2*width), #(4*4*width+2*width), #(4*5*width+2*width), #(4*6*width+2*width), #(4*7*width+2*width) + + _3_layer_CT_32_FNT r4, r5, r6, r7, r8, r9, r10, r11, s2, s3, s4, s5, s6, s7, s8, r14, r2, r3, r1, r12 + + vmov.w s17, s18, r4, r5 // a1, a3 + vmov.w s19, s20, r6, r7 // a5, a7 + vmov.w s21, s22, r8, r9 // a9, a11 + vmov.w s23, s24, r10, r11 // a13, a15 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(4*0*width), #(4*1*width), #(4*2*width), #(4*3*width), #(4*4*width), #(4*5*width), #(4*6*width), #(4*7*width) + + _3_layer_CT_32_FNT r4, r5, r6, r7, r8, r9, r10, r11, s2, s3, s4, s5, s6, s7, s8, r14, r2, r3, r1, r12 + + final_butterfly r5, s18, s10, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*1*width+2*width)] + str.w r1, [r0, #(4*1*width)] + + final_butterfly r6, s19, s11, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*2*width+2*width)] + str.w r1, [r0, #(4*2*width)] + + final_butterfly r7, s20, s12, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*3*width+2*width)] + str.w r1, [r0, #(4*3*width)] + + final_butterfly r8, s21, s13, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*4*width+2*width)] + str.w r1, [r0, #(4*4*width)] + + final_butterfly r9, s22, s14, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*5*width+2*width)] + str.w r1, [r0, #(4*5*width)] + + final_butterfly r10, s23, s15, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*6*width+2*width)] + str.w r1, [r0, #(4*6*width)] + + final_butterfly r11, s24, s16, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*7*width+2*width)] + str.w r1, [r0, #(4*7*width)] + + final_butterfly r4, s17, s9, r1, r12, r2, r3, r14 + str.w r12, [r0, #(4*0*width+2*width)] + str.w r1, [r0], #width + add.w r0, #((32-2)*width) + + vmov.w r12, s25 + _fnt_3_4_5_6_end: + cmp.w r0, r12 + bne.w _fnt_3_4_5_6 + + # switch to 16-bit representation + sub.w r0, r0, #256*width + mov.w r1, r0 + _fnt_to_16_bit: + _fnt_to_16_bit_start: + ldr.w r3, [r0, #1*width] + ldr.w r4, [r0, #2*width] + ldr.w r5, [r0, #3*width] + ldr.w r6, [r0, #4*width] + ldr.w r7, [r0, #5*width] + ldr.w r8, [r0, #6*width] + ldr.w r9, [r0, #7*width] + ldr.w r2, [r0], #8*width + strh.w r3, [r1, #1*2] + strh.w r4, [r1, #2*2] + strh.w r5, [r1, #3*2] + strh.w r6, [r1, #4*2] + strh.w r7, [r1, #5*2] + strh.w r8, [r1, #6*2] + strh.w r9, [r1, #7*2] + strh.w r2, [r1], #8*2 + _fnt_to_16_bit_end: + cmp.w r0, r12 + bne.w _fnt_to_16_bit + + vpop.w {s16-s27} + pop.w {r4-r11, pc} diff --git a/examples/naive/armv7m/ifnt_257_dilithium.s b/examples/naive/armv7m/ifnt_257_dilithium.s new file mode 100644 index 00000000..692352be --- /dev/null +++ b/examples/naive/armv7m/ifnt_257_dilithium.s @@ -0,0 +1,372 @@ +// 4 +.macro ldrstr4 ldrstr, target, c0, c1, c2, c3, mem0, mem1, mem2, mem3 + \ldrstr \c0, [\target, \mem0] + \ldrstr \c1, [\target, \mem1] + \ldrstr \c2, [\target, \mem2] + \ldrstr \c3, [\target, \mem3] +.endm + +// 4 +.macro ldrstr4jump ldrstr, target, c0, c1, c2, c3, mem1, mem2, mem3, jump + \ldrstr \c1, [\target, \mem1] + \ldrstr \c2, [\target, \mem2] + \ldrstr \c3, [\target, \mem3] + \ldrstr \c0, [\target], \jump +.endm + +// 8 +.macro ldrstrvec ldrstr, target, c0, c1, c2, c3, c4, c5, c6, c7, mem0, mem1, mem2, mem3, mem4, mem5, mem6, mem7 + ldrstr4 \ldrstr, \target, \c0, \c1, \c2, \c3, \mem0, \mem1, \mem2, \mem3 + ldrstr4 \ldrstr, \target, \c4, \c5, \c6, \c7, \mem4, \mem5, \mem6, \mem7 +.endm + +// 8 +.macro ldrstrvecjump ldrstr, target, c0, c1, c2, c3, c4, c5, c6, c7, mem1, mem2, mem3, mem4, mem5, mem6, mem7, jump + ldrstr4 \ldrstr, \target, \c4, \c5, \c6, \c7, \mem4, \mem5, \mem6, \mem7 + ldrstr4jump \ldrstr, \target, \c0, \c1, \c2, \c3, \mem1, \mem2, \mem3, \jump +.endm + +.macro addSub1 c0, c1 + add.w \c0, \c1 + sub.w \c1, \c0, \c1, lsl #1 +.endm + +.macro addSub2 c0, c1, c2, c3 + add \c0, \c1 + add \c2, \c3 + sub.w \c1, \c0, \c1, lsl #1 + sub.w \c3, \c2, \c3, lsl #1 +.endm + +.macro addSub4 c0, c1, c2, c3, c4, c5, c6, c7 + add \c0, \c1 + add \c2, \c3 + add \c4, \c5 + add \c6, \c7 + sub.w \c1, \c0, \c1, lsl #1 + sub.w \c3, \c2, \c3, lsl #1 + sub.w \c5, \c4, \c5, lsl #1 + sub.w \c7, \c6, \c7, lsl #1 +.endm + +// 2 +.macro barrett_32 a, Qbar, Q, tmp + smmulr.w \tmp, \a, \Qbar + mls.w \a, \tmp, \Q, \a +.endm + +.macro shift_subAdd c0, c1, shlv + sub.w \c0, \c0, \c1, lsl #(\shlv) + add.w \c1, \c0, \c1, lsl #(\shlv+1) +.endm + +.macro FNT_CT_ibutterfly c0, c1, shlv + shift_subAdd \c0, \c1, \shlv +.endm + +.macro final_butterfly c0, c1, c1f, twiddle + vmov.w \c1, \c1f + add.w \c0, \c1 + sub.w \c1, \c0, \c1, lsl #1 + mul.w \c1, \twiddle +.endm + +.macro final_butterfly2 c0, c0out, c1, c1f, twiddle, twiddle2 + vmov.w \c1, \c1f + mla.w \c0out, \twiddle2, \c1, \c0 + mls.w \c1, \twiddle2, \c1, \c0 + mul.w \c1, \twiddle +.endm + +.syntax unified +.cpu cortex-m4 +.align 2 +.global __asm_ifnt_257 +.type __asm_ifnt_257, %function +__asm_ifnt_257: + push.w {r4-r11, lr} + vpush.w {s16-s24} + + .equ width, 4 + + add.w r12, r0, #256*width + vmov.w s1, r12 + _ifnt_7_6_5_4: + _ifnt_7_6_5_4_start: + + vldm.w r1!, {s2-s16} + +// ================ + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(2*8*width), #(2*9*width), #(2*10*width), #(2*11*width), #(2*12*width), #(2*13*width), #(2*14*width), #(2*15*width) + + addSub4 r4, r5, r6, r7, r8, r9, r10, r11 + vmov.w r14, s6 + mul.w r5, r5, r14 + vmov.w r14, s8 + mul.w r9, r9, r14 + addSub2 r4, r6, r8, r10 + vmov.w r14, s7 + mla.w r12, r7, r14, r5 + mls.w r7, r7, r14, r5 + vmov.w r14, s9 + mla.w r5, r11, r14, r9 + mls.w r11, r11, r14, r9 + + // r4, r12, r6, r7, r8, r5, r10, r11 + + vmov.w r14, s12 + mul.w r6, r6, r14 + mul.w r7, r7, r14 + vmov.w r14, s13 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + + barrett_32 r4, r2, r3, r14 + barrett_32 r12, r2, r3, r14 + barrett_32 r6, r2, r3, r14 + barrett_32 r7, r2, r3, r14 + barrett_32 r8, r2, r3, r14 + barrett_32 r5, r2, r3, r14 + barrett_32 r10, r2, r3, r14 + barrett_32 r11, r2, r3, r14 + + addSub4 r4, r8, r6, r10, r12, r5, r7, r11 + + vmov.w s17, s18, r4, r12 + vmov.w s19, s20, r6, r7 + vmov.w s21, s22, r8, r5 + vmov.w s23, s24, r10, r11 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(2*0*width), #(2*1*width), #(2*2*width), #(2*3*width), #(2*4*width), #(2*5*width), #(2*6*width), #(2*7*width) + + addSub4 r4, r5, r6, r7, r8, r9, r10, r11 + vmov.w r14, s2 + mul.w r5, r5, r14 + vmov.w r14, s4 + mul.w r9, r9, r14 + addSub2 r4, r6, r8, r10 + vmov.w r14, s3 + mla.w r12, r7, r14, r5 + mls.w r7, r7, r14, r5 + vmov.w r14, s5 + mla.w r5, r11, r14, r9 + mls.w r11, r11, r14, r9 + + // r4, r12, r6, r7, r8, r5, r10, r11 + + vmov.w r14, s10 + mul.w r6, r6, r14 + mul.w r7, r7, r14 + vmov.w r14, s11 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + + barrett_32 r4, r2, r3, r14 + barrett_32 r12, r2, r3, r14 + barrett_32 r6, r2, r3, r14 + barrett_32 r7, r2, r3, r14 + barrett_32 r8, r2, r3, r14 + barrett_32 r5, r2, r3, r14 + barrett_32 r10, r2, r3, r14 + barrett_32 r11, r2, r3, r14 + + addSub4 r4, r8, r6, r10, r12, r5, r7, r11 + vmov.w r14, s14 + mul.w r8, r8, r14 + mul.w r5, r5, r14 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + vmov.w r14, s16 + final_butterfly r12, r9, s18, r14 + str.w r12, [r0, #(2*1*width)] + str.w r9, [r0, #(2*9*width)] + final_butterfly r6, r9, s19, r14 + str.w r6, [r0, #(2*2*width)] + str.w r9, [r0, #(2*10*width)] + final_butterfly r7, r9, s20, r14 + str.w r7, [r0, #(2*3*width)] + str.w r9, [r0, #(2*11*width)] + vmov.w r12, s15 + final_butterfly2 r8, r6, r9, s21, r14, r12 + str.w r6, [r0, #(2*4*width)] + str.w r9, [r0, #(2*12*width)] + final_butterfly2 r5, r6, r9, s22, r14, r12 + str.w r6, [r0, #(2*5*width)] + str.w r9, [r0, #(2*13*width)] + final_butterfly2 r10, r6, r9, s23, r14, r12 + str.w r6, [r0, #(2*6*width)] + str.w r9, [r0, #(2*14*width)] + final_butterfly2 r11, r6, r9, s24, r14, r12 + str.w r6, [r0, #(2*7*width)] + str.w r9, [r0, #(2*15*width)] + final_butterfly r4, r9, s17, r14 + str.w r9, [r0, #(2*8*width)] + str.w r4, [r0], #width + +// ================ + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(2*8*width), #(2*9*width), #(2*10*width), #(2*11*width), #(2*12*width), #(2*13*width), #(2*14*width), #(2*15*width) + + addSub4 r4, r5, r6, r7, r8, r9, r10, r11 + vmov.w r14, s6 + mul.w r5, r5, r14 + vmov.w r14, s8 + mul.w r9, r9, r14 + addSub2 r4, r6, r8, r10 + vmov.w r14, s7 + mla.w r12, r7, r14, r5 + mls.w r7, r7, r14, r5 + vmov.w r14, s9 + mla.w r5, r11, r14, r9 + mls.w r11, r11, r14, r9 + + // r4, r12, r6, r7, r8, r5, r10, r11 + + vmov.w r14, s12 + mul.w r6, r6, r14 + mul.w r7, r7, r14 + vmov.w r14, s13 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + + barrett_32 r4, r2, r3, r14 + barrett_32 r12, r2, r3, r14 + barrett_32 r6, r2, r3, r14 + barrett_32 r7, r2, r3, r14 + barrett_32 r8, r2, r3, r14 + barrett_32 r5, r2, r3, r14 + barrett_32 r10, r2, r3, r14 + barrett_32 r11, r2, r3, r14 + + addSub4 r4, r8, r6, r10, r12, r5, r7, r11 + + vmov.w s17, s18, r4, r12 + vmov.w s19, s20, r6, r7 + vmov.w s21, s22, r8, r5 + vmov.w s23, s24, r10, r11 + + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(2*0*width), #(2*1*width), #(2*2*width), #(2*3*width), #(2*4*width), #(2*5*width), #(2*6*width), #(2*7*width) + + addSub4 r4, r5, r6, r7, r8, r9, r10, r11 + vmov.w r14, s2 + mul.w r5, r5, r14 + vmov.w r14, s4 + mul.w r9, r9, r14 + addSub2 r4, r6, r8, r10 + vmov.w r14, s3 + mla.w r12, r7, r14, r5 + mls.w r7, r7, r14, r5 + vmov.w r14, s5 + mla.w r5, r11, r14, r9 + mls.w r11, r11, r14, r9 + + // r4, r12, r6, r7, r8, r5, r10, r11 + + vmov.w r14, s10 + mul.w r6, r6, r14 + mul.w r7, r7, r14 + vmov.w r14, s11 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + + barrett_32 r4, r2, r3, r14 + barrett_32 r12, r2, r3, r14 + barrett_32 r6, r2, r3, r14 + barrett_32 r7, r2, r3, r14 + barrett_32 r8, r2, r3, r14 + barrett_32 r5, r2, r3, r14 + barrett_32 r10, r2, r3, r14 + barrett_32 r11, r2, r3, r14 + + addSub4 r4, r8, r6, r10, r12, r5, r7, r11 + vmov.w r14, s14 + mul.w r8, r8, r14 + mul.w r5, r5, r14 + mul.w r10, r10, r14 + mul.w r11, r11, r14 + vmov.w r14, s16 + + final_butterfly r12, r9, s18, r14 + str.w r12, [r0, #(2*1*width)] + str.w r9, [r0, #(2*9*width)] + final_butterfly r6, r9, s19, r14 + str.w r6, [r0, #(2*2*width)] + str.w r9, [r0, #(2*10*width)] + final_butterfly r7, r9, s20, r14 + str.w r7, [r0, #(2*3*width)] + str.w r9, [r0, #(2*11*width)] + vmov.w r12, s15 + final_butterfly2 r8, r6, r9, s21, r14, r12 + str.w r6, [r0, #(2*4*width)] + str.w r9, [r0, #(2*12*width)] + final_butterfly2 r5, r6, r9, s22, r14, r12 + str.w r6, [r0, #(2*5*width)] + str.w r9, [r0, #(2*13*width)] + final_butterfly2 r10, r6, r9, s23, r14, r12 + str.w r6, [r0, #(2*6*width)] + str.w r9, [r0, #(2*14*width)] + final_butterfly2 r11, r6, r9, s24, r14, r12 + str.w r6, [r0, #(2*7*width)] + str.w r9, [r0, #(2*15*width)] + final_butterfly r4, r9, s17, r14 + str.w r9, [r0, #(2*8*width)] + str.w r4, [r0], #31*width + +// ================ + + vmov.w r12, s1 + _ifnt_7_6_5_4_end: + cmp.w r0, r12 + bne.w _ifnt_7_6_5_4 + + sub.w r0, r0, #256*width + + mov.w r14, #0 + + add.w r1, r0, #32*width + _ifnt_0_1_2: +.rept 2 + _ifnt_0_1_2_start: + ldrstrvec ldr.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(32*0*width), #(32*1*width), #(32*2*width), #(32*3*width), #(32*4*width), #(32*5*width), #(32*6*width), #(32*7*width) + + addSub4 r4, r5, r6, r7, r8, r9, r10, r11 + + addSub2 r4, r6, r8, r10 + FNT_CT_ibutterfly r5, r7, 4 + FNT_CT_ibutterfly r9, r11, 4 + + addSub1 r4, r8 + barrett_32 r9, r2, r3, r12 + FNT_CT_ibutterfly r5, r9, 6 + FNT_CT_ibutterfly r6, r10, 4 + FNT_CT_ibutterfly r7, r11, 2 + + barrett_32 r6, r2, r3, r12 + barrett_32 r7, r2, r3, r12 + sub.w r4, r14, r4, lsl #1 + neg.w r5, r5 + lsl.w r6, r6, #7 + lsl.w r7, r7, #6 + lsl.w r8, r8, #5 + lsl.w r9, r9, #4 + lsl.w r10, r10, #3 + lsl.w r11, r11, #2 + + barrett_32 r4, r2, r3, r12 + barrett_32 r5, r2, r3, r12 + barrett_32 r6, r2, r3, r12 + barrett_32 r7, r2, r3, r12 + barrett_32 r8, r2, r3, r12 + barrett_32 r9, r2, r3, r12 + barrett_32 r10, r2, r3, r12 + barrett_32 r11, r2, r3, r12 + + ldrstrvecjump str.w, r0, r4, r5, r6, r7, r8, r9, r10, r11, #(32*1*width), #(32*2*width), #(32*3*width), #(32*4*width), #(32*5*width), #(32*6*width), #(32*7*width), #width + _ifnt_0_1_2_end: +.endr + + cmp.w r0, r1 + bne.w _ifnt_0_1_2 + vpop.w {s16-s24} + pop.w {r4-r11, pc} diff --git a/slothy/targets/arm_v7m/arch_v7m.py b/slothy/targets/arm_v7m/arch_v7m.py index a10e0093..67c5d8eb 100644 --- a/slothy/targets/arm_v7m/arch_v7m.py +++ b/slothy/targets/arm_v7m/arch_v7m.py @@ -293,6 +293,7 @@ def __init__(self, *, mnemonic, self.flag = None self.width = None self.barrel = None + self.range = None def extract_read_writes(self): """Extracts 'reads'/'writes' clauses from the source line of the instruction""" @@ -521,6 +522,7 @@ def pattern_i(i): index_pattern = "[0-9]+" width_pattern = "(?:\.w|\.n|)" barrel_pattern = "(?:lsl|ror|lsr|asr)" + range_pattern = "\{(?P[rs])(?P\\\\d+)-[rs](?P\\\\d+)\}" src = re.sub(" ", "\\\\s+", src) src = re.sub(",", "\\\\s*,\\\\s*", src) @@ -531,6 +533,7 @@ def pattern_i(i): src = replace_placeholders(src, "flag", flag_pattern, "flag") # TODO: Are any changes required for IT syntax? src = replace_placeholders(src, "width", width_pattern, "width") src = replace_placeholders(src, "barrel", barrel_pattern, "barrel") + src = replace_placeholders(src, "range", range_pattern, "range") src = r"\s*" + src + r"\s*(//.*)?\Z" return src @@ -683,6 +686,10 @@ def group_name_i(i): group_to_attribute('flag', 'flag') group_to_attribute('width', 'width') group_to_attribute('barrel', 'barrel') + group_to_attribute('range', 'range') + group_to_attribute('range_start', 'range_start', int) + group_to_attribute('range_end', 'range_end', int) + group_to_attribute('range_type', 'range_type') for s, ty in obj.pattern_inputs: if ty == RegisterType.FLAGS: @@ -755,6 +762,7 @@ def t_default(x): out = replace_pattern(out, "index", "index", str) out = replace_pattern(out, "width", "width", lambda x: x.lower()) out = replace_pattern(out, "barrel", "barrel", lambda x: x.lower()) + out = replace_pattern(out, "range", "range", lambda x: x.lower()) out = out.replace("\\[", "[") out = out.replace("\\]", "]") @@ -787,6 +795,11 @@ class vmov_gpr2(Armv7mFPInstruction): # pylint: disable=missing-docstring,invali pattern = "vmov , " inputs = ["Ra"] outputs = ["Sd"] + +class vmov_gpr2_dual(Armv7mFPInstruction): # pylint: disable=missing-docstring,invalid-name + pattern = "vmov , , , " + inputs = ["Ra", "Rb"] + outputs = ["Sd1", "Sd2"] # movs class movw_imm(Armv7mBasicArithmetic): # pylint: disable=missing-docstring,invalid-name @@ -864,6 +877,21 @@ class mul(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-nam pattern = "mul , , " inputs = ["Ra","Rb"] outputs = ["Rd"] + +class mul_short(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "mul , " + inputs = ["Ra"] + in_outs = ["Rd"] + +class mla(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "mla , , , " + inputs = ["Ra","Rb", "Rc"] + outputs = ["Rd"] + +class mls(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "mls , , , " + inputs = ["Ra","Rb", "Rc"] + outputs = ["Rd"] class smulwb(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name pattern = "smulwb , , " @@ -874,6 +902,11 @@ class smulwt(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid- pattern = "smulwt , , " inputs = ["Ra","Rb"] outputs = ["Rd"] + +class smultb(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "smultb , , " + inputs = ["Ra","Rb"] + outputs = ["Rd"] class smlabt(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name pattern = "smlabt , , , " @@ -890,7 +923,28 @@ class smlal(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-n inputs = ["Rc","Rd"] in_outs = ["Ra", "Rb"] +class smmulr(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "smmulr , , " + inputs = ["Rb","Rc"] + outputs = ["Ra"] + +class smuad(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "smuad , , " + inputs = ["Rb","Rc"] + outputs = ["Ra"] + +class smuadx(Armv7mMultiplication): # pylint: disable=missing-docstring,invalid-name + pattern = "smuadx , , " + inputs = ["Rb","Rc"] + outputs = ["Ra"] + + # Logical + +class neg_short(Armv7mLogical): # pylint: disable=missing-docstring,invalid-name + pattern = "neg , " + inputs = ["Ra"] + in_outs = ["Rd"] class log_and(Armv7mLogical): # pylint: disable=missing-docstring,invalid-name pattern = "and , , " inputs = ["Ra", "Rb"] @@ -961,12 +1015,22 @@ class rors_short(Armv7mLogical): # pylint: disable=missing-docstring,invalid-nam pattern = "rors , " in_outs = ["Rd"] modifiesFlags = True + +class lsl(Armv7mLogical): # pylint: disable=missing-docstring,invalid-name + pattern = "lsl , , " + inputs = ["Ra"] + outputs = ["Rd"] class pkhtb(Armv7mLogical): # pylint: disable=missing-docstring,invalid-name pattern = "pkhtb , , , " inputs = ["Ra", "Rb"] outputs = ["Rd"] +class pkhbt(Armv7mLogical): # pylint: disable=missing-docstring,invalid-name + pattern = "pkhbt , , , " + inputs = ["Ra", "Rb"] + outputs = ["Rd"] + # Load class ldr(Armv7mLoadInstruction): # pylint: disable=missing-docstring,invalid-name pattern = "ldr , []" @@ -1040,6 +1104,15 @@ def make(cls, src): obj.addr = obj.args_in_out[0] return obj +class vldm_interval_inc_writeback(Armv7mLoadInstruction): # pylint: disable=missing-docstring,invalid-name + pattern = "vldm !, " + in_outs = ["Ra"] + outputs = [] + @classmethod + def make(cls, src): + obj = Armv7mLoadInstruction.build(cls, src) + obj.outputs += [f"{obj.range_type}{i}" for i in range(obj.range_start, obj.range_end+1)] + return obj # Store class str_no_off(Armv7mStoreInstruction): # pylint: disable=missing-docstring,invalid-name @@ -1053,6 +1126,18 @@ def make(cls, src): obj.addr = obj.args_in[0] return obj +class strh_with_imm(Armv7mStoreInstruction): # pylint: disable=missing-docstring,invalid-name + pattern = "strh , [, ]" + inputs = ["Ra", "Rd"] + outputs = [] + @classmethod + def make(cls, src): + obj = Armv7mInstruction.build(cls, src) + obj.increment = None + obj.pre_index = obj.immediate + obj.addr = obj.args_in[0] + return obj + class str_with_imm(Armv7mStoreInstruction): # pylint: disable=missing-docstring,invalid-name pattern = "str , [, ]" inputs = ["Ra", "Rd"] @@ -1089,6 +1174,18 @@ def make(cls, src): obj.addr = obj.args_in_out[0] return obj +class strh_with_postinc(Armv7mStoreInstruction): # pylint: disable=missing-docstring,invalid-name + pattern = "strh , [], " + inputs = ["Rd"] + in_outs = ["Ra"] + @classmethod + def make(cls, src): + obj = Armv7mStoreInstruction.build(cls, src) + obj.increment = obj.immediate + obj.pre_index = None + obj.addr = obj.args_in_out[0] + return obj + # Other class cmp(Armv7mBasicArithmetic): # pylint: disable=missing-docstring,invalid-name pattern = "cmp , "