From 26219aab47f6fb81b51a986a4f6ee29a1ad38b87 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Sat, 10 Feb 2024 00:05:55 +0100 Subject: [PATCH] ref: poly: improve scheduling for ntt and invntt --- code/jasmin/mlkem_ref/extraction/jkem.ec | 14 +++++++------- code/jasmin/mlkem_ref/poly.jinc | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index a4907309..c3ce1bda 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1078,9 +1078,9 @@ module M(SC:Syscall_t) = { var zeta_0:W16.t; var j:W64.t; var cmp:W64.t; + var t:W16.t; var offset:W64.t; var s:W16.t; - var t:W16.t; var m:W16.t; zetasp <- witness; zetasp <- jzetas_inv; @@ -1098,10 +1098,10 @@ module M(SC:Syscall_t) = { cmp <- (cmp + len); while ((j \ult cmp)) { + t <- rp.[(W64.to_uint j)]; offset <- j; offset <- (offset + len); s <- rp.[(W64.to_uint offset)]; - t <- rp.[(W64.to_uint j)]; m <- s; m <- (m + t); m <@ __barrett_reduce (m); @@ -1137,10 +1137,10 @@ module M(SC:Syscall_t) = { var zeta_0:W16.t; var j:W64.t; var cmp:W64.t; - var offset:W64.t; - var t:W16.t; var s:W16.t; var m:W16.t; + var offset:W64.t; + var t:W16.t; zetasp <- witness; zetasp <- jzetas; zetasctr <- (W64.of_int 0); @@ -1157,15 +1157,15 @@ module M(SC:Syscall_t) = { cmp <- (cmp + len); while ((j \ult cmp)) { + s <- rp.[(W64.to_uint j)]; + m <- s; offset <- j; offset <- (offset + len); t <- rp.[(W64.to_uint offset)]; t <@ __fqmul (t, zeta_0); - s <- rp.[(W64.to_uint j)]; - m <- s; m <- (m - t); - rp.[(W64.to_uint offset)] <- m; t <- (t + s); + rp.[(W64.to_uint offset)] <- m; rp.[(W64.to_uint j)] <- t; j <- (j + (W64.of_int 1)); } diff --git a/code/jasmin/mlkem_ref/poly.jinc b/code/jasmin/mlkem_ref/poly.jinc index fade03c8..19cc55bb 100644 --- a/code/jasmin/mlkem_ref/poly.jinc +++ b/code/jasmin/mlkem_ref/poly.jinc @@ -486,9 +486,9 @@ fn _poly_invntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] cmp = start; cmp += len; while (j < cmp) { + t = rp[(int)j]; offset = j; offset += len; s = rp[(int)offset]; - t = rp[(int)j]; m = s; m += t; m = __barrett_reduce(m); rp[(int)j] = m; @@ -544,14 +544,14 @@ fn _poly_ntt(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N] cmp = start; cmp += len; while (j < cmp) { + s = rp[(int)j]; + m = s; offset = j; offset += len; t = rp[(int)offset]; t = __fqmul(t, zeta); - s = rp[(int)j]; - m = s; m -= t; - rp[(int)offset] = m; t += s; + rp[(int)offset] = m; rp[(int)j] = t; j += 1; }