diff --git a/mk/schemes.mk b/mk/schemes.mk index 64abf8604..ef59c606b 100644 --- a/mk/schemes.mk +++ b/mk/schemes.mk @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 SOURCES = $(wildcard mlkem/*.c) -ifneq ($(OPT),REF) - SOURCES += $(wildcard mlkem/asm/*.S) - CPPFLAGS += -D$(OPT) +ifeq ($(OPT),AARCH64) + SOURCES += $(wildcard mlkem/asm/aarch64/*.S) + CPPFLAGS += -DMLKEM_OPT_AARCH64 endif CPPFLAGS += -Imlkem diff --git a/mlkem/asm/intt_123_4567.S b/mlkem/asm/aarch64/intt_123_4567.S similarity index 100% rename from mlkem/asm/intt_123_4567.S rename to mlkem/asm/aarch64/intt_123_4567.S diff --git a/mlkem/asm/intt_kyber_123_45_67_twiddles.S b/mlkem/asm/aarch64/intt_kyber_123_45_67_twiddles.S similarity index 100% rename from mlkem/asm/intt_kyber_123_45_67_twiddles.S rename to mlkem/asm/aarch64/intt_kyber_123_45_67_twiddles.S diff --git a/mlkem/asm/macro.S b/mlkem/asm/aarch64/macro.S similarity index 100% rename from mlkem/asm/macro.S rename to mlkem/asm/aarch64/macro.S diff --git a/mlkem/asm/ntt_123_4567.S b/mlkem/asm/aarch64/ntt_123_4567.S similarity index 100% rename from mlkem/asm/ntt_123_4567.S rename to mlkem/asm/aarch64/ntt_123_4567.S diff --git a/mlkem/asm/ntt_kyber_123_45_67_twiddles.S b/mlkem/asm/aarch64/ntt_kyber_123_45_67_twiddles.S similarity index 100% rename from mlkem/asm/ntt_kyber_123_45_67_twiddles.S rename to mlkem/asm/aarch64/ntt_kyber_123_45_67_twiddles.S diff --git a/mlkem/asm/asm.h b/mlkem/asm/asm.h new file mode 100644 index 000000000..107a3f09d --- /dev/null +++ b/mlkem/asm/asm.h @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef ASM_H +#define ASM_H + +#include +#include "params.h" + +#ifdef MLKEM_OPT_AARCH64 +void ntt_kyber_123_4567(int16_t *); +void intt_kyber_123_4567(int16_t *); +#endif + +#endif diff --git a/mlkem/ntt.c b/mlkem/ntt.c index bb44b6961..5a3302182 100644 --- a/mlkem/ntt.c +++ b/mlkem/ntt.c @@ -4,6 +4,8 @@ #include "reduce.h" #include +#include "asm/asm.h" + /* Code to generate zetas and zetas_inv used in the number-theoretic transform: #define KYBER_ROOT_OF_UNITY 17 @@ -37,11 +39,6 @@ void init_ntt() { } */ -#ifdef NTT123_4567 -void ntt_kyber_123_4567(int16_t *); -void intt_kyber_123_4567(int16_t *); -#endif - const int16_t zetas[128] = { -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, @@ -84,9 +81,9 @@ static int16_t fqmul(int16_t a, int16_t b) **************************************************/ void ntt(int16_t r[256]) { - #ifdef NTT123_4567 + #ifdef MLKEM_OPT_AARCH64 ntt_kyber_123_4567(r); - #else + #else /* OPT_AARCH64 */ unsigned int len, start, j, k; int16_t t, zeta; @@ -104,7 +101,7 @@ void ntt(int16_t r[256]) } } } - #endif + #endif /* OPT_AARCH64 */ } /************************************************* @@ -119,9 +116,9 @@ void ntt(int16_t r[256]) **************************************************/ void invntt(int16_t r[256]) { - #ifdef NTT123_4567 + #ifdef MLKEM_OPT_AARCH64 intt_kyber_123_4567(r); - #else + #else /* OPT_AARCH64 */ unsigned int start, len, j, k; int16_t t, zeta; const int16_t f = 1441; // mont^2/128 @@ -146,7 +143,7 @@ void invntt(int16_t r[256]) { r[j] = fqmul(r[j], f); } - #endif + #endif /* OPT_AARCH64 */ } /************************************************* diff --git a/scripts/tests b/scripts/tests index c1d1defb9..5daf997eb 100755 --- a/scripts/tests +++ b/scripts/tests @@ -292,7 +292,7 @@ _shared_options = [ ), click.option( "--opt", - type=click.Choice(["REF", "NTT123_4567"], case_sensitive=False), + type=click.Choice(["REF", "AARCH64"], case_sensitive=False), help="Choose optimized version", ), ] @@ -455,7 +455,7 @@ def bench( extra_make_envs=process_make_envs(cflags, arch_flags), extra_make_args=[ f"CYCLES={cycles}", - *(f"OPT={opt}" if opt is not None else ""), + f"OPT={opt}" if opt is not None else "", ], )