Skip to content

Commit

Permalink
mlkem768 ref sct: work in progress - recording the state to try a dif…
Browse files Browse the repository at this point in the history
…ferent approach for gen_matrix -- current overhead between 3% and 5%
  • Loading branch information
tfaoliveira-sb committed Apr 9, 2024
1 parent 5c14571 commit 97cc863
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 77 deletions.
3 changes: 3 additions & 0 deletions code/jasmin/mlkem_ref/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) jpolyvec.s
ct:
$(JASMINC) -checkCT -infer jkem.jazz

sct:
$(JASMINC) -checkSCT jkem.jazz

clean:
-rm -f *.s
-rm -f jindcpa.o
Expand Down
95 changes: 44 additions & 51 deletions code/jasmin/mlkem_ref/fips202.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ fn ____xtr_bytes(
}
}

// Note: the following code is not used; Todo: Remove it after double checking why it is here.

/*
inline
fn ____keccak1600_ref(
stack u64 s_out s_outlen,
Expand Down Expand Up @@ -462,17 +464,16 @@ fn __shake256(reg u64 out outlen in inlen)
config[1] = rate;
__keccak1600_ref(out, outlen, in, inlen, config);
}
*/


fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[128]
fn _shake256_128_33(#spill_to_mmx reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[128]
{
stack u64[25] state;
reg u8 c;
inline int i;

stack ptr u8[128] sout;

sout = out;
() = #spill(out);

state = __st0(state);

Expand All @@ -485,7 +486,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12

state = _keccakf1600_(state);

out = sout;
() = #unspill(out);

for i = 0 to 128 {
c = state[u8 (int) i];
Expand All @@ -494,16 +495,15 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12
return out;
}

fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) {
fn _shake256_1120_32(#spill_to_mmx reg u64 out in0 in1) {
stack u64[25] state;
stack u64 s_out s_in1;
stack u64 s_in s_ilen s_r8;
reg u64 ilen r8 t64 in;
#spill_to_mmx reg u64 ilen r8;
reg u64 t64;
reg u8 t8;
inline int i;

s_out = out;
s_in1 = in1;
() = #spill(out);

state = __st0(state);

for i = 0 to MLKEM_SYMBYTES/8 {
Expand All @@ -515,35 +515,34 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) {
t64 = (u64)[in1 + (i-MLKEM_SYMBYTES/8)*8];
state[u64 i] ^= t64;
}


() = #spill(in1);

state = _keccakf1600_(state);

() = #unspill(in1);

r8 = SHAKE256_RATE;
ilen = MLKEM_CT_LEN - (SHAKE256_RATE - MLKEM_SYMBYTES);
in = s_in1;
in += SHAKE256_RATE - MLKEM_SYMBYTES;
in1 += SHAKE256_RATE - MLKEM_SYMBYTES;

while(ilen >= r8)
{
state, in, ilen = __add_full_block(state, in, ilen, r8);
state, in1, ilen = __add_full_block(state, in1, ilen, r8);

s_in = in;
s_ilen = ilen;
s_r8 = r8;
() = #spill(in1, ilen, r8);

state = _keccakf1600_(state);

in = s_in;
ilen = s_ilen;
r8 = s_r8;
() = #unspill(in1, ilen, r8);
}

t8 = 0x1f;
state = __add_final_block(state, in, ilen, t8, r8);
state = __add_final_block(state, in1, ilen, t8, r8);

state = _keccakf1600_(state);

out = s_out;
() = #unspill(out);

for i=0 to MLKEM_SYMBYTES/8
{
Expand All @@ -554,14 +553,13 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) {
}


fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64]
fn _sha3512_32(#spill_to_mmx reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64]
{
stack u64[25] state;
reg u8 c;
inline int i;
stack ptr u8[64] s_out;

s_out = out;
() = #spill(out);

state = __st0(state);

Expand All @@ -574,11 +572,13 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64]

state = _keccakf1600_(state);

out = s_out;
() = #unspill(out);

for i = 0 to 64 {
c = state[u8 (int) i];
out[i] = c;
}

return out;
}

Expand All @@ -601,17 +601,17 @@ fn _shake128_absorb34(reg ptr u64[25] state, reg const ptr u8[34] in) -> reg ptr
}


fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE]
fn _shake128_squeezeblock(reg ptr u64[25] state, #spill_to_mmx reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE]
{
stack ptr u8[SHAKE128_RATE] s_out;
reg u8 c;
inline int i;

s_out = out;
() = #spill(out);

state = _keccakf1600_(state);

out = s_out;
() = #unspill(out);

for i = 0 to SHAKE128_RATE { // SHAKE128 rate is 168: or 21 u64: TODO: 'compress' this for loop
c = state[u8 (int) i];
out[i] = c;
Expand All @@ -621,16 +621,15 @@ fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out)


#[returnaddress="stack"]
fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32]
fn _isha3_256(#spill_to_mmx reg ptr u8[32] out, #spill_to_mmx reg u64 in inlen) -> reg ptr u8[32]
{
stack u64[25] state;
stack ptr u8[32] s_out;
stack u64 s_in s_ilen s_r8;
reg u64 ilen r8 t64;
#spill_to_mmx reg u64 ilen r8;
reg u64 t64;
reg u8 t8;
inline int i;

s_out = out;
() = #spill(out);

state = __st0(state);

Expand All @@ -641,23 +640,19 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32]
{
state, in, ilen = __add_full_block(state, in, ilen, r8);

s_in = in;
s_ilen = ilen;
s_r8 = r8;
() = #spill(in, ilen, r8);

state = _keccakf1600_(state);

in = s_in;
ilen = s_ilen;
r8 = s_r8;
() = #unspill(in, ilen, r8);
}

t8 = 0x06;
state = __add_final_block(state, in, ilen, t8, r8);

state = _keccakf1600_(state);

out = s_out;
() = #unspill(out);

for i=0 to 4
{
Expand All @@ -669,14 +664,13 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32]
}

#[returnaddress="stack"]
fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u8[32]
fn _isha3_256_32(#spill_to_mmx reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u8[32]
{
stack u64[25] state;
stack ptr u8[32] s_out;
reg u64 t64;
inline int i;

s_out = out;
() = #spill(out);

state = __st0(state);

Expand All @@ -691,7 +685,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u

state = _keccakf1600_(state);

out = s_out;
() = #unspill(out);

for i=0 to 4
{
Expand All @@ -703,10 +697,9 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u
}

#[returnaddress="stack"]
fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64]
fn _sha3_512_64(#spill_to_mmx reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64]
{
stack u64[25] state;
stack ptr u8[64] out_s;
reg u64 t64;
inline int i;

Expand All @@ -721,11 +714,11 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64]
state[u8 64] ^= 0x06;
state[u8 SHA3_512_RATE - 1] ^= 0x80;

out_s = out;
() = #spill(out);

state = _keccakf1600_(state);

out = out_s;
() = #unspill(out);

for i = 0 to 8
{
Expand Down
48 changes: 38 additions & 10 deletions code/jasmin/mlkem_ref/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,71 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
reg u16 t;
reg u64 pos ctr;

#msf reg u64 ms;
reg bool cond;

ms = #init_msf();

ctr = offset;
pos = 0;

while (pos < SHAKE128_RATE - 2) {
if ctr < MLKEM_N {
while { cond = (pos < SHAKE128_RATE - 2); } (cond) {

ms = #update_msf(cond, ms);

cond = ctr < MLKEM_N;
if cond {
ms = #update_msf(cond, ms);

val1 = (16u)buf[pos];
t = (16u)buf[pos + 1];
val1 = #protect_16(val1, ms);

t = (16u)buf[pos + 1];
t = #protect_16(t, ms);

val2 = t;
val2 >>= 4;
t &= 0x0F;
t <<= 8;
val1 |= t;

t = (16u)buf[pos + 2];
t = #protect_16(t, ms);

t <<= 4;
val2 |= t;
pos += 3;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond {
ms = #update_msf(cond, ms);
rp[ctr] = val1;
ctr += 1;
} else {
ms = #update_msf(!cond, ms);
}

#[declassify]
cond = val2 < MLKEM_Q;
if cond {
if(ctr < MLKEM_N)
{
ms = #update_msf(cond, ms);

cond = ctr < MLKEM_N;
if cond {
ms = #update_msf(cond, ms);
rp[ctr] = val2;
ctr += 1;
} else {
ms = #update_msf(!cond, ms);
}
} else {
ms = #update_msf(!cond, ms);
}

} else {
ms = #update_msf(!cond, ms);
pos = SHAKE128_RATE;
}

}

return ctr, rp;
Expand All @@ -64,8 +90,8 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[
reg u8 c;
reg u16 t;
reg u64 ctr k;
stack u64 sctr;
stack u64 stransposed;
#mmx reg u64 sctr;
#mmx reg u64 stransposed;
inline int j i;

stransposed = transposed;
Expand All @@ -81,6 +107,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[
for j = 0 to MLKEM_K
{
transposed = stransposed;

if(transposed == 0)
{
extseed[MLKEM_SYMBYTES] = j;
Expand All @@ -100,6 +127,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[
sctr = ctr;
state, buf = _shake128_squeezeblock(state, buf);
ctr = sctr;

ctr, poly = __rej_uniform(poly, ctr, buf);
}

Expand Down
Loading

0 comments on commit 97cc863

Please sign in to comment.