diff --git a/oqsprov/oqs_hyb_kem.c b/oqsprov/oqs_hyb_kem.c index 41f43d11..dc0599e0 100644 --- a/oqsprov/oqs_hyb_kem.c +++ b/oqsprov/oqs_hyb_kem.c @@ -10,6 +10,7 @@ #include "oqs_prov.h" #include #include +#include static OSSL_FUNC_kem_encapsulate_fn oqs_hyb_kem_encaps; static OSSL_FUNC_kem_decapsulate_fn oqs_hyb_kem_decaps; @@ -386,7 +387,7 @@ static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t *ctlen, ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_ct0); secret1 = OPENSSL_zalloc(secretLen1); ON_ERR_SET_GOTO(!secret1, ret, 0, err_secret1); - ct1 = OPENSSL_zalloc(ctLen1); + ct1 = OPENSSL_zalloc(ctLen1); // do not free ct1, freed later with cmpCT ON_ERR_SET_GOTO(!ct1, ret, 0, err_ct1); cmpCT = CompositeCiphertext_new(); @@ -404,14 +405,12 @@ static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t *ctlen, OPENSSL_free(temp); ON_ERR_SET_GOTO(1, ret, 0, err_cmpct); } - cmpCT->ct1->flags = 8; // do not check for unused bits ret2 = ASN1_STRING_set(cmpCT->ct2, ct1, ctLen1); if (!ret2) { OPENSSL_free(temp); ON_ERR_SET_GOTO(1, ret, 0, err_cmpct); } - cmpCT->ct2->flags = 8; // do not check for unused bits *ctlen = (size_t)i2d_CompositeCiphertext(cmpCT, &temp); if (ctlen <= 0) { @@ -441,10 +440,8 @@ static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t *ctlen, ret2 = ASN1_STRING_set(cmpCT->ct1, ct0, ctLen0); ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct); - cmpCT->ct1->flags = 8; // do not check for unused bits ret2 = ASN1_STRING_set(cmpCT->ct2, ct1, ctLen1); - cmpCT->ct2->flags = 8; // do not check for unused bits ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct); *ctlen = (size_t)i2d_CompositeCiphertext(cmpCT, &p); @@ -483,10 +480,11 @@ static int oqs_cmp_kem_decaps(void *vpkemctx, unsigned char *secret, pkemctx->kem->oqsx_provider_ctx.oqsx_evp_ctx->evp_info; size_t secretLen0 = 0, secretLen1 = 0; - size_t ctLen0 = 0, ctLen1 = 0; - const unsigned char *ct0 = NULL, *ct1 = NULL; unsigned char *secret0 = NULL, *secret1 = NULL; + CompositeCiphertext *cmpCT; + const unsigned char *p = ct; // temp ptr because d2i_* may move input ct ptr + ret2 = oqs_qs_kem_decaps_keyslot(vpkemctx, NULL, &secretLen0, NULL, 0, 0); ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err); secret0 = OPENSSL_malloc(secretLen0); @@ -497,39 +495,35 @@ static int oqs_cmp_kem_decaps(void *vpkemctx, unsigned char *secret, secret1 = OPENSSL_malloc(secretLen1); ON_ERR_SET_GOTO(!secret1, ret, 0, err_secret1); - if (secret == NULL) { - ret2 = oqs_kem_combiner( - pkemctx, NULL, secretLen1, NULL, secretLen0, NULL, ctLen1, NULL, - pkemctx->kem->pubkeylen_cmp[1], NULL, secretlen); - ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err); - - ON_ERR_SET_GOTO(1, ret, 1, err); - } - - CompositeCiphertext *cmpCT; - const unsigned char *p = ct; // temp ptr because d2i_* may move input ct ptr + cmpCT = CompositeCiphertext_new(); + ON_ERR_SET_GOTO(!cmpCT, ret, 0, err_cmpct); cmpCT = d2i_CompositeCiphertext(&cmpCT, (const unsigned char **)&p, ctlen); - ON_ERR_SET_GOTO(!cmpCT, ret, 0, err); + ON_ERR_SET_GOTO(!cmpCT, ret, 0, err_cmpct); - ct0 = ASN1_STRING_get0_data(cmpCT->ct1); - ctLen0 = ASN1_STRING_length(cmpCT->ct1); - ct1 = ASN1_STRING_get0_data(cmpCT->ct2); - ctLen1 = ASN1_STRING_length(cmpCT->ct2); - ON_ERR_SET_GOTO(!ct0 || !ct1, ret, 0, err_cmpct); + if (secret == NULL) { + ret2 = + oqs_kem_combiner(pkemctx, NULL, secretLen1, NULL, secretLen0, NULL, + cmpCT->ct2->length, NULL, + pkemctx->kem->pubkeylen_cmp[1], NULL, secretlen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_secret1); - ret2 = oqs_qs_kem_decaps_keyslot(vpkemctx, secret0, &secretLen0, ct0, - ctLen0, 0); - ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_secret1); + ON_ERR_SET_GOTO(1, ret, 1, err_secret1); + } - ret2 = oqs_evp_kem_decaps_keyslot(vpkemctx, secret1, &secretLen1, ct1, - ctLen1, 1); - ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_secret1); + ret2 = oqs_qs_kem_decaps_keyslot(vpkemctx, secret0, &secretLen0, + cmpCT->ct1->data, cmpCT->ct1->length, 0); + ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_cmpct); + + ret2 = oqs_evp_kem_decaps_keyslot(vpkemctx, secret1, &secretLen1, + cmpCT->ct2->data, cmpCT->ct2->length, 1); + ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_cmpct); ret2 = oqs_kem_combiner(pkemctx, secret1, secretLen1, secret0, secretLen0, - ct1, ctLen1, pkemctx->kem->comp_pubkey[1], + cmpCT->ct2->data, cmpCT->ct2->length, + pkemctx->kem->comp_pubkey[1], pkemctx->kem->pubkeylen_cmp[1], secret, secretlen); - ON_ERR_SET_GOTO(!ret2, ret, 0, err_secret1); + ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct); err_cmpct: CompositeCiphertext_free(cmpCT); diff --git a/oqsprov/oqs_kem.c b/oqsprov/oqs_kem.c index ab59dd88..b7607e55 100644 --- a/oqsprov/oqs_kem.c +++ b/oqsprov/oqs_kem.c @@ -21,6 +21,7 @@ #include #include #include +#include #include "oqs_prov.h" @@ -217,26 +218,31 @@ static int oqs_kem_combiner(const PROV_OQSKEM_CTX *pkemctx, ON_ERR_SET_GOTO(1, ret, 1, err); } - bufferLen = 4 + tradSSLen + mlkemSSLen + tradCTLen + tradPKLen + - sizeof(info->domSep); - buffer = OPENSSL_malloc(bufferLen); - ON_ERR_SET_GOTO(buffer == NULL, ret, 0, err); - - p = buffer; - memcpy(p, counter, 4); - p += 4; - memcpy(p, tradSS, tradSSLen); - p += tradSSLen; - memcpy(p, mlkemSS, mlkemSSLen); - p += mlkemSSLen; - memcpy(p, tradCT, tradCTLen); - p += tradCTLen; - memcpy(p, tradPK, tradPKLen); - p += tradPKLen; - memcpy(p, info->domSep, sizeof(info->domSep)); - - ret2 = EVP_Digest(buffer, bufferLen, output, (unsigned int *)outputLen, md, - NULL); + mdctx = EVP_MD_CTX_new(); + ON_ERR_SET_GOTO(!mdctx, ret, 0, err_buffer); + + ret2 = EVP_DigestInit_ex(mdctx, md, NULL); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, counter, 4); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, tradSS, tradSSLen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, mlkemSS, mlkemSSLen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, tradCT, tradCTLen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, tradPK, tradPKLen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestUpdate(mdctx, tradPK, tradPKLen); + ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); + + ret2 = EVP_DigestFinal_ex(mdctx, output, (unsigned int *)outputLen); ON_ERR_SET_GOTO(ret2 != 1, ret, 0, err_buffer); err_buffer: