From 9f4dac228ea12694e4b61e5d23180fc6be5d82d8 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 28 Jul 2024 22:09:14 +0800 Subject: [PATCH] Update kyber.c KEM passed --- src/kyber.c | 308 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 204 insertions(+), 104 deletions(-) diff --git a/src/kyber.c b/src/kyber.c index bb702f3e..ecbe1209 100644 --- a/src/kyber.c +++ b/src/kyber.c @@ -54,6 +54,22 @@ #define KYBER_TEST +/* +CRYSTALS-Kyber Algorithm Specifications and Supporing Documentation (version 3.02) + + + FIPS-202 90s + + XOF SHAKE-128 AES256-CTR MGF1-SM3 + H SHA3-256 SHA256 SM3 + G SHA3-512 SHA512 MGF1-SM3 + PRF(s,b) SHAKE-256(s||b) AES256-CTR HKDF-SM3 + KDF SHAKE-256 SHA256 HKDF-SM3 + +*/ + + + typedef int16_t kyber_poly_t[256]; typedef struct { @@ -84,8 +100,6 @@ typedef KYBER_CPA_CIPHERTEXT KYBER_CIPHERTEXT; - - void kyber_h_hash(const uint8_t *in, size_t inlen, uint8_t out[32]) { SM3_CTX ctx; @@ -128,12 +142,18 @@ static int kyber_prf(const uint8_t seed[32], uint8_t N, size_t outlen, uint8_t * return 1; } - static int kyber_kdf(const uint8_t in[64], uint8_t out[32]) { - return 0; + uint8_t key[32]; + sm3_hkdf_extract(NULL, 0, in, 64, key); + sm3_hkdf_expand(key, NULL, 0, 32, out); + gmssl_secure_clear(key, 32); + return 1; } +#define KYBER_FMT_POLY 1 +#define KYBER_FMT_HEX 2 + int kyber_poly_print(FILE *fp, int fmt, int ind, const char *label, const kyber_poly_t a) { int i; @@ -816,7 +836,7 @@ static int test_kyber_poly_ntt_mul(void) return 1; } -static int test_kyber_poly_ops(void) +static int test_kyber_poly_add(void) { kyber_poly_t a, b; @@ -931,7 +951,7 @@ static int test_kyber_poly_compress(void) //printf("compress(-, 1) bound = %d\n", bound); for (i = 0; i < 256; i++) { if (b[i] < -bound || b[i] > bound) { - // 这块是有可能出现错误的 + // FIXME: might failed error_print(); return -1; } @@ -1058,7 +1078,6 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) kyber_poly_t s[KYBER_K]; kyber_poly_t e[KYBER_K]; kyber_poly_t t[KYBER_K]; - uint8_t d[64]; uint8_t *rho = d; uint8_t *sigma = d + 32; @@ -1072,43 +1091,32 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) kyber_g_hash(d, 32, d); - format_bytes(stderr, 0, 0, "rho", rho, 32); - format_bytes(stderr, 0, 0, "sigma", sigma, 32); - - // AHat[i][j] = Parse(XOR(rho, j, i)) for (i = 0; i < KYBER_K; i++) { for (j = 0; j < KYBER_K; j++) { kyber_poly_uniform_sample(A[i][j], rho, j, i); - kyber_poly_print(stderr, 0, 0, "A[i][j]", A[i][j]); } } // s[i] = CBD_eta1(PRF(sigma, N++)) for (i = 0; i < KYBER_K; i++) { kyber_poly_cbd_sample(s[i], KYBER_ETA1, sigma, N); - //kyber_poly_set_all(s[i], 1); - kyber_poly_print(stderr, 0, 0, "s[i]", s[i]); N++; } // e[i] = CBD_eta1(PRF(sigma, N++)) for (i = 0; i < KYBER_K; i++) { kyber_poly_cbd_sample(e[i], KYBER_ETA1, sigma, N); - //kyber_poly_set_all(e[i], 0); - kyber_poly_print(stderr, 0, 0, "e[i]", e[i]); N++; } // sHat = NTT(s) for (i = 0; i < KYBER_K; i++) { kyber_poly_ntt(s[i]); - kyber_poly_print(stderr, 0, 0, "ntt(s[i])", s[i]); } // eHat = NTT(e) for (i = 0; i < KYBER_K; i++) { kyber_poly_ntt(e[i]); - kyber_poly_print(stderr, 0, 0, "ntt(e[i])", e[i]); } for (i = 0; i < KYBER_K; i++) { @@ -1123,14 +1131,8 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) kyber_poly_add(t[i], t[i], tmp); } kyber_poly_add(t[i], t[i], e[i]); - kyber_poly_print(stderr, 0, 0, "ntt(t[i])", t[i]); - } - - // 这里实际上t没有压缩,就是原来的值 - // t - A^T * s 实际上就是很小的值 - // output (pk, sk) for (i = 0; i < KYBER_K; i++) { kyber_poly_encode12(t[i], pk->t[i]); @@ -1138,8 +1140,9 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) } memcpy(pk->rho, rho, 32); - - fprintf(stderr, "\n"); + gmssl_secure_clear(d, sizeof(d)); + gmssl_secure_clear(s, sizeof(s)); + gmssl_secure_clear(e, sizeof(e)); return 1; } @@ -1165,15 +1168,58 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) */ +int kyber_cpa_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c) +{ + int i; + + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + for (i = 0; i < KYBER_K; i++) { + format_print(fp, fmt, ind, "c1[%d] (Compress10(u[%d]))", i, i); + format_bytes(fp, fmt, 0, "", c->c1[i], KYBER_C1_SIZE); + } + format_bytes(fp, fmt, ind, "c2 (Compress4(v))", c->c2, KYBER_C2_SIZE); + return 1; +} + +int kyber_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c) +{ + return kyber_cpa_ciphertext_print(fp, fmt, ind, label, c); +} + +int kyber_cpa_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PUBLIC_KEY *pk) +{ + int i; + + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + for (i = 0; i < KYBER_K; i++) { + format_print(fp, fmt, ind, "ntt(t[%d])", i); + format_bytes(fp, fmt, 0, "", pk->t[i], 384); + } + format_bytes(fp, fmt, ind, "rho", pk->rho, 32); + return 1; +} + + + +int kyber_cpa_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PRIVATE_KEY *sk) +{ + int i; + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + for (i = 0; i < KYBER_K; i++) { + format_print(fp, fmt, ind, "ntt(s[%d])", i); + format_bytes(fp, fmt, 0, "", sk->s[i], 384); + } + return 1; +} + int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], const uint8_t rand[32], KYBER_CPA_CIPHERTEXT *out) { - int i, j; - int N = 0; - kyber_poly_t A[KYBER_K][KYBER_K]; - kyber_poly_t t[KYBER_K]; kyber_poly_t r[KYBER_K]; kyber_poly_t u[KYBER_K]; @@ -1181,52 +1227,41 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], kyber_poly_t e2; kyber_poly_t v; kyber_poly_t m; - - printf("%s() ok\n", __FUNCTION__); + int i, j; + int N = 0; // tHat = Decode12(pk) for (i = 0; i < KYBER_K; i++) { kyber_poly_decode12(t[i], pk->t[i]); - kyber_poly_print(stderr, 0, 0, "ntt(t[i])", t[i]); } // AHat^T[i][j] = Parse(XOR(rho, i, j)) for (i = 0; i < KYBER_K; i++) { for (j = 0; j < KYBER_K; j++) { kyber_poly_uniform_sample(A[i][j], pk->rho, i, j); - kyber_poly_print(stderr, 0, 0, "A[i][j]", A[i][j]); } } // r[i] = CBD_eta1(PRF(rand, N++)) for (i = 0; i < KYBER_K; i++) { kyber_poly_cbd_sample(r[i], KYBER_ETA1, rand, N); - //kyber_poly_set_all(r[i], 2); - kyber_poly_print(stderr, 0, 0, "r[i]", r[i]); N++; } // e1[i] = CBD_eta2(PRF(rand, N++)) for (i = 0; i < KYBER_K; i++) { kyber_poly_cbd_sample(e1[i], KYBER_ETA2, rand, N); - //kyber_poly_set_all(e1[i], 0); - kyber_poly_print(stderr, 0, 0, "e1[i]", e1[i]); N++; } // e2 = CBD_eta2(PRF(rand, N)) kyber_poly_cbd_sample(e2, KYBER_ETA2, rand, N); - //kyber_poly_set_all(e2, 0); - kyber_poly_print(stderr, 0, 0, "e2", e2); // rHat = NTT(r) for (i = 0; i < KYBER_K; i++) { kyber_poly_ntt(r[i]); - kyber_poly_print(stderr, 0, 0, "ntt(r[i])", r[i]); } - // 实际上 u == A^T * r + e1 - // u = NTT^-1(A^T * r) + e1 for (i = 0; i < KYBER_K; i++) { kyber_poly_set_zero(u[i]); @@ -1240,11 +1275,9 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], kyber_poly_inv_ntt(u[i]); kyber_poly_add(u[i], u[i], e1[i]); - kyber_poly_print(stderr, 0, 0, "u[i] = (A^T * r)[i]", u[i]); } // v = NTT^-1( t^T * r ) + e2 + round(q/2)*m - kyber_poly_set_zero(v); for (i = 0; i < KYBER_K; i++) { kyber_poly_t tmp; @@ -1253,15 +1286,6 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], } kyber_poly_inv_ntt(v); kyber_poly_add(v, v, e2); - kyber_poly_print(stderr, 0, 0, "t^T * r + e2", v); - - - // check - - // v = t^T * r + e2 == s^T * (A^T * r) == s^T * (u) - - // 验证 v 和 s^T * u 大概是相等的 - // 这里的主要问题是 s 的值是不知道的,并且s 是ntt(s) 而不是原始s if (0) { kyber_poly_t s[KYBER_K]; @@ -1279,26 +1303,15 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], kyber_poly_add(v_, v_, tmp); } - kyber_poly_print(stderr, 0, 0, "test v", v); - kyber_poly_print(stderr, 0, 0, "test v", v_); - kyber_poly_sub(v_, v_, v); kyber_poly_to_signed(v_, v_); - - kyber_poly_print(stderr, 0, 0, "delta", v_); } - - kyber_poly_decode1(m, in); kyber_poly_decompress(m, 1, m); kyber_poly_add(v, v, m); - - - - // c1 = Encode10(Compress(u, 10)) for (i = 0; i < KYBER_K; i++) { kyber_poly_compress(u[i], 10, u[i]); @@ -1309,6 +1322,11 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], kyber_poly_compress(v, 4, v); kyber_poly_encode4(v, out->c2); + gmssl_secure_clear(m, sizeof(m)); + gmssl_secure_clear(r, sizeof(r)); + gmssl_secure_clear(e1, sizeof(e1)); + gmssl_secure_clear(e2, sizeof(e2)); + return 1; } @@ -1354,45 +1372,75 @@ int kyber_cpa_decrypt(const KYBER_CPA_PRIVATE_KEY *sk, const KYBER_CPA_CIPHERTEX kyber_poly_compress(m, 1, m); kyber_poly_encode1(m, out); + gmssl_secure_clear(s, sizeof(s)); + gmssl_secure_clear(m, sizeof(m)); return 1; } int kyber_keygen(KYBER_PUBLIC_KEY *pk, KYBER_PRIVATE_KEY *sk) { - if (rand_bytes(sk->z, 32) != 1) { + if (kyber_cpa_keygen(pk, &sk->sk) != 1) { error_print(); return -1; } - if (kyber_cpa_keygen(&sk->pk, &sk->sk) != 1) { + + memcpy(&sk->pk, pk, sizeof(KYBER_PUBLIC_KEY)); + + kyber_h_hash((uint8_t *)pk, sizeof(KYBER_CPA_PUBLIC_KEY), sk->pk_hash); + + if (rand_bytes(sk->z, 32) != 1) { error_print(); return -1; } - kyber_h_hash((uint8_t *)&sk->pk, sizeof(KYBER_CPA_PUBLIC_KEY), sk->pk_hash); + return 1; +} +int kyber_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PRIVATE_KEY *sk) +{ + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + kyber_cpa_private_key_print(fp, fmt, ind, "privateKey", &sk->sk); + kyber_cpa_public_key_print(fp, fmt, ind, "publicKey", &sk->pk); + format_bytes(fp, fmt, ind, "publicKeyHash", sk->pk_hash, 32); + format_bytes(fp, fmt, ind, "z", sk->z, 32); return 1; } + + + +int kyber_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PUBLIC_KEY *pk) +{ + return kyber_cpa_public_key_print(fp, fmt, ind, label, pk); +} + + int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32]) { - uint8_t m[64]; + uint8_t m_h[64]; uint8_t K_r[64]; + uint8_t *m = m_h; + uint8_t *h = m_h + 32; + uint8_t *K_ = K_r; uint8_t *r = K_r + 32; + // m = rand(32) if (rand_bytes(m, 32) != 1) { error_print(); return -1; } - // m = H(rand(32)) + // m = H(m) kyber_h_hash(m, 32, m); + // h = H(pk) + kyber_h_hash((const uint8_t *)pk, sizeof(KYBER_PUBLIC_KEY), h); // (K_, r) = G(m || H(pk)) - kyber_h_hash((const uint8_t *)pk, sizeof(KYBER_PUBLIC_KEY), m + 32); - kyber_g_hash(m, 64, K_r); + kyber_g_hash(m_h, 64, K_r); // c = Kyber.CPA.Enc(pk, m, r) if (kyber_cpa_encrypt(pk, m, r, c) != 1) { @@ -1400,10 +1448,14 @@ int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32]) return -1; } + // H(c) + kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), r); + // K = KDF(K_ || H(c)) - kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), K_r + 32); kyber_kdf(K_r, K); + gmssl_secure_clear(m_h, sizeof(m_h)); + gmssl_secure_clear(K_r, sizeof(K_r)); return 1; } @@ -1411,78 +1463,131 @@ int kyber_decap(const KYBER_PRIVATE_KEY *sk, const KYBER_CIPHERTEXT *c, uint8_t { uint8_t m_h[64]; uint8_t K_r[64]; + uint8_t *m = m_h; + uint8_t *h = m_h + 32; + uint8_t *K_ = K_r; uint8_t *r = K_r + 32; - - KYBER_CIPHERTEXT c_; uint8_t c_hash[32]; - if (kyber_cpa_decrypt(&sk->sk, c, m_h) != 1) { + + // m' = Dec(sk, c) + if (kyber_cpa_decrypt(&sk->sk, c, m) != 1) { error_print(); return -1; } - // (K, r) = G(m || H(pk)) - memcpy(m_h + 32, sk->pk_hash, 32); + // h = H(pk) + memcpy(h, sk->pk_hash, 32); + + // (K_, r) = G(m || h) kyber_g_hash(m_h, 64, K_r); - if (kyber_cpa_encrypt(&sk->pk, m_h, r, &c_) != 1) { + // c_ = CPA.Enc(pk, m, r) + if (kyber_cpa_encrypt(&sk->pk, m, r, &c_) != 1) { + gmssl_secure_clear(m_h, sizeof(m_h)); + gmssl_secure_clear(K_r, sizeof(K_r)); error_print(); return -1; } + // H(c) kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), r); if (memcmp(c, &c_, sizeof(KYBER_CIPHERTEXT)) == 0) { + // K = KDF(K_||H(c)) kyber_kdf(K_r, K); } else { - memcpy(K_r, sk->z, 32); + error_print(); + memcpy(K_r, sk->z, 32); // TODO: const time kyber_kdf(K_r, K); } - + gmssl_secure_clear(m_h, sizeof(m_h)); + gmssl_secure_clear(K_r, sizeof(K_r)); return 1; } -static int test_kyber_cpa_keygen(void) +static int test_kyber_cpa(void) { KYBER_CPA_PUBLIC_KEY pk; KYBER_CPA_PRIVATE_KEY sk; KYBER_CPA_CIPHERTEXT c; + uint8_t m[32]; + uint8_t r[32]; + uint8_t m_[32]; - uint8_t r[32] = {0}; - uint8_t m[32] = {1,0,1,0}; - uint8_t K[32] = {0}; + if (rand_bytes(m, 32) != 1) { + error_print(); + return -1; + } + if (rand_bytes(r, 32) != 1) { + error_print(); + return -1; + } + if (kyber_cpa_keygen(&pk, &sk) != 1) { + error_print(); + return -1; + } + kyber_cpa_public_key_print(stderr, 0, 0, "publicKey", &pk); + kyber_cpa_private_key_print(stderr, 0, 0, "privateKey", &sk); + if (kyber_cpa_encrypt(&pk, m, r, &c) != 1) { + error_print(); + return -1; + } + kyber_cpa_ciphertext_print(stderr, 0, 0, "ciphertext", &c); - if (rand_bytes(r, 32) != 1) { + if (kyber_cpa_decrypt(&sk, &c, m_) != 1) { error_print(); return -1; } - if (rand_bytes(m, 32) != 1) { + if (memcmp(m_, m, 32) != 0) { error_print(); return -1; } + printf("%s() ok\n", __FUNCTION__); + return 1; +} - kyber_cpa_keygen(&pk, &sk); - - kyber_cpa_encrypt(&pk, m, r, &c); +static int test_kyber_kem(void) +{ + KYBER_PRIVATE_KEY sk; + KYBER_PUBLIC_KEY pk; + KYBER_CIPHERTEXT c; + uint8_t K[32]; + uint8_t K_[32]; + if (kyber_keygen(&pk, &sk) != 1) { + error_print(); + return -1; + } - kyber_cpa_decrypt(&sk, &c, K); + kyber_public_key_print(stderr, 0, 0, "pk", &pk); + kyber_private_key_print(stderr, 0, 0, "sk", &sk); - format_bytes(stderr, 0, 0, "m", m, 32); - format_bytes(stderr, 0, 0, "out", K, 32); + if (kyber_encap(&pk, &c, K) != 1) { + error_print(); + return -1; + } + kyber_ciphertext_print(stderr, 0, 0, "ciphertext", &c); + format_bytes(stderr, 0, 0, "KEM_K", K, 32); - if (memcmp(K, m, 32) != 0) { + if (kyber_decap(&sk, &c, K_) != 1) { error_print(); return -1; } + format_bytes(stderr, 0, 0, "DEC_K", K_, 32); + if (memcmp(K_, K, 32) != 0) { + error_print(); + return -1; + } + printf("%s() ok\n", __FUNCTION__); return 1; } @@ -1490,28 +1595,23 @@ static int test_kyber_cpa_keygen(void) - int main(void) { init_zeta(); - if (test_kyber_cpa_keygen() != 1) goto err; - - - return 1; - if (test_kyber_poly_ops() != 1) goto err; - if (test_kyber_poly_ntt_mul() != 1) goto err; - if (test_kyber_poly_ntt() != 1) goto err; if (test_kyber_poly_uniform_sample() != 1) goto err; if (test_kyber_poly_cbd_sample() != 1) goto err; if (test_kyber_poly_to_signed() != 1) goto err; if (test_kyber_poly_compress() != 1) goto err; - if (test_kyber_poly_encode12() != 1) goto err; if (test_kyber_poly_encode10() != 1) goto err; if (test_kyber_poly_encode4() != 1) goto err; if (test_kyber_poly_encode1() != 1) goto err; - + if (test_kyber_poly_add() != 1) goto err; + if (test_kyber_poly_ntt() != 1) goto err; + if (test_kyber_poly_ntt_mul() != 1) goto err; + if (test_kyber_cpa() != 1) goto err; + if (test_kyber_kem() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0;