Skip to content

Commit

Permalink
vectorize poly_from/to_msg
Browse files Browse the repository at this point in the history
Signed-off-by: Duc Tri Nguyen <[email protected]>
  • Loading branch information
cothan committed Jun 9, 2024
1 parent d7c40f8 commit 8536950
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ CFLAGS_NISTRANDOMBYTES = ${CFLAGS} ${INCLUDE_NISTRANDOM}
NISTFLAGS += -Wno-unused-result -O3 -fomit-frame-pointer
RM = /bin/rm

ASM_CLEAN = mlkem/asm/clean/rej_uniform_asm.s
ASM_CLEAN = mlkem/asm/clean/rej_uniform_asm.s mlkem/asm/clean/poly_asm.s

SOURCES = $(ASM_CLEAN) mlkem/kem.c mlkem/indcpa.c mlkem/polyvec.c mlkem/poly.c mlkem/ntt.c mlkem/cbd.c mlkem/reduce.c mlkem/verify.c mlkem/rej_uniform.c
SOURCESKECCAK = $(SOURCES) fips202/keccakf1600.c fips202/fips202.c mlkem/symmetric-shake.c
Expand Down
15 changes: 15 additions & 0 deletions mlkem/asm/clean/poly_asm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-License-Identifier: Apache-2.0
#ifndef POLY_ASM_H
#define POLY_ASM_H

#include "params.h"

void poly_frommsg_asm(int16_t coeffs[KYBER_N],
const uint8_t msg[KYBER_INDCPA_MSGBYTES],
const uint16_t bits[8]);

void poly_tomsg_asm(uint8_t msg[KYBER_INDCPA_MSGBYTES],
const int16_t coeffs[KYBER_N],
const uint16_t position[8]);

#endif
224 changes: 224 additions & 0 deletions mlkem/asm/clean/poly_asm.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
// SPDX-License-Identifier: Apache-2.0
/*************************************************
* Name: poly_frommsg_asm
*
* Description: Convert 32-byte message to polynomial
*
* Arguments: - int16_t *coeffs: pointer to output polynomial
* - const uint8_t *msg: pointer to input message
* - const uint16_t *bits: pointer to bit_table
**************************************************/
.align 4
.global poly_frommsg_asm
.global _poly_frommsg_asm
poly_frommsg_asm:
_poly_frommsg_asm:

/* Input registers */
coeffs .req x0
msg .req x1
bit_table .req x2

/* Temporary registers */
iter .req x9
tmp .req w10

/* Vector registers */
const .req v16
bits .req v17
bitsq .req q17
a0 .req v18
a0q .req q18

/* Vectorize code start */
mov tmp, #1665 // (KYBER_Q + 1) / 2
dup const.8h, tmp
ldr bitsq, [bit_table]
mov iter, xzr
loop:
ldrb tmp, [msg, iter]
dup a0.8h, tmp
and a0.16b, a0.16b, bits.16b
cmeq a0.8h, a0.8h, #0
bic a0.16b, const.16b, a0.16b
str a0q, [coeffs, iter, lsl #4]
add iter, iter, #1
cmp iter, #32 // KYBER_N / 8
b.ne loop
ret

/* Input registers */
.unreq coeffs
.unreq msg
.unreq bit_table

/* Temporary registers */
.unreq iter
.unreq tmp

/* Vector registers */
.unreq const
.unreq bits
.unreq bitsq
.unreq a0
.unreq a0q
/*************************************************
* Name: poly_tomsg_asm
*
* Description: Convert polynomial to 32-byte message
*
* Arguments: - uint8_t *msg: pointer to output message
* - int16_t *coeffs: pointer to input polynomial
**************************************************/
.align 4
.global poly_tomsg_asm
.global _poly_tomsg_asm
poly_tomsg_asm:
_poly_tomsg_asm:

/* Input registers */
msg .req x0
coeffs .req x1
position .req x2

/* Temporary registers */
iter .req x9
tmp .req w10
idx_addr .req x11

r0 .req w12
r1 .req w13
r2 .req w14
r3 .req w15

/* Vector registers */
vhq .req v16
vhqinv .req v17

a0 .req v18
a1 .req v19
a2 .req v20
a3 .req v21

idx .req v22
idxq .req q22

t0 .req h23
t1 .req h24
t2 .req h25
t3 .req h26

/* Vectorize code start */

mov w9, #1164 // KYBER_Q / 2
dup vhq.8h, w9
mov w10, #10079 // 2^26 / KYBER_Q / 2
dup vhqinv.8h, w10
ldr idxq, [position]

mov iter, xzr

loop32:
ld1 {a0.8h, a1.8h, a2.8h, a3.8h}, [x1], #64

/* t << = 1; */
add a0.8h, a0.8h, a0.8h
add a1.8h, a1.8h, a1.8h
add a2.8h, a2.8h, a2.8h
add a3.8h, a3.8h, a3.8h

/* t += KYBER_Q/2 */
add a0.8h, a0.8h, vhq.8h
add a1.8h, a1.8h, vhq.8h
add a2.8h, a2.8h, vhq.8h
add a3.8h, a3.8h, vhq.8h

/*
* t = t / KYBER_Q
* Instead of direct division, we multiply with inverse of KYBER_Q and utilize the sqdmulh instruction.
* To do so, we have a few options:
* 80635 = round(2^28/KYBER_Q) as in the reference C implementation
* However, we need number that fit in the range [-2^15..2^15]
* So we pick:
* 20159 = round(2^26/KYBER_Q)
* Because we use sqdmulh instruction, the constant will be:
* 10079 = round(2^26/KYBER_Q/2)
* sqdmulh helps us shift right by 16, we need additional shift right by 10 to complete shift right by 26.
* The other approach is to use smull/umull instructions, but they are inefficient.
*/
sqdmulh a0.8h, a0.8h, vhqinv.8h
sqdmulh a1.8h, a1.8h, vhqinv.8h
sqdmulh a2.8h, a2.8h, vhqinv.8h
sqdmulh a3.8h, a3.8h, vhqinv.8h

ushr a0.8h, a0.8h, #10
ushr a1.8h, a1.8h, #10
ushr a2.8h, a2.8h, #10
ushr a3.8h, a3.8h, #10

/* t = t & 1 */

bic a0.8h, #62
bic a1.8h, #62
bic a2.8h, #62
bic a3.8h, #62

/* Position the bits */
ushl a0.8h, a0.8h, idx.8h
ushl a1.8h, a1.8h, idx.8h
ushl a2.8h, a2.8h, idx.8h
ushl a3.8h, a3.8h, idx.8h

/* Extract the result */
addv t0, a0.8h
addv t1, a1.8h
addv t2, a2.8h
addv t3, a3.8h

fmov r0, t0
fmov r1, t1
fmov r2, t2
fmov r3, t3

strb r0, [x0], #1
strb r1, [x0], #1
strb r2, [x0], #1
strb r3, [x0], #1

add iter, iter, #4
cmp iter, #32
b.ne loop32

ret

/* Input registers */
.unreq msg
.unreq coeffs
.unreq position

/* Temporary registers */
.unreq iter
.unreq tmp
.unreq idx_addr

.unreq r0
.unreq r1
.unreq r2
.unreq r3

/* Vector registers */
.unreq vhq
.unreq vhqinv

.unreq a0
.unreq a1
.unreq a2
.unreq a3

.unreq idx
.unreq idxq

.unreq t0
.unreq t1
.unreq t2
.unreq t3
4 changes: 2 additions & 2 deletions mlkem/asm/clean/rej_uniform_asm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: Apache-2.0
#ifndef REJ_UNIFORM_ASM
#define REJ_UNIFORM_ASM
#ifndef REJ_UNIFORM_ASM_H
#define REJ_UNIFORM_ASM_H

unsigned int rej_uniform_asm(int16_t *r,
const uint8_t *buf,
Expand Down
34 changes: 6 additions & 28 deletions mlkem/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "reduce.h"
#include "cbd.h"
#include "symmetric.h"
#include "poly_asm.h"
#include "rej_uniform.h"

/************************************************************
* Name: scalar_compress_q_16
Expand Down Expand Up @@ -258,19 +260,10 @@ void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) {
* - const uint8_t *msg: pointer to input message
**************************************************/
void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) {
unsigned int i, j;
int16_t mask;

#if (KYBER_INDCPA_MSGBYTES != KYBER_N/8)
#error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!"
#endif

for (i = 0; i < KYBER_N / 8; i++) {
for (j = 0; j < 8; j++) {
mask = -(int16_t)((msg[i] >> j) & 1);
r->coeffs[8 * i + j] = mask & ((KYBER_Q + 1) / 2);
}
}
poly_frommsg_asm(r->coeffs, msg, bit_table);
}

/*************************************************
Expand All @@ -281,24 +274,9 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) {
* Arguments: - uint8_t *msg: pointer to output message
* - const poly *a: pointer to input polynomial
**************************************************/
void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) {
unsigned int i, j;
uint32_t t;

for (i = 0; i < KYBER_N / 8; i++) {
msg[i] = 0;
for (j = 0; j < 8; j++) {
t = a->coeffs[8 * i + j];
// t += ((int16_t)t >> 15) & KYBER_Q;
// t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1;
t <<= 1;
t += 1665;
t *= 80635;
t >>= 28;
t &= 1;
msg[i] |= t << j;
}
}
void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r) {
const uint16_t position[8] = {0, 1, 2, 3, 4, 5, 6, 7};
poly_tomsg_asm(msg, r->coeffs, position);
}

/*************************************************
Expand Down
2 changes: 2 additions & 0 deletions mlkem/rej_uniform.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ unsigned int rej_uniform(int16_t *r,
const uint8_t *buf,
unsigned int buflen);

const uint16_t bit_table[8];

#endif

0 comments on commit 8536950

Please sign in to comment.