Skip to content

Commit

Permalink
Merge pull request #291 from pq-code-package/input-validation
Browse files Browse the repository at this point in the history
Add input validation for public key and secret key
  • Loading branch information
hanno-becker authored Oct 31, 2024
2 parents 10fd1dd + 2f72643 commit 289f01c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 8 deletions.
72 changes: 67 additions & 5 deletions mlkem/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,58 @@
#include "randombytes.h"
#include "symmetric.h"
#include "verify.h"

/*************************************************
* Name: check_pk
*
* Description: Implements modulus check mandated by FIPS203,
* i.e., ensures that coefficients are in [0,q-1].
* Described in Section 7.2 of FIPS203.
*
* Arguments: - const uint8_t *pk: pointer to input public key
* (an already allocated array of MLKEM_PUBLICKEYBYTES bytes)
**
* Returns 0 on success, and -1 on failure
**************************************************/
static int check_pk(const uint8_t pk[MLKEM_PUBLICKEYBYTES]) {
polyvec p;
uint8_t p_reencoded[MLKEM_POLYVECBYTES];
polyvec_frombytes(&p, pk);
polyvec_reduce(&p);
polyvec_tobytes(p_reencoded, &p);
// Data is public, so a variable-time memcmp() is OK
if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) {
return -1;
}
return 0;
}

/*************************************************
* Name: check_sk
*
* Description: Implements public key hash check mandated by FIPS203,
* i.e., ensures that
* sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32])
* Described in Section 7.3 of FIPS203.
*
* Arguments: - const uint8_t *sk: pointer to input private key
* (an already allocated array of MLKEM_SECRETKEYBYTES bytes)
*
* Returns 0 on success, and -1 on failure
**************************************************/
static int check_sk(const uint8_t sk[MLKEM_SECRETKEYBYTES]) {
uint8_t test[MLKEM_SYMBYTES];
// The parts of `sk` being hashed and compared here are public, so
// no public information is leaked through the runtime or the return value
// of this function.
hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_PUBLICKEYBYTES);
if (memcmp(sk + MLKEM_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test,
MLKEM_SYMBYTES)) {
return -1;
}
return 0;
}

/*************************************************
* Name: crypto_kem_keypair_derand
*
Expand Down Expand Up @@ -71,14 +123,19 @@ int crypto_kem_keypair(uint8_t *pk, uint8_t *sk) {
* (an already allocated array filled with MLKEM_SYMBYTES random
*bytes)
**
* Returns 0 (success)
* Returns 0 on success, and -1 if the public key modulus check (see Section 7.2
* of FIPS203) fails.
**************************************************/
int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk,
const uint8_t *coins) {
uint8_t buf[2 * MLKEM_SYMBYTES] ALIGN;
/* Will contain key, coins */
uint8_t kr[2 * MLKEM_SYMBYTES] ALIGN;

if (check_pk(pk)) {
return -1;
}

memcpy(buf, coins, MLKEM_SYMBYTES);

/* Multitarget countermeasure for coins + contributory KEM */
Expand All @@ -105,13 +162,13 @@ int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk,
* - const uint8_t *pk: pointer to input public key
* (an already allocated array of MLKEM_PUBLICKEYBYTES bytes)
*
* Returns 0 (success)
* Returns 0 on success, and -1 if the public key modulus check (see Section 7.2
* of FIPS203) fails.
**************************************************/
int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk) {
uint8_t coins[MLKEM_SYMBYTES] ALIGN;
randombytes(coins, MLKEM_SYMBYTES);
crypto_kem_enc_derand(ct, ss, pk, coins);
return 0;
return crypto_kem_enc_derand(ct, ss, pk, coins);
}

