Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poly_compress: one counter #21

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -816,25 +816,22 @@ module M(SC:Syscall_t) = {
proc _poly_compress (rp:W64.t, a:W16.t Array256.t) : W16.t Array256.t = {

var i:W64.t;
var j:W64.t;
var t:W16.t;
var d0:W32.t;
var d1:W32.t;

a <@ _poly_csubq (a);
i <- (W64.of_int 0);
j <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- a.[(W64.to_uint j)];
t <- a.[(W64.to_uint ((W64.of_int 2) * i))];
d0 <- (zeroextu32 t);
d0 <- (d0 `<<` (W8.of_int 4));
d0 <- (d0 + (W32.of_int 1665));
d0 <- (d0 * (W32.of_int 80635));
d0 <- (d0 `>>` (W8.of_int 28));
d0 <- (d0 `&` (W32.of_int 15));
j <- (j + (W64.of_int 1));
t <- a.[(W64.to_uint j)];
t <- a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))];
d1 <- (zeroextu32 t);
d1 <- (d1 `<<` (W8.of_int 4));
d1 <- (d1 + (W32.of_int 1665));
Expand All @@ -845,7 +842,6 @@ module M(SC:Syscall_t) = {
d0 <- (d0 `|` d1);
Glob.mem <- storeW8 Glob.mem (W64.to_uint (rp + i)) ((truncateu8 d0));
i <- (i + (W64.of_int 1));
j <- (j + (W64.of_int 1));
}
return (a);
}
Expand Down
9 changes: 3 additions & 6 deletions code/jasmin/mlkem_ref/poly.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,21 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N]
{
reg u16 t;
reg u32 d0, d1;
reg u64 i j;
reg u64 i;

a = _poly_csubq(a);

i = 0;
j = 0;
while(i < 128)
{
t = a[(int)j];
t = a[2 * i];
d0 = (32u)t;
d0 <<= 4;
d0 += 1665;
d0 *= 80635;
d0 >>= 28;
d0 &= 0xf;
j += 1;
t = a[(int)j];
t = a[2 * i + 1];
d1 = (32u)t;
d1 <<= 4;
d1 += 1665;
Expand All @@ -171,7 +169,6 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N]
d0 |= d1;
(u8)[rp+i] = d0;
i += 1;
j += 1;
}
return a;
}
Expand Down
21 changes: 11 additions & 10 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -782,25 +782,26 @@ lemma poly_compress_corr _a (_p : address) mem :
touches mem Glob.mem{1} _p 128 /\
load_array128 Glob.mem{1} _p = res{2}].
proc => /=.
seq 3 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
to_uint j{1} = j{2} /\ j{2} = 0 /\
seq 2 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
j{2} = 0 /\
pos_bound256_cxq a{1} 0 256 1 /\ lift_array256 a{1} = _a).
wp => /=;call{1} (poly_csubq_corr _a); 1: by auto => /#.

while (#{/~mem}{~i{2}=0}{~j{2}=0}pre /\ to_uint i{1} = i{2} /\ 0<=i{2}<=128 /\
to_uint j{1} = j{2} /\ j{2} = 2* i{2} /\
j{2} = 2* i{2} /\
touches mem Glob.mem{1} _p i{2} /\
forall k, 0<=k<i{2} => loadW8 Glob.mem{1} (_p + k) = r{2}.[k]); last first.
+ auto => /> &1 &2; rewrite ultE of_uintK /load_array32 /loadW8 /ptr /= =>
vpl vph bnd ??; split; 1: by smt().
move => mem' i' j' ra'; rewrite ultE of_uintK /= => exit _ ibl ibh jv touch load.
vpl vph bnd ?; split; 1: by smt().
move => mem' i' ra'; rewrite ultE of_uintK /= => exit _ ibl ibh touch load.
split; 1: smt().
by rewrite tP => k kb; rewrite initiE //= (load k _) /#.

auto => /> &1 &2 ??; rewrite /pos_bound256_cxq /touches /storeW8 /loadW8 /=.
rewrite ultE of_uintK /= => ????????.
rewrite !to_uintD_small /=; 1..4: smt().
do split; 1..4: by smt(get_set_neqE_s).
rewrite ultE of_uintK /= => ???X???.
rewrite !to_uintD_small ?to_uintM_small 1..5:/#.
do split; 1..3: smt().
+ move => *; rewrite get_set_neqE_s 1:/#; apply: X; smt().
+ move => k kbl kbh.
case (k = to_uint i{1}); last first.
+ move => neq; rewrite get_set_neqE_s; 1: by smt().
Expand All @@ -815,9 +816,9 @@ do split; 1..4: by smt(get_set_neqE_s).
case (k = to_uint i{1}); last by smt(Array128.set_neqiE).
move => iv; have -> : 15 = 2^4 - 1 by auto.
rewrite !and_mod //.
pose x := (((zeroextu32 a{1}.[to_uint j{1}] `<<` (of_int 4)%W8) +
pose x := (((zeroextu32 a{1}.[2 * to_uint i{1}] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>`(of_int 28)%W8).
pose y := (((zeroextu32 a{1}.[to_uint j{1} + 1] `<<` (of_int 4)%W8) +
pose y := (((zeroextu32 a{1}.[2 * to_uint i{1} + 1] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>` (of_int 28)%W8).
rewrite to_uint_eq to_uint_truncateu8 !of_uintK to_uint_orw_disjoint.
+ apply W32.ext_eq => i ib; rewrite /W32.(`&`) map2E initiE //=.
Expand Down