Skip to content

Commit

Permalink
Removing admits
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Nov 1, 2024
1 parent 2181c6f commit e9ae902
Showing 1 changed file with 161 additions and 27 deletions.
188 changes: 161 additions & 27 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,7 @@ realize get_out by smt(Array4.get_out).
op init_256_16 (f: int -> W16.t) : W16.t Array256.t = Array256.init f.

bind op [W16.t & Array256.t] init_256_16 "ainit".
realize bvainitP by admit. (* Not provable *)
realize bvainitP by admit. (* Not provable bvinit has no semantics *)


op init_768_16 (f: int -> W16.t) : W16.t Array768.t = Array768.init f.
Expand All @@ -1792,27 +1792,27 @@ realize bvainitP by admit. (* Not provable *)
op sliceget256_16_256 (arr: W16.t Array256.t) (i: int) : W256.t = WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (i %/ 256).

bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget".
realize bvaslicegetP by admit.
realize bvaslicegetP by admit. (* We need a general framework for these *)

op sliceset256_16_256 (arr: W16.t Array256.t) (i: int) (bv: W256.t) : W16.t Array256.t = Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (i %/ 256) bv) i3).

bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset".
realize bvaslicesetP by admit.
realize bvaslicesetP by admit. (* We need a general framework for these *)

op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).

bind op [W8.t & W256.t & Array32.t] sliceget32_8_256 "asliceget".
realize bvaslicegetP by admit.
realize bvaslicegetP by admit. (* We need a general framework for these *)

op sliceget768_16_256 (arr: W16.t Array768.t) (i: int) : W256.t = get256 (WArray1536.init16 (fun (i_0 : int) => arr.[i_0])) (i %/ 256).

bind op [W16.t & W256.t & Array768.t] sliceget768_16_256 "asliceget".
realize bvaslicegetP by admit.
realize bvaslicegetP by admit. (* We need a general framework for these *)

op sliceset960_8_128 (arr: W8.t Array960.t) (i: int) (bv: W128.t) : W8.t Array960.t = Array960.init (get8 (set128_direct ((init8 (fun (i_0 : int) => arr.[i_0])))%WArray960 (i %/ 8) bv)).

bind op [W8.t & W128.t & Array960.t] sliceset960_8_128 "asliceset".
realize bvaslicesetP by admit.
realize bvaslicesetP by admit. (* We need a general framework for these *)


op sliceset960_8_32 (arr: W8.t Array960.t) (i: int) (bv: W32.t) : W8.t Array960.t = Array960.init
Expand All @@ -1822,7 +1822,7 @@ op sliceset960_8_32 (arr: W8.t Array960.t) (i: int) (bv: W32.t) : W8.t Array960.


bind op [W8.t & W32.t & Array960.t] sliceset960_8_32 "asliceset".
realize bvaslicesetP by admit.
realize bvaslicesetP by admit. (* We need a general framework for these *)


theory W10.
Expand Down Expand Up @@ -1863,41 +1863,120 @@ by rewrite !nth_mkseq // /bits2w initiE //= nth_mkseq /#.
qed.

bind op [W64.t & W8.t] W8u8.truncateu8 "truncate".
realize bvtruncateP.
realize bvtruncateP. (* generalize *)
move => mv; rewrite /truncateu8 /W64.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // nth_mkseq //=.
rewrite get_to_uint //= /to_uint.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 64) by smt().
pose a := bs2int (w2bits mv).
admit.
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (8 - i) * 2 ^ i) = 256;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.

bind op [W16.t & W8.t] W2u8.truncateu8 "truncate".
realize bvtruncateP by admit.
realize bvtruncateP.
move => mv; rewrite /truncateu8 /W16.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 16) by smt().
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (8 - i) * 2 ^ i) = 256;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.


bind op [W16.t & W64.t] W4u16.zeroextu64 "zextend".
realize bvzextendP by admit.
realize bvzextendP
by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W16.to_uint_cmp pow2_16).

