Skip to content

Commit

Permalink
keypair: spare a few additions
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl committed Mar 27, 2024
1 parent 3bd9229 commit 643103e
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 98 deletions.
17 changes: 7 additions & 10 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -4987,34 +4987,31 @@ module M(SC:Syscall_t) = {
aux <- (((3 * 384) + 32) %/ 8);
i <- 0;
while (i < aux) {
t64 <- (loadW64 Glob.mem (W64.to_uint (pkp + (W64.of_int (8 * i)))));
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int (8 * i)))) ((loadW64 Glob.mem (W64.to_uint (pkp + (W64.of_int (8 * i))))));
i <- i + 1;
}
s_skp <- skp;
s_skp <- (s_skp + (W64.of_int ((3 * 384) + ((3 * 384) + 32))));
pkp <- s_pkp;
t64 <- (W64.of_int ((3 * 384) + 32));
h_pk <@ _isha3_256 (h_pk, pkp, t64);
skp <- s_skp;
i <- 0;
while (i < 4) {
t64 <- (get64 (WArray32.init8 (fun i_0 => h_pk.[i_0])) i);
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int (8 * i)))) ((get64
(WArray32.init8 (fun i_0 => h_pk.[i_0]))
i));
i <- i + 1;
}
randomnessp <- s_randomnessp;
randomnessp2 <- (Array32.init (fun i_0 => randomnessp.[32 + i_0]));
aux <- (32 %/ 8);
i <- 0;
while (i < aux) {
t64 <- (get64 (WArray32.init8 (fun i_0 => randomnessp2.[i_0])) i);
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int ((8 * i) + 32)))) (
(get64 (WArray32.init8 (fun i_0 => randomnessp2.[i_0])) i));
i <- i + 1;
}
return ();
Expand Down
14 changes: 4 additions & 10 deletions code/jasmin/mlkem_avx2/kem.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,25 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES

for i=0 to MLKEM_INDCPA_PUBLICKEYBYTES/8
{
t64 = (u64)[pkp + 8*i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i] = (u64)[pkp + 8 * i];
}

s_skp = skp;
s_skp += MLKEM_POLYVECBYTES + MLKEM_INDCPA_PUBLICKEYBYTES;
pkp = s_pkp;
t64 = MLKEM_PUBLICKEYBYTES;
h_pk = _isha3_256(h_pk, pkp, t64);
skp = s_skp;

for i=0 to 4
{
t64 = h_pk[u64 i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i] = h_pk[u64 i];
}

randomnessp = s_randomnessp;
randomnessp2 = randomnessp[MLKEM_SYMBYTES:MLKEM_SYMBYTES];
for i=0 to MLKEM_SYMBYTES/8
{
t64 = randomnessp2[u64 i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i + 32] = randomnessp2[u64 i];
}
}

Expand Down
17 changes: 7 additions & 10 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -2214,34 +2214,31 @@ module M(SC:Syscall_t) = {
aux <- (((3 * 384) + 32) %/ 8);
i <- 0;
while (i < aux) {
t64 <- (loadW64 Glob.mem (W64.to_uint (pkp + (W64.of_int (8 * i)))));
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int (8 * i)))) ((loadW64 Glob.mem (W64.to_uint (pkp + (W64.of_int (8 * i))))));
i <- i + 1;
}
s_skp <- skp;
s_skp <- (s_skp + (W64.of_int ((3 * 384) + ((3 * 384) + 32))));
pkp <- s_pkp;
t64 <- (W64.of_int ((3 * 384) + 32));
h_pk <@ _isha3_256 (h_pk, pkp, t64);
skp <- s_skp;
i <- 0;
while (i < 4) {
t64 <- (get64 (WArray32.init8 (fun i_0 => h_pk.[i_0])) i);
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int (8 * i)))) ((get64
(WArray32.init8 (fun i_0 => h_pk.[i_0]))
i));
i <- i + 1;
}
randomnessp <- s_randomnessp;
randomnessp2 <- (Array32.init (fun i_0 => randomnessp.[32 + i_0]));
aux <- (32 %/ 8);
i <- 0;
while (i < aux) {
t64 <- (get64 (WArray32.init8 (fun i_0 => randomnessp2.[i_0])) i);
Glob.mem <-
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int 0))) (t64);
skp <- (skp + (W64.of_int 8));
storeW64 Glob.mem (W64.to_uint (skp + (W64.of_int ((8 * i) + 32)))) (
(get64 (WArray32.init8 (fun i_0 => randomnessp2.[i_0])) i));
i <- i + 1;
}
return ();
Expand Down
14 changes: 4 additions & 10 deletions code/jasmin/mlkem_ref/kem.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,25 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES

for i=0 to MLKEM_INDCPA_PUBLICKEYBYTES/8
{
t64 = (u64)[pkp + 8*i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i] = (u64)[pkp + 8 * i];
}

s_skp = skp;
s_skp += MLKEM_POLYVECBYTES + MLKEM_INDCPA_PUBLICKEYBYTES;
pkp = s_pkp;
t64 = MLKEM_POLYVECBYTES + MLKEM_SYMBYTES;
h_pk = _isha3_256(h_pk, pkp, t64);
skp = s_skp;

for i=0 to 4
{
t64 = h_pk[u64 i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i] = h_pk[u64 i];
}

randomnessp = s_randomnessp;
randomnessp2 = randomnessp[MLKEM_SYMBYTES:MLKEM_SYMBYTES];
for i=0 to MLKEM_SYMBYTES/8
{
t64 = randomnessp2[u64 i];
(u64)[skp] = t64;
skp += 8;
(u64)[skp + 8 * i + 32] = randomnessp2[u64 i];
}
}

Expand Down
59 changes: 30 additions & 29 deletions proof/correctness/MLKEM_KEM.ec
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ swap {1} 1 14.

seq 19 4 : (
z{2} =Array32.init (fun i => randomnessp{1}.[32 + i]) /\
to_uint skp{1} = _skp + 1152 + 1152 + 32 + 32 /\
to_uint skp{1} = _skp + 1152 + 1152 + 32 /\
valid_disj_reg _pkp (384*3+32) _skp (384*3 + 384*3 + 32 + 32 + 32 + 32) /\
touches2 Glob.mem{1} mem _pkp (384 * 3 + 32) _skp (384*3 + 384*3 + 32 + 32 + 32 + 32) /\
sk{2} = load_array1152 Glob.mem{1} _skp /\
Expand All @@ -79,7 +79,7 @@ seq 19 4 : (
); last first.
+ while {1} (aux{1} = 4 /\
z{2} = Array32.init (fun i => randomnessp2{1}.[i]) /\
to_uint skp{1} = _skp + 1152 + 1152 + 32 + 32 + i{1}*8 /\
to_uint skp{1} = _skp + 1152 + 1152 + 32 /\
valid_disj_reg _pkp (384*3+32) _skp (384*3 + 384*3 + 32 + 32 + 32 + 32) /\
touches2 Glob.mem{1} mem _pkp (384 * 3 + 32) _skp (384*3 + 384*3 + 32 + 32 + 32 + 32) /\
sk{2} = load_array1152 Glob.mem{1} _skp /\
Expand All @@ -94,9 +94,8 @@ seq 19 4 : (
pack8_t (W8u8.Pack.init (fun i => z{2}.[k*8+i])))
(4 - i{1}).
+ move => &m z0; auto => /> &hr; rewrite /touches2 /load_array1152 /load_array32 !tP =>
skv ????? touch pk1vs pk2vs pk1v pk2v ??prev? ; rewrite !to_uintD_small /=.
+ by smt().
do split; 1,9,12: by smt().
skv ????? touch pk1vs pk2vs pk1v pk2v ??prev? ; rewrite !to_uintD_small !to_uint_small /= 1..3:/#.
do split.
+ by move => a H1 H2; rewrite /storeW64 /loadW64 /stores /= !get_set_neqE_s /#.
+ by move => k kb; rewrite !initiE //= /storeW64 /loadW64 /stores /= !get_set_neqE_s /#.
+ move => k kb; rewrite !initiE //= /storeW64 /loadW64
Expand All @@ -114,6 +113,7 @@ seq 19 4 : (
/stores /= !get_set_neqE_s; 1..8: smt().
by rewrite pk2v // initiE //=.
+ by smt().
+ by smt().
+ move => k kbl kbh.
case (k < i{hr}).
+ move => hk.
Expand All @@ -132,13 +132,14 @@ seq 19 4 : (
rewrite WArray32.WArray32.get64E !pack8bE 1..8:/# !initiE 1..8:/# /= /init8.
rewrite !WArray32.WArray32.initiE 1..8:/#.
by smt(get_set_neqE_s get_set_eqE_s).
by smt().

auto => />;move => ????????touch????; do split.
+ rewrite tP => k kb; rewrite !initiE /#.
+ smt().
move => memL iL skpL.
move => memL iL.
split; 1: smt().
move => ???touch2????????store???.
move => ???touch2???????store???.
rewrite /load_array32 tP => k kb.
rewrite !initiE //=.
move : (store (k %/ 8) _); 1: smt().
Expand All @@ -152,7 +153,7 @@ swap {2} 2 2.
swap {1} [2..3] 1; sp 4 1.

wp;conseq (_: _ ==>
to_uint skp{1} = _skp + 2368 /\
to_uint skp{1} = _skp + 2336 /\
touches2 Glob.mem{1} mem (_pkp) 1184 (_skp) 2432 /\
sk{2} = load_array1152 Glob.mem{1} (_skp) /\
pk{2}.`1 = load_array1152 Glob.mem{1} (_skp + 1152) /\
Expand Down Expand Up @@ -185,19 +186,17 @@ seq 8 0 : (#{/~to_uint skp{1} = _skp}pre /\
pk{2}.`2 = load_array32 Glob.mem{1} (_skp+ 2304)
).

+ wp;while {1} (#{/~to_uint skp{1} = _skp}{~s_skp{1} = skp{1}}pre /\
+ wp; while {1} (#{/~to_uint skp{1} = _skp}{~s_skp{1} = skp{1}}pre /\
aux{1} = (3 * 384 + 32) %/ 8 /\ 0<=i{1} <= aux{1} /\
to_uint skp{1} = _skp + 3*384 + i{1}*8 /\
to_uint skp{1} = _skp + 3*384 /\
(forall k, 0<= k < min (8 * i{1}) 1152 =>
pk{2}.`1.[k] = Glob.mem{1}.[_skp + 3*384 + k]) /\
(forall k, min (8 * i{1}) 1152 <= k < min (8 * i{1}) (1152 + 32) =>
pk{2}.`2.[k-1152] = Glob.mem{1}.[_skp + 3*384 + k]))
((3 * 384 + 32) %/ 8 - i{1}).
move => &m z;auto => /> &hr. rewrite /load_array32 /load_array1152 !tP /touches2.
move => ???????touch pkv1 pkv2???prev1 prev2 ?;rewrite !to_uintD_small /=.
+ by rewrite of_uintK /= modz_small /=; smt().
by smt().
do split; 5..7:smt().
move => &m z;auto => /> &hr. rewrite !tP /touches2.
move => ???????touch pkv1 pkv2???prev1 prev2 ?; rewrite !to_uintD_small !to_uint_small /=; 1..5: smt().
do split.
+ move => i ib ibb.
rewrite /storeW64 /loadW64 /stores /=.
rewrite !get_set_neqE_s; 1..8:smt().
Expand All @@ -214,12 +213,13 @@ seq 8 0 : (#{/~to_uint skp{1} = _skp}pre /\
rewrite /storeW64 /loadW64 /stores /=.
rewrite !get_set_neqE_s; 1..8:smt().
by move : (pkv2 k kb); rewrite initiE //=.
+ smt().
+ smt().
+ move => kk kkbl kkbh.
rewrite /storeW64 /loadW64 /stores /=.
case (kk < i{hr} * 8).
+ by move => *; rewrite !get_set_neqE_s; smt().
move => ?.
rewrite !of_uintK /= modz_small;1:smt().
move : (pkv1 kk _); 1: smt().
rewrite initiE /=; 1: smt().
by smt(get_set_neqE_s get_set_eqE_s).
Expand All @@ -228,21 +228,24 @@ seq 8 0 : (#{/~to_uint skp{1} = _skp}pre /\
case (kk < i{hr} * 8).
+ by move => *; rewrite !get_set_neqE_s; smt().
move => ?.
rewrite !of_uintK /= modz_small;1:smt().
move : (pkv2 (kk - 1152) _); 1: smt().
rewrite initiE /=; 1: smt().
by smt(get_set_neqE_s get_set_eqE_s).
+ by smt().
auto => /> &1 &2; rewrite /load_array1152 /load_array32 !tP.

auto => /> &1 &2; rewrite !tP.
move => ??????????.
rewrite to_uintD_small /=; 1: by smt().
do split; 1..2: by smt().
move => meml il skpl.
move => meml il.
rewrite !tP; split; 1: smt().
move => ??????????; do split; 1: smt().
move => ????????X; do split.
+ rewrite to_uintD_small //= /#.
+ by move => *; rewrite initiE //= /#.
by move => *; rewrite initiE //= /#.

move => j hj.
have {1}-> : j = (j + 1152) - 1152 by ring.
rewrite X 1:/# initiE /#.

seq 4 1 :
(to_uint skp{1} = _skp + 2336 /\
valid_disj_reg _pkp 1184 _skp 2432 /\
Expand All @@ -258,10 +261,10 @@ inline *; auto => /> &1 &2; rewrite /touches /touches2 /load_array1152 /load_arr
+ move => i ib; congr; rewrite /H_pk; congr.
by smt(Array32.initiE Array1152.initiE Array32.tP Array1152.tP).

while {1} (#{/~to_uint skp{1} = _skp + 2336}pre /\ 0 <= i{1} <= 4 /\ to_uint skp{1} = _skp + 2336 + 8*i{1} /\ forall k, 0 <= k < i{1} * 8 => Glob.mem{1}.[_skp + 2336 + k] = (H_pk pk{2}).[k]) (4 - i{1}).
while {1} (#pre /\ 0 <= i{1} <= 4 /\ forall k, 0 <= k < i{1} * 8 => Glob.mem{1}.[_skp + 2336 + k] = (H_pk pk{2}).[k]) (4 - i{1}).
move => &m z; auto => /> &1 &2; rewrite /load_array1152 /load_array32 /touches2 !tP.
move => ?????pkv1s pkv2s pkv1 pkv2 ??? prev ?.
rewrite !to_uintD_small /= 1:/#.
move => ??????pkv1s pkv2s pkv1 pkv2 ?? prev ?.
rewrite !to_uintD_small /= !to_uint_small; 1..3: smt().
do split.
+ move => i ib ih.
rewrite /storeW64 /loadW64 /stores /=.
Expand All @@ -287,7 +290,6 @@ do split.
by move : (pkv2 i ib); rewrite initiE //=.
+ by smt().
+ by smt().
+ by smt().
+ move => k kbl kbh.
case (k < (i{1} * 8)).
+ by move => kl;rewrite /storeW64 /loadW64 /stores /= !get_set_neqE_s;smt().
Expand All @@ -306,9 +308,8 @@ do split.
+ move => i ib ih.
rewrite /storeW64 /loadW64 /stores /=.
by smt(get_set_neqE_s get_set_eqE_s).
move => memL iL skL; do split; 1: by smt().
move => *; split; 1: by smt().
by rewrite tP => i ib; smt(Array32.initiE).
move => memL iL; split; 1: by smt().
by move => 9? X; rewrite Array32.tP => *; rewrite -X 1:/# initiE.
qed.

lemma mlkem_kem_correct_enc mem _ctp _pkp _kp :
Expand Down
Loading

0 comments on commit 643103e

Please sign in to comment.