/*************************************************
Expand All @@ -127,7 +184,8 @@ int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk) {
* - const uint8_t *sk: pointer to input private key
* (an already allocated array of MLKEM_SECRETKEYBYTES bytes)
*
* Returns 0.
* Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of
* FIPS203) fails.
*
* On failure, ss will contain a pseudo-random value.
**************************************************/
Expand All @@ -139,6 +197,10 @@ int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk) {
uint8_t cmp[MLKEM_CIPHERTEXTBYTES + MLKEM_SYMBYTES] ALIGN;
const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES;

if (check_sk(sk)) {
return -1;
}

indcpa_dec(buf, ct, sk);

/* Multitarget countermeasure for coins + contributory KEM */
Expand Down
72 changes: 69 additions & 3 deletions test/test_mlkem.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,60 @@ static int test_keys(void) {
return 0;
}

static int test_invalid_pk(void) {
uint8_t pk[CRYPTO_PUBLICKEYBYTES];
uint8_t sk[CRYPTO_SECRETKEYBYTES];
uint8_t ct[CRYPTO_CIPHERTEXTBYTES];
uint8_t key_b[CRYPTO_BYTES];
int rc;
// Alice generates a public key
crypto_kem_keypair(pk, sk);

// Bob derives a secret key and creates a response
rc = crypto_kem_enc(ct, key_b, pk);

if (rc) {
printf("ERROR test_invalid_pk\n");
return 1;
}

// set first public key coefficient to 4095 (0xFFF)
pk[0] = 0xFF;
pk[1] |= 0x0F;
// Bob derives a secret key and creates a response
rc = crypto_kem_enc(ct, key_b, pk);

if (!rc) {
printf("ERROR test_invalid_pk\n");
return 1;
}
return 0;
}

static int test_invalid_sk_a(void) {
uint8_t pk[CRYPTO_PUBLICKEYBYTES];
uint8_t sk[CRYPTO_SECRETKEYBYTES];
uint8_t ct[CRYPTO_CIPHERTEXTBYTES];
uint8_t key_a[CRYPTO_BYTES];
uint8_t key_b[CRYPTO_BYTES];
int rc;

// Alice generates a public key
crypto_kem_keypair(pk, sk);

// Bob derives a secret key and creates a response
crypto_kem_enc(ct, key_b, pk);

// Replace secret key with random values
randombytes(sk, CRYPTO_SECRETKEYBYTES);
// Replace first part of secret key with random values
randombytes(sk, 10);

// Alice uses Bobs response to get her shared key
crypto_kem_dec(key_a, ct, sk);
// This should fail due to wrong sk
rc = crypto_kem_dec(key_a, ct, sk);
if (rc) {
printf("ERROR test_invalid_sk_a\n");
return 1;
}

if (!memcmp(key_a, key_b, CRYPTO_BYTES)) {
printf("ERROR invalid sk\n");
Expand All @@ -58,6 +94,34 @@ static int test_invalid_sk_a(void) {
return 0;
}

static int test_invalid_sk_b(void) {
uint8_t pk[CRYPTO_PUBLICKEYBYTES];
uint8_t sk[CRYPTO_SECRETKEYBYTES];
uint8_t ct[CRYPTO_CIPHERTEXTBYTES];
uint8_t key_a[CRYPTO_BYTES];
uint8_t key_b[CRYPTO_BYTES];
int rc;

// Alice generates a public key
crypto_kem_keypair(pk, sk);

// Bob derives a secret key and creates a response
crypto_kem_enc(ct, key_b, pk);

// Replace H(pk) with radom values;
randombytes(sk + CRYPTO_SECRETKEYBYTES - 64, 32);

// Alice uses Bobs response to get her shared key
// This should fail due to the input validation
rc = crypto_kem_dec(key_a, ct, sk);
if (!rc) {
printf("ERROR test_invalid_sk_b\n");
return 1;
}

return 0;
}

static int test_invalid_ciphertext(void) {
uint8_t pk[CRYPTO_PUBLICKEYBYTES];
uint8_t sk[CRYPTO_SECRETKEYBYTES];
Expand Down Expand Up @@ -98,7 +162,9 @@ int main(void) {

for (i = 0; i < NTESTS; i++) {
r = test_keys();
r |= test_invalid_pk();
r |= test_invalid_sk_a();
r |= test_invalid_sk_b();
r |= test_invalid_ciphertext();
if (r) {
return 1;
Expand Down

0 comments on commit 289f01c

Please sign in to comment.