bind op [W64.t & W16.t] W4u16.truncateu16 "truncate".
realize bvtruncateP by admit.
realize bvtruncateP.
move => mv; rewrite /truncateu16 /W64.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 64) by smt().
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.


op sll_64 (w1 w2 : W64.t) : W64.t =
w1 `<<` (truncateu8 w2).

bind op [W64.t] sll_64 "shl".
realize bvshlP by admit.
realize bvshlP by admit. (* not provable. mod 2^64 missing? *)

bind op [W32.t & W16.t] W2u16.truncateu16 "truncate".
realize bvtruncateP by admit.
realize bvtruncateP.
move => mv; rewrite /truncateu16 /W32.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 32) by smt().
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.


bind op [W16.t & W32.t] sigextu32 "sextend".
realize bvsextendP by admit.
realize bvsextendP.
move => bv;rewrite /sigextu32 /to_sint /smod /= !of_uintK /=.
case (32768 <= to_uint bv); 2: smt(W16.to_uint_cmp).
move =>?;rewrite -{2}(oppzK (to_uint bv - 65536)) modNz /=; smt(W16.to_uint_cmp pow2_16).
qed.

bind op [W32.t & W8.t] W4u8.truncateu8 "truncate".
realize bvtruncateP by admit.
realize bvtruncateP.
move => mv; rewrite /truncateu8 /W32.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 32) by smt().
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (8 - i) * 2 ^ i) = 256;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.


bind circuit VPBROADCAST_8u32 "VPBROADCAST_8u32".
bind circuit VPBROADCAST_4u64 "VPBROADCAST_4u64".
Expand All @@ -1919,7 +1998,24 @@ bind circuit VPEXTR_32 "VEXTRACTI32_256".
bind circuit W4u32.VPEXTR_32 "VEXTRACTI32_128".

bind op [W256.t & W128.t] truncateu128 "truncate".
realize bvtruncateP by admit.
realize bvtruncateP.
move => mv; rewrite /truncateu128 /W256.w2bits take_mkseq //= /w2bits.
apply (eq_from_nth witness);1: by smt(size_mkseq).
move => i; rewrite size_mkseq /= /max /= => ib.
rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w //
nth_mkseq //= get_to_uint //= /to_uint /=.
have -> /=: (0 <= i && i < 256) by smt().
pose a := bs2int (w2bits mv).
rewrite {1}(divz_eq a (2^(128-i)*2^i)) !mulrA divzMDl;
1: by smt(StdOrder.IntOrder.expr_gt0).
rewrite dvdz_modzDl; 1: by
have -> : 2^(128-i) = 2^((128-i-1)+1); [ by smt() |
rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)].
by have -> : (2 ^ (128 - i) * 2 ^ i) = 340282366920938463463374607431768211456;
[ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg
1,2:/# /= -!addrA /= | done ].
qed.


op sra_32 (w1 w2 : W32.t) : W32.t =
w1 `|>>` (truncateu8 w2).
Expand All @@ -1937,19 +2033,22 @@ op srl_16 (w1 w2 : W16.t) : W16.t =
w1 `>>` (truncateu8 w2).

bind op [W16.t] srl_16 "shr".
realize bvshrP by admit.
realize bvshrP.
move => v1 v2; rewrite /srl_16 /(`>>`) to_uint_shr;1:smt(W16.to_uint_cmp).
admit. (* not provable. missing %% 256? *)
qed.

op sll_16 (w1 w2 : W16.t) : W16.t =
w1 `<<` (truncateu8 w2).

bind op [W16.t] sll_16 "shl".
realize bvshlP by admit.
realize bvshlP by admit. (* not provable. missing %% 65536? *)

op srl_64 (w1 w2 : W64.t) : W64.t =
w1 `>>` (truncateu8 w2).

bind op [W64.t] srl_64 "shr".
realize bvshrP by admit.
realize bvshrP by admit. (* not provable. missing %% 256? *)

op lane_func_reduce(c : W16.t) : W16.t =
let t = (sigextu32 c) * (W32.of_int 20159) in
Expand All @@ -1970,9 +2069,24 @@ op pcond_reduced (w: W16.t) = w \ule W16.of_int (2*3329).
lemma reduce_commutes x xr : xr = lane_func_reduce x => pcond_reduced xr.
rewrite /lane_func_reduce /pcond_reduced. print Fq.SignedReductions.
have := Fq.SignedReductions.BREDCp_corr (to_sint x) 26 _ _ _ _ _ _; rewrite ?qE /R //=.
+ admit. smt().
rewrite /BREDC.
admit.
+ by have /= := W16.to_sint_cmp x;smt().
+ by smt().
move => [#] ??? Hr.
have -> : xr = W16.of_int (Fq.SignedReductions.BREDC (to_sint x) 26); last by smt(W16.to_uintK W16.of_uintK pow2_16). search W32.(`|>>`).
rewrite Hr /W16_sub /sra_32 /sigextu32 /truncateu16 /= Fq.SAR_sem26 !W32.of_sintK /= !W32.of_uintK W16.to_uint_eq to_uintD.
pose xx := (smod (to_sint x * 20159 %% 4294967296))%W32 %/ 67108864 * 3329.
have -> /= : to_uint (- (of_int (xx %% W32.modulus))%W16) = to_uint (W16.of_int (-xx)).
+ rewrite of_uintK of_intN' W16.of_uintK /=.
case (xx %% 2^32 = 0); 1: by smt().
move => /= *; rewrite modNz; 1,2: smt(modz_ge0).
by rewrite -modzDml (modz_dvd xx (65536 * 65536) 65536); smt().
rewrite /BREDC /R /= !Fq.smod_W32 !Fq.smod_W16 qE /= -/xx !of_uintK /=.
congr;congr.
rewrite to_sintE /smod /=.
case (32768 <= to_uint x); last by admit.
move => H.
have ->: to_uint x - 65536 = to_uint x + (-1) * 65536; 1: by ring.
rewrite modzMDr. admit.
qed.

import BitEncoding.BitChunking.
Expand Down Expand Up @@ -2128,6 +2242,20 @@ map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))] = 1%r.
admitted.

lemma unbits16 (bpl : W16.t list) :
(map W16.bits2w (chunk 16 (flatten (map W16.w2bits bpl)))) = bpl.
rewrite flattenK 1:/#;1:smt(mapP W16.size_w2bits).
by rewrite -map_comp /(\o) /= -/idfun map_id.
qed.

lemma injective_fun_8_10 l1 l2 :
size l1 = size l2 =>
map W10.bits2w (chunk 10 (flatten (map W8.w2bits l1))) =
map W10.bits2w (chunk 10 (flatten (map W8.w2bits l2))) =>
l1 = l2.
admit.
qed.

(* MAP REDUCE GOAL *)
lemma compress10_mr :
equiv [AuxPolyVecCompress10.avx2 ~ AuxPolyVecCompress10.ref : lift_array768 bp{1} = lift_array768 bp{2}==> ={res}].
Expand All @@ -2136,7 +2264,13 @@ exlim bp{1}, bp{2} => _bp1 _bp2.
call{1} (avx_correctness_p _bp1).
call{2} (ref_correctness_p _bp2).
auto => /> Hpre r1 Hr1 r2 Hr2.
admit.
rewrite !flatten1 !unbits16 in Hr1.
rewrite !flatten1 !unbits16 in Hr2.
have Hmap : map lane_polyvec_redcomp10 (to_list _bp2) = map lane_polyvec_redcomp10 (to_list _bp1).
+ rewrite /lane_polyvec_redcomp10.
admit.
rewrite -(Array960.to_listK W8.zero r2) -(Array960.to_listK W8.zero r1).
by smt(injective_fun_8_10 Array960.size_to_list).
qed.

(*****************************************************************)
Expand Down

0 comments on commit e9ae902

Please sign in to comment.