Skip to content

Commit

Permalink
remove [nosmt] tags + fix proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
strub committed Jul 21, 2024
1 parent 6d9e524 commit 02b5d9d
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 41 deletions.
2 changes: 1 addition & 1 deletion proof/correctness/MLKEMFCLib.ec
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ lemma initEq16 (f g: int -> 'a) :
(*-----------------------------------------------------------------------------*)
lemma nosmt set_neqiE (t : coeff Array256.t) x y a :
lemma set_neqiE (t : coeff Array256.t) x y a :
y <> x => t.[x <- a].[y] = t.[y].
proof. by rewrite get_set_if => /neqF ->. qed.
Expand Down
5 changes: 3 additions & 2 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,8 @@ have /= [#] redbl6 redbh6 redv6 :=
have /= [#] redbl7 redbh7 redv7 :=
(SREDCp_corr (to_sint r6 * to_sint (- jzetas.[64 + to_uint i{hr} %/ 4])) hq _).
+ rewrite /R /=; move : (zeta_bound (64 + to_uint i{hr} %/ 4)); rewrite /minimum_residues /bpos16 => zb.
rewrite to_sintN /=; do split; smt().
case/(_ _): zb => *; ~-1:smt().
by rewrite to_sintN /=; do split; smt().

have /= [#] redbl8 redbh8 redv8 :=
(SREDCp_corr (to_sint ap{hr}.[to_uint i{hr}+2] * to_sint bp{hr}.[to_uint i{hr}+2]) hq _).
Expand Down Expand Up @@ -1801,7 +1802,7 @@ case (k < to_uint i{hr} %/4).
+ move => kbb; move: (vprev k _); 1:smt(); rewrite !mapiE /=; 1..12:smt().
rewrite /doublemul /cmplx_mul_169 /=.
move => /> vprev0 vprev1 vprev2 vprev3.
by rewrite !set_neqiE /#.
by rewrite !set_neqiE //#.

move => *; have kval : (k = to_uint i{hr} %/ 4) by smt().
have -> : 4 * k = to_uint i{hr} by smt().
Expand Down
10 changes: 5 additions & 5 deletions proof/correctness/Montgomery.ec
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ op BREDC(a bits : int) =

require import Barrett_mlkem_general.

lemma nosmt BREDCp_corr a bits:
lemma BREDCp_corr a bits:
0 < 2 * q < R %/2 =>
R < 2^bits =>
2 ^ bits %/ q * q < 2 ^ bits =>
Expand Down Expand Up @@ -222,7 +222,7 @@ op SREDC (a: int) : int =
let t = smod (a - u %/ R * q) (R^2)in
smod (t %/ R %% (R^2)) R.

lemma nosmt SREDCp_corr a:
lemma SREDCp_corr a:
0 < q < R %/2 =>
-R %/ 2 * q <= a < R %/2 * q =>
-q <= SREDC a < q /\
Expand Down Expand Up @@ -353,7 +353,7 @@ op REDC' (T: int) : int =
let m = ((T %% R)*_N') %% R
in (T + m*_N) %/ R.

lemma nosmt aux_divR T:
lemma aux_divR T:
let m = ((T %% R)*_N') %% R
in (T + m*_N) %% R = 0.
proof.
Expand All @@ -364,7 +364,7 @@ smt().
qed.


lemma nosmt REDC'_congr T:
lemma REDC'_congr T:
REDC' T %% _N = T * Rinv %% _N.
proof.
pose m := ((T %% R)*_N') %% R.
Expand All @@ -378,7 +378,7 @@ have t_modN: t %% _N = T*Rinv %% _N.
by rewrite /REDC'.
qed.

lemma nosmt REDC'_bnds T n:
lemma REDC'_bnds T n:
0 <= n =>
0 <= T < _N + _N * R^(n+1) =>
0 <= REDC' T < _N + _N*R^n.
Expand Down
58 changes: 29 additions & 29 deletions proof/correctness/Montgomery16.ec
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require import MLKEMFCLib.

(* @JBA: MOVE THIS *)

lemma nosmt modzB_eq0 (x y m:int):
lemma modzB_eq0 (x y m:int):
0 < m => (x-y) %% m = 0 =>
x%%m = y%%m.
proof.
Expand Down Expand Up @@ -53,7 +53,7 @@ case: ( 2 ^ (16 - 1) <= x) => ?//.
by rewrite -modzDmr -modzNm modzz.
qed.

lemma nosmt to_sint_mod x:
lemma to_sint_mod x:
W16.to_sint x %% W16.modulus = to_uint x.
proof.
rewrite /to_sint /smod.
Expand All @@ -69,7 +69,7 @@ lemma to_sintK (w : W16.t) :
W16.of_int (W16.to_sint w) = w.
proof. by rewrite -of_int_mod to_sint_mod to_uintK //. qed.

lemma nosmt to_sint_eq (w1 w2: W16.t):
lemma to_sint_eq (w1 w2: W16.t):
to_sint w1 = to_sint w2 <=> w1=w2.
proof.
rewrite !to_sintE /smod /=.
Expand All @@ -82,7 +82,7 @@ case: (32768 <= to_uint w2) => CB /=.
by rewrite to_uint_eq.
qed.

lemma nosmt smod_small (x: int):
lemma smod_small (x: int):
-2^(16-1) <= x < 2^(16-1) =>
W16.smod (x %% W16.modulus) = x.
proof.
Expand All @@ -92,7 +92,7 @@ rewrite /smod; case: (x < 0) => C.
by have ->/#: ! 2 ^ (16 - 1) <= x %% W16.modulus by smt().
qed.

lemma nosmt modzM_sint (x y: W16.t):
lemma modzM_sint (x y: W16.t):
(to_sint x * to_sint y) %% W16.modulus
= (to_uint x * to_uint y) %% W16.modulus.
proof.
Expand All @@ -110,14 +110,14 @@ by rewrite modzMm.
done.
qed.

lemma nosmt to_sintM (x y: W16.t):
lemma to_sintM (x y: W16.t):
to_sint (x*y)
= W16.smod ((to_sint x * to_sint y) %% W16.modulus).
proof.
by rewrite {1}/W16.to_sint to_uintM modzM_sint.
qed.

lemma nosmt modzD_sint (x y: W16.t):
lemma modzD_sint (x y: W16.t):
(to_sint x + to_sint y) %% W16.modulus
= (to_uint x + to_uint y) %% W16.modulus.
proof.
Expand All @@ -126,22 +126,22 @@ case: (2 ^ (16 - 1) <= to_uint x);
case: (2 ^ (16 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sintD (x y: W16.t):
lemma to_sintD (x y: W16.t):
to_sint (x+y)
= W16.smod ((to_sint x + to_sint y)%%W16.modulus).
proof.
by rewrite {1}/W16.to_sint to_uintD modzD_sint.
qed.

lemma nosmt modzN_sint (x: W16.t):
lemma modzN_sint (x: W16.t):
(- to_sint x) %% W16.modulus
= (- to_uint x) %% W16.modulus.
proof.
rewrite /to_sint /smod.
case: (2 ^ (16 - 1) <= to_uint x); smt().
qed.

lemma nosmt to_sintN (x: W16.t):
lemma to_sintN (x: W16.t):
to_sint (-x)
= W16.smod ((-to_sint x) %% W16.modulus).
proof.
Expand All @@ -164,7 +164,7 @@ have X: 32768 <= (- to_uint x) %% 65536 by smt().
smt().
qed.

lemma nosmt modzB_sint (x y: W16.t):
lemma modzB_sint (x y: W16.t):
(to_sint x - to_sint y) %% W16.modulus
= (to_uint x - to_uint y) %% W16.modulus.
proof.
Expand All @@ -173,14 +173,14 @@ case: (2 ^ (16 - 1) <= to_uint x);
case: (2 ^ (16 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sintB (x y: W16.t):
lemma to_sintB (x y: W16.t):
to_sint (x-y)
= W16.smod ((to_sint x - to_sint y)%%W16.modulus).
proof.
rewrite {1}/W16.to_sint to_uintD to_uintN modzB_sint /#.
qed.

lemma nosmt wmulsE (x y: W16.t):
lemma wmulsE (x y: W16.t):
to_sint x * to_sint y
= to_sint (wmulhs x y) * W16.modulus
+ to_uint (x * y).
Expand All @@ -193,7 +193,7 @@ have /=?:= to_sint_cmp y.
smt().
qed.

lemma nosmt to_sint_wmulhs x y:
lemma to_sint_wmulhs x y:
to_sint (W16.wmulhs x y) = to_sint x * to_sint y %/ W16.modulus.
proof.
rewrite wmulsE divzDl 1:/# mulzK 1:/#; ring.
Expand Down Expand Up @@ -221,15 +221,15 @@ lemma sint_bndW (x: W16.t) (xL1 xH1 xL2 xH2: int):
sint_bnd xL2 xH2 x
by smt().

lemma nosmt to_sintD_small (x y: W16.t):
lemma to_sintD_small (x y: W16.t):
W16.min_sint <= to_sint x + to_sint y <= W16.max_sint =>
to_sint (x+y) = to_sint x + to_sint y.
proof.
move=> /=?; rewrite to_sintD smod_small /= /#.
qed.

(* a version of [to_sintD_small] with bounds *)
lemma nosmt to_sintD_small' (xL xH yL yH: int) (x y: W16.t):
lemma to_sintD_small' (xL xH yL yH: int) (x y: W16.t):
sint_bnd xL xH x =>
sint_bnd yL yH y =>
W16.min_sint <= xL+yL =>
Expand All @@ -239,7 +239,7 @@ proof.
by move=> /= *; rewrite to_sintD_small /#.
qed.
lemma nosmt to_sintN_small (x: W16.t):
lemma to_sintN_small (x: W16.t):
W16.min_sint < to_sint x =>
to_sint (-x) = - to_sint x.
proof.
Expand Down Expand Up @@ -384,7 +384,7 @@ qed.
abbrev sint32_bnd xL xH x =
xL <= W32.to_sint x <= xH.
lemma nosmt modzD_sint32 (x y: W32.t):
lemma modzD_sint32 (x y: W32.t):
(to_sint x + to_sint y) %% W32.modulus
= (to_uint x + to_uint y) %% W32.modulus.
proof.
Expand All @@ -393,14 +393,14 @@ case: (2 ^ (32 - 1) <= to_uint x);
case: (2 ^ (32 - 1) <= to_uint y); smt().
qed.
lemma nosmt to_sint32D (x y: W32.t):
lemma to_sint32D (x y: W32.t):
to_sint (x+y)
= W32.smod ((to_sint x + to_sint y)%%W32.modulus).
proof.
by rewrite {1}/W32.to_sint to_uintD modzD_sint32.
qed.
lemma nosmt smod32_small (x: int):
lemma smod32_small (x: int):
-2^(32-1) <= x < 2^(32-1) =>
W32.smod (x %% W32.modulus) = x.
proof.
Expand All @@ -410,7 +410,7 @@ rewrite /smod; case: (x < 0) => C.
by have ->/#: ! 2 ^ (32 - 1) <= x %% W32.modulus by smt().
qed.
lemma nosmt to_sint32D_small (x y: W32.t):
lemma to_sint32D_small (x y: W32.t):
W32.min_sint <= to_sint x + to_sint y <= W32.max_sint =>
to_sint (x+y) = to_sint x + to_sint y.
proof.
Expand All @@ -427,22 +427,22 @@ proof.
by move=> /> *; rewrite to_sint32D_small /#.
qed.
(*
lemma nosmt modzN_sint32 (x: W32.t):
lemma modzN_sint32 (x: W32.t):
(- to_sint x) %% W32.modulus
= (- to_uint x) %% W32.modulus.
proof.
rewrite /to_sint /smod.
case: (2 ^ (32 - 1) <= to_uint x); smt().
qed.
lemma nosmt to_sint32N (x: W32.t):
lemma to_sint32N (x: W32.t):
to_sint (-x)
= W32.smod ((-to_sint x) %% W32.modulus).
proof.
by rewrite {1}/W32.to_sint to_uintN modzN_sint32.
qed.
*)
lemma nosmt modzB_sint32 (x y: W32.t):
lemma modzB_sint32 (x y: W32.t):
(to_sint x - to_sint y) %% W32.modulus
= (to_uint x - to_uint y) %% W32.modulus.
proof.
Expand All @@ -451,7 +451,7 @@ case: (2 ^ (32 - 1) <= to_uint x);
case: (2 ^ (32 - 1) <= to_uint y); smt().
qed.
lemma nosmt to_sint32B (x y: W32.t):
lemma to_sint32B (x y: W32.t):
to_sint (x-y)
= W32.smod ((to_sint x - to_sint y)%%W32.modulus).
proof.
Expand Down Expand Up @@ -512,7 +512,7 @@ rewrite -{1}unpack16K /unpack16 /=; congr.
by rewrite init_of_list -JUtils.iotaredE /=.
qed.
lemma nosmt modz_sint32 (x: W32.t):
lemma modz_sint32 (x: W32.t):
(to_sint x) %% W16.modulus
= (to_uint x) %% W16.modulus.
proof.
Expand Down Expand Up @@ -567,7 +567,7 @@ op REDC16 (xyL xyH: W16.t): W16.t =
in xyH - (wmulhs m (W16.of_int q)).
(* general bounds... *)
lemma nosmt REDC16_correct bL bR (xyL xyH: W16.t):
lemma REDC16_correct bL bR (xyL xyH: W16.t):
W16.min_sint + 1664 <= bL <= 0 =>
0 <= bR <= W16.max_sint - 1665 =>
sint_bnd bL bR xyH =>
Expand Down Expand Up @@ -608,7 +608,7 @@ by rewrite eq_sym -modzDmr -Domain.mulNr -modzMm modzz mod0z.
qed.
(* useful specific case *)
lemma nosmt REDC16_correct_q (xyL xyH: W16.t):
lemma REDC16_correct_q (xyL xyH: W16.t):
sint_bnd (-q%/2) (q%/2 - 1) xyH =>
to_sint (REDC16 xyL xyH) %% q
= (to_sint xyH * R + to_uint xyL) * Rinv %% q
Expand All @@ -625,7 +625,7 @@ abbrev REDCmul16 (x y: W16.t): W16.t = REDC16 (x*y) (wmulhs x y).
(* correctness result for multiplication, for
the specific case of a reduced argument *)
lemma nosmt REDCmul16_correct (x y: W16.t):
lemma REDCmul16_correct (x y: W16.t):
sint_bnd 0 (q-1) y =>
to_sint (REDCmul16 x y) %% q
= to_sint x * to_sint y * Rinv %% q
Expand Down
2 changes: 1 addition & 1 deletion proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ move: (noise_coef_avx2_aux bytes j) => /=.
by rewrite /noise_coef C' to_sintE /smod to_uint_shr //= => <- /#.
qed.

lemma nosmt to_sint8_mod x:
lemma to_sint8_mod x:
W8.to_sint x %% W8.modulus = to_uint x.
proof.
rewrite /to_sint /smod.
Expand Down
8 changes: 6 additions & 2 deletions proof/correctness/avx2/NTT_AVX_j.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,11 @@ lemma wmuls16P n x y _x _y:
Iu16_sb n x _x =>
Iu16_sb n y _y =>
sint32_bnd (-n*n*q*q) (n*n*q*q) (wmuls16 x y).
proof. by move => [??] [??]; rewrite to_sint_wmuls16 /#. qed.
proof.
move=> [? [??]] [? [??]]; rewrite to_sint_wmuls16.
have ->: (n * n * q * q) = (n * q) * (n * q) by ring.
by rewrite &(ler_norml) normrM ler_pmul // 1,2:normr_ge0 /#.
qed.
phoare wmul_16u16_ph n _x _y:
[Jkem_avx2.M(Jkem_avx2.Syscall).__wmul_16u16:
Expand Down Expand Up @@ -1506,7 +1510,7 @@ qed.
(** Butterfly *)
lemma nosmt REDCmul16coeff (x y: W16.t):
lemma REDCmul16coeff (x y: W16.t):
sint_bnd 0 (q-1) y =>
incoeffW16 (Montgomery16.REDCmul16 x y)
= incoeffW16 x * incoeffW16 y * incoeff Montgomery16.Rinv
Expand Down

0 comments on commit 02b5d9d

Please sign in to comment.