diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index 12e08009..3a5bc9c6 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -816,25 +816,22 @@ module M(SC:Syscall_t) = { proc _poly_compress (rp:W64.t, a:W16.t Array256.t) : W16.t Array256.t = { var i:W64.t; - var j:W64.t; var t:W16.t; var d0:W32.t; var d1:W32.t; a <@ _poly_csubq (a); i <- (W64.of_int 0); - j <- (W64.of_int 0); while ((i \ult (W64.of_int 128))) { - t <- a.[(W64.to_uint j)]; + t <- a.[(W64.to_uint ((W64.of_int 2) * i))]; d0 <- (zeroextu32 t); d0 <- (d0 `<<` (W8.of_int 4)); d0 <- (d0 + (W32.of_int 1665)); d0 <- (d0 * (W32.of_int 80635)); d0 <- (d0 `>>` (W8.of_int 28)); d0 <- (d0 `&` (W32.of_int 15)); - j <- (j + (W64.of_int 1)); - t <- a.[(W64.to_uint j)]; + t <- a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))]; d1 <- (zeroextu32 t); d1 <- (d1 `<<` (W8.of_int 4)); d1 <- (d1 + (W32.of_int 1665)); @@ -845,7 +842,6 @@ module M(SC:Syscall_t) = { d0 <- (d0 `|` d1); Glob.mem <- storeW8 Glob.mem (W64.to_uint (rp + i)) ((truncateu8 d0)); i <- (i + (W64.of_int 1)); - j <- (j + (W64.of_int 1)); } return (a); } diff --git a/code/jasmin/mlkem_ref/poly.jinc b/code/jasmin/mlkem_ref/poly.jinc index 0b69a270..9cee29bc 100644 --- a/code/jasmin/mlkem_ref/poly.jinc +++ b/code/jasmin/mlkem_ref/poly.jinc @@ -144,23 +144,21 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N] { reg u16 t; reg u32 d0, d1; - reg u64 i j; + reg u64 i; a = _poly_csubq(a); i = 0; - j = 0; while(i < 128) { - t = a[(int)j]; + t = a[2 * i]; d0 = (32u)t; d0 <<= 4; d0 += 1665; d0 *= 80635; d0 >>= 28; d0 &= 0xf; - j += 1; - t = a[(int)j]; + t = a[2 * i + 1]; d1 = (32u)t; d1 <<= 4; d1 += 1665; @@ -171,7 +169,6 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N] d0 |= d1; (u8)[rp+i] = d0; i += 1; - j += 1; } return a; } diff --git a/proof/correctness/MLKEM_Poly.ec b/proof/correctness/MLKEM_Poly.ec index 791a87f3..2e84ce92 100644 --- a/proof/correctness/MLKEM_Poly.ec +++ b/proof/correctness/MLKEM_Poly.ec @@ -782,25 +782,26 @@ lemma poly_compress_corr _a (_p : address) mem : touches mem Glob.mem{1} _p 128 /\ load_array128 Glob.mem{1} _p = res{2}]. proc => /=. -seq 3 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\ - to_uint j{1} = j{2} /\ j{2} = 0 /\ +seq 2 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\ + j{2} = 0 /\ pos_bound256_cxq a{1} 0 256 1 /\ lift_array256 a{1} = _a). wp => /=;call{1} (poly_csubq_corr _a); 1: by auto => /#. while (#{/~mem}{~i{2}=0}{~j{2}=0}pre /\ to_uint i{1} = i{2} /\ 0<=i{2}<=128 /\ - to_uint j{1} = j{2} /\ j{2} = 2* i{2} /\ + j{2} = 2* i{2} /\ touches mem Glob.mem{1} _p i{2} /\ forall k, 0<=k loadW8 Glob.mem{1} (_p + k) = r{2}.[k]); last first. + auto => /> &1 &2; rewrite ultE of_uintK /load_array32 /loadW8 /ptr /= => - vpl vph bnd ??; split; 1: by smt(). - move => mem' i' j' ra'; rewrite ultE of_uintK /= => exit _ ibl ibh jv touch load. + vpl vph bnd ?; split; 1: by smt(). + move => mem' i' ra'; rewrite ultE of_uintK /= => exit _ ibl ibh touch load. split; 1: smt(). by rewrite tP => k kb; rewrite initiE //= (load k _) /#. auto => /> &1 &2 ??; rewrite /pos_bound256_cxq /touches /storeW8 /loadW8 /=. -rewrite ultE of_uintK /= => ????????. -rewrite !to_uintD_small /=; 1..4: smt(). -do split; 1..4: by smt(get_set_neqE_s). +rewrite ultE of_uintK /= => ???X???. +rewrite !to_uintD_small ?to_uintM_small 1..5:/#. +do split; 1..3: smt(). ++ move => *; rewrite get_set_neqE_s 1:/#; apply: X; smt(). + move => k kbl kbh. case (k = to_uint i{1}); last first. + move => neq; rewrite get_set_neqE_s; 1: by smt(). @@ -815,9 +816,9 @@ do split; 1..4: by smt(get_set_neqE_s). case (k = to_uint i{1}); last by smt(Array128.set_neqiE). move => iv; have -> : 15 = 2^4 - 1 by auto. rewrite !and_mod //. - pose x := (((zeroextu32 a{1}.[to_uint j{1}] `<<` (of_int 4)%W8) + + pose x := (((zeroextu32 a{1}.[2 * to_uint i{1}] `<<` (of_int 4)%W8) + (of_int 1665)%W32) * (of_int 80635)%W32 `>>`(of_int 28)%W8). - pose y := (((zeroextu32 a{1}.[to_uint j{1} + 1] `<<` (of_int 4)%W8) + + pose y := (((zeroextu32 a{1}.[2 * to_uint i{1} + 1] `<<` (of_int 4)%W8) + (of_int 1665)%W32) * (of_int 80635)%W32 `>>` (of_int 28)%W8). rewrite to_uint_eq to_uint_truncateu8 !of_uintK to_uint_orw_disjoint. + apply W32.ext_eq => i ib; rewrite /W32.(`&`) map2E initiE //=.