diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index e670cafa..6896d69b 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -4350,58 +4350,48 @@ module M(SC:Syscall_t) = { var ctr:W64.t; var pos:W64.t; - var exit:W64.t; var val1:W16.t; var t:W16.t; var val2:W16.t; var cond:bool; - var cnd0:W64.t; - var cnd1:W64.t; ctr <- offset; pos <- (W64.of_int 0); - exit <- (W64.of_int 0); - - while ((exit = (W64.of_int 0))) { - val1 <- (zeroextu16 buf.[(W64.to_uint pos)]); - pos <- (pos + (W64.of_int 1)); - t <- (zeroextu16 buf.[(W64.to_uint pos)]); - val2 <- t; - val2 <- (val2 `>>` (W8.of_int 4)); - t <- (t `&` (W16.of_int 15)); - t <- (t `<<` (W8.of_int 8)); - val1 <- (val1 `|` t); - pos <- (pos + (W64.of_int 1)); - t <- (zeroextu16 buf.[(W64.to_uint pos)]); - t <- (t `<<` (W8.of_int 4)); - val2 <- (val2 `|` t); - pos <- (pos + (W64.of_int 1)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { - rp.[(W64.to_uint ctr)] <- val1; - ctr <- (ctr + (W64.of_int 1)); - } else { - - } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { - if ((ctr \ult (W64.of_int 256))) { - rp.[(W64.to_uint ctr)] <- val2; + + while ((pos \ult (W64.of_int (168 - 2)))) { + if ((ctr \ult (W64.of_int 256))) { + val1 <- (zeroextu16 buf.[(W64.to_uint pos)]); + t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 1)))]); + val2 <- t; + val2 <- (val2 `>>` (W8.of_int 4)); + t <- (t `&` (W16.of_int 15)); + t <- (t `<<` (W8.of_int 8)); + val1 <- (val1 `|` t); + t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 2)))]); + t <- (t `<<` (W8.of_int 4)); + val2 <- (val2 `|` t); + pos <- (pos + (W64.of_int 3)); + cond <- (val1 \ult (W16.of_int 3329)); + if (cond) { + rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { + } + cond <- (val2 \ult (W16.of_int 3329)); + if (cond) { + if ((ctr \ult (W64.of_int 256))) { + rp.[(W64.to_uint ctr)] <- val2; + ctr <- (ctr + (W64.of_int 1)); + } else { + + } + } else { + } } else { - + pos <- (W64.of_int 168); } - cnd0 <- (W64.of_int 256); - cnd0 <- (cnd0 - ctr); - cnd0 <- (cnd0 - (W64.of_int 1)); - cnd1 <- (W64.of_int 168); - cnd1 <- (cnd1 - pos); - cnd1 <- (cnd1 - (W64.of_int 3)); - exit <- (cnd0 `|` cnd1); - exit <- (exit `>>` (W8.of_int 63)); } return (ctr, rp); } diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 0cfba0b3..59e3b518 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -9,59 +9,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] reg u16 val1 val2; reg u16 t; reg u64 pos ctr; - reg u64 cnd0 cnd1 exit; ctr = offset; pos = 0; - exit = 0; - while(exit == 0) - { - val1 = (16u)buf[(int)pos]; - pos += 1; - t = (16u)buf[(int)pos]; - val2 = t; - val2 >>= 4; - t &= 0x0F; - t <<= 8; - val1 |= t; - pos += 1; - - t = (16u)buf[(int)pos]; - t <<= 4; - val2 |= t; - pos += 1; - - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond - { - rp[(int)ctr] = val1; - ctr += 1; - } - - #[declassify] - cond = val2 < MLKEM_Q; - if cond - { - if(ctr < MLKEM_N) - { - rp[(int)ctr] = val2; + while (pos < SHAKE128_RATE - 2) { + if ctr < MLKEM_N { + val1 = (16u)buf[pos]; + t = (16u)buf[pos + 1]; + val2 = t; + val2 >>= 4; + t &= 0x0F; + t <<= 8; + val1 |= t; + + t = (16u)buf[pos + 2]; + t <<= 4; + val2 |= t; + pos += 3; + + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { + rp[ctr] = val1; ctr += 1; } - } - // Check if we should exit the loop - cnd0 = MLKEM_N; - cnd0 -= ctr; - cnd0 -= 1; - cnd1 = SHAKE128_RATE; - cnd1 -= pos; - cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes - exit = cnd0 | cnd1; - exit >>= 63; + #[declassify] + cond = val2 < MLKEM_Q; + if cond { + if(ctr < MLKEM_N) + { + rp[ctr] = val2; + ctr += 1; + } + } + } else { + pos = SHAKE128_RATE; + } } return ctr, rp; diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index a2d16249..12e08009 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1621,58 +1621,48 @@ module M(SC:Syscall_t) = { var ctr:W64.t; var pos:W64.t; - var exit:W64.t; var val1:W16.t; var t:W16.t; var val2:W16.t; var cond:bool; - var cnd0:W64.t; - var cnd1:W64.t; ctr <- offset; pos <- (W64.of_int 0); - exit <- (W64.of_int 0); - - while ((exit = (W64.of_int 0))) { - val1 <- (zeroextu16 buf.[(W64.to_uint pos)]); - pos <- (pos + (W64.of_int 1)); - t <- (zeroextu16 buf.[(W64.to_uint pos)]); - val2 <- t; - val2 <- (val2 `>>` (W8.of_int 4)); - t <- (t `&` (W16.of_int 15)); - t <- (t `<<` (W8.of_int 8)); - val1 <- (val1 `|` t); - pos <- (pos + (W64.of_int 1)); - t <- (zeroextu16 buf.[(W64.to_uint pos)]); - t <- (t `<<` (W8.of_int 4)); - val2 <- (val2 `|` t); - pos <- (pos + (W64.of_int 1)); - cond <- (val1 \ult (W16.of_int 3329)); - if (cond) { - rp.[(W64.to_uint ctr)] <- val1; - ctr <- (ctr + (W64.of_int 1)); - } else { - - } - cond <- (val2 \ult (W16.of_int 3329)); - if (cond) { - if ((ctr \ult (W64.of_int 256))) { - rp.[(W64.to_uint ctr)] <- val2; + + while ((pos \ult (W64.of_int (168 - 2)))) { + if ((ctr \ult (W64.of_int 256))) { + val1 <- (zeroextu16 buf.[(W64.to_uint pos)]); + t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 1)))]); + val2 <- t; + val2 <- (val2 `>>` (W8.of_int 4)); + t <- (t `&` (W16.of_int 15)); + t <- (t `<<` (W8.of_int 8)); + val1 <- (val1 `|` t); + t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 2)))]); + t <- (t `<<` (W8.of_int 4)); + val2 <- (val2 `|` t); + pos <- (pos + (W64.of_int 3)); + cond <- (val1 \ult (W16.of_int 3329)); + if (cond) { + rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { + } + cond <- (val2 \ult (W16.of_int 3329)); + if (cond) { + if ((ctr \ult (W64.of_int 256))) { + rp.[(W64.to_uint ctr)] <- val2; + ctr <- (ctr + (W64.of_int 1)); + } else { + + } + } else { + } } else { - + pos <- (W64.of_int 168); } - cnd0 <- (W64.of_int 256); - cnd0 <- (cnd0 - ctr); - cnd0 <- (cnd0 - (W64.of_int 1)); - cnd1 <- (W64.of_int 168); - cnd1 <- (cnd1 - pos); - cnd1 <- (cnd1 - (W64.of_int 3)); - exit <- (cnd0 `|` cnd1); - exit <- (exit `>>` (W8.of_int 63)); } return (ctr, rp); } diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index 7461a13f..f261b711 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -7,59 +7,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] reg u16 val1 val2; reg u16 t; reg u64 pos ctr; - reg u64 cnd0 cnd1 exit; ctr = offset; pos = 0; - exit = 0; - while(exit == 0) - { - val1 = (16u)buf[(int)pos]; - pos += 1; - t = (16u)buf[(int)pos]; - val2 = t; - val2 >>= 4; - t &= 0x0F; - t <<= 8; - val1 |= t; - pos += 1; - - t = (16u)buf[(int)pos]; - t <<= 4; - val2 |= t; - pos += 1; - - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond - { - rp[(int)ctr] = val1; - ctr += 1; - } - - #[declassify] - cond = val2 < MLKEM_Q; - if cond - { - if(ctr < MLKEM_N) - { - rp[(int)ctr] = val2; + while (pos < SHAKE128_RATE - 2) { + if ctr < MLKEM_N { + val1 = (16u)buf[pos]; + t = (16u)buf[pos + 1]; + val2 = t; + val2 >>= 4; + t &= 0x0F; + t <<= 8; + val1 |= t; + + t = (16u)buf[pos + 2]; + t <<= 4; + val2 |= t; + pos += 3; + + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { + rp[ctr] = val1; ctr += 1; } - } - // Check if we should exit the loop - cnd0 = MLKEM_N; - cnd0 -= ctr; - cnd0 -= 1; - cnd1 = SHAKE128_RATE; - cnd1 -= pos; - cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes - exit = cnd0 | cnd1; - exit >>= 63; + #[declassify] + cond = val2 < MLKEM_Q; + if cond { + if(ctr < MLKEM_N) + { + rp[ctr] = val2; + ctr += 1; + } + } + } else { + pos = SHAKE128_RATE; + } } return ctr, rp; diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 3f598fbc..29cbb3de 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -561,6 +561,10 @@ swap {1} 2 -1; seq 1 1 : (#pre /\ buf{1} = b168{2}). by auto => />. wp; conseq />. +splitwhile{1} ^while : to_uint ctr0 < 256. +while{1} (!(to_uint pos{1} < 166 /\ to_uint ctr0{1} < 256) /\ #post) (168 - to_uint pos{1}). ++ move => &h z; rcondf ^if; auto => /> &m; rewrite !ultE /#. + while(0<=j0{2}<=256 /\ 0<=k{2}<=168 /\to_uint ctr0{1} = j0{2} /\ buf0{1} = b168{2} /\ to_uint pos{1} = k{2} /\ k{2} %% 3 = 0 /\ (forall (k0 : int), @@ -571,20 +575,21 @@ while(0<=j0{2}<=256 /\ 0<=k{2}<=168 /\to_uint ctr0{1} = j0{2} /\ buf0{1} = b168{ 0 <= to_sint rp{1}.[k0] /\ to_sint rp{1}.[k0] < q)); last first. -+ auto => /> &1 &2 ???????; do split; 1,2: smt(). - move => ctrl rpl aar kr => *; do split; 1,2: smt(); - by rewrite ultE /=. ++ auto => /> &1 &2 *; do split; 1,2: smt(). + move => ctrl rpl aar kr => *; do split; 1, 2, 3: smt(). + * by rewrite ultE /=. + * by rewrite ultE /=. + by move => *; rewrite ultE /#. + +rcondt{1} ^if; first by move => &m; auto => &h /> *; rewrite ultE. -seq 13 6 : (#{/~k{2} < 168}pre /\ to_uint val1{1} = d1{2} /\ to_uint val2{1} = d2{2}). +seq 11 6 : (#{/~k{2} < 168}{/~pos{1} \ult (of_int (168 - 2))%W64}pre /\ to_uint val1{1} = d1{2} /\ to_uint val2{1} = d2{2}). -auto => /> &1 &2 ?????????; do split; 1,2,4:smt(). -+ by rewrite to_uintD_small; smt(). -+ by rewrite mergebytes to_uintD_small; smt(). -by rewrite mergebytes2 !to_uintD_small; smt(). +auto => /> &1 &2 *; do split; rewrite ?ultE ?to_uintD_small // ?mergebytes // ?mergebytes2 // /#. seq 4 2 : (to_uint ctr0{1} = j0{2} /\ to_uint pos{1} = k{2} /\ - #{/~exit{1}}post). + #{/~pos{1} \ult (of_int (168 - 2))%W64}post). + sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt(). + sp 3 2; if{2}. @@ -652,27 +657,7 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\ rcondf{1} 1; 1: by move => *; auto => /> &1; rewrite !ultE /= /#. by auto => />. -auto => /> &1 &2 *. - -rewrite extract_msb. - -have : -(! ((of_int 256)%W64 - ctr0{1} - W64.one `|` ((of_int 168)%W64 - pos{1} - (of_int 3)%W64)).[63]) <=> -to_uint ctr0{1} < 256 /\ (to_uint ctr0{1} < 256 => to_uint pos{1} < 168); last by smt(). - -rewrite /W64.(`|`) map2E //=. - -have ->: ((of_int 256)%W64 - ctr0{1} - W64.one).[63] = (255 - to_uint ctr0{1} < 0). -+ have -> : W64.of_int 256 - ctr0{1} - W64.one = W64.of_int (255 - to_uint ctr0{1}) - by ring; rewrite to_uintK; ring. - by rewrite of_intwE /= /int_bit /= /#. - -have -> : ((of_int 168)%W64 - pos{1} - (of_int 3)%W64).[63] = (165 - to_uint pos{1} < 0); - last by smt(). - -+ have -> : (W64.of_int 168) - pos{1} - (W64.of_int 3) = W64.of_int (165 - to_uint pos{1}) - by ring; rewrite to_uintK; ring. - by rewrite of_intwE /= /int_bit /= /#. +auto => /> &1 &2 *; rewrite ultE /#. qed.