diff --git a/src/common/common.c b/src/common/common.c index 44280c141..4adbc4602 100644 --- a/src/common/common.c +++ b/src/common/common.c @@ -257,6 +257,9 @@ OQS_API int OQS_MEM_secure_bcmp(const void *a, const void *b, size_t len) { } OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) { + if (ptr == NULL) { + return; + } #if defined(OQS_USE_OPENSSL) OSSL_FUNC(OPENSSL_cleanse)(ptr, len); #elif defined(_WIN32) @@ -267,7 +270,7 @@ OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) { explicit_memset(ptr, 0, len); #elif defined(__STDC_LIB_EXT1__) || defined(OQS_HAVE_MEMSET_S) if (0U < len && memset_s(ptr, (rsize_t)len, 0, (rsize_t)len) != 0) { - abort(); + return NULL; //abort(); } #else typedef void *(*memset_t)(void *, int, size_t); @@ -275,12 +278,11 @@ OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) { memset_func(ptr, 0, len); #endif } - void *OQS_MEM_checked_malloc(size_t len) { void *ptr = OQS_MEM_malloc(len); if (ptr == NULL) { fprintf(stderr, "Memory allocation failed\n"); - abort(); + return NULL; //abort(); } return ptr; @@ -290,7 +292,7 @@ void *OQS_MEM_checked_aligned_alloc(size_t alignment, size_t size) { void *ptr = OQS_MEM_aligned_alloc(alignment, size); if (ptr == NULL) { fprintf(stderr, "Memory allocation failed\n"); - abort(); + return NULL; //abort(); } return ptr; @@ -391,12 +393,13 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) { } void OQS_MEM_aligned_free(void *ptr) { + if (ptr == NULL) { + return; + } #if defined(OQS_USE_OPENSSL) // Use OpenSSL's free function - if (ptr) { - uint8_t *u8ptr = ptr; - OPENSSL_free(u8ptr - u8ptr[-1]); - } + uint8_t *u8ptr = ptr; + OPENSSL_free(u8ptr - u8ptr[-1]); #elif defined(OQS_HAVE_ALIGNED_ALLOC) || defined(OQS_HAVE_POSIX_MEMALIGN) || defined(OQS_HAVE_MEMALIGN) free(ptr); // IGNORE free-check #elif defined(__MINGW32__) || defined(__MINGW64__) @@ -404,11 +407,9 @@ void OQS_MEM_aligned_free(void *ptr) { #elif defined(_MSC_VER) _aligned_free(ptr); #else - if (ptr) { - // Reconstruct the pointer returned from malloc using the difference - // stored one byte ahead of ptr. - uint8_t *u8ptr = ptr; - free(u8ptr - u8ptr[-1]); // IGNORE free-check - } + // Reconstruct the pointer returned from malloc using the difference + // stored one byte ahead of ptr. + uint8_t *u8ptr = ptr; + free(u8ptr - u8ptr[-1]); // IGNORE free-check #endif } diff --git a/src/common/ossl_helpers.c b/src/common/ossl_helpers.c index 2eaf4f586..dd5c7b5f0 100644 --- a/src/common/ossl_helpers.c +++ b/src/common/ossl_helpers.c @@ -20,6 +20,57 @@ static EVP_MD *sha256_ptr, *sha384_ptr, *sha512_ptr, static EVP_CIPHER *aes128_ecb_ptr, *aes128_ctr_ptr, *aes256_ecb_ptr, *aes256_ctr_ptr; +static void free_ossl_objects(void) { + if (sha256_ptr) { + OSSL_FUNC(EVP_MD_free)(sha256_ptr); + sha256_ptr = NULL; + } + if (sha384_ptr) { + OSSL_FUNC(EVP_MD_free)(sha384_ptr); + sha384_ptr = NULL; + } + if (sha512_ptr) { + OSSL_FUNC(EVP_MD_free)(sha512_ptr); + sha512_ptr = NULL; + } + if (sha3_256_ptr) { + OSSL_FUNC(EVP_MD_free)(sha3_256_ptr); + sha3_256_ptr = NULL; + } + if (sha3_384_ptr) { + OSSL_FUNC(EVP_MD_free)(sha3_384_ptr); + sha3_384_ptr = NULL; + } + if (sha3_512_ptr) { + OSSL_FUNC(EVP_MD_free)(sha3_512_ptr); + sha3_512_ptr = NULL; + } + if (shake128_ptr) { + OSSL_FUNC(EVP_MD_free)(shake128_ptr); + shake128_ptr = NULL; + } + if (shake256_ptr) { + OSSL_FUNC(EVP_MD_free)(shake256_ptr); + shake256_ptr = NULL; + } + if (aes128_ecb_ptr) { + OSSL_FUNC(EVP_CIPHER_free)(aes128_ecb_ptr); + aes128_ecb_ptr = NULL; + } + if (aes128_ctr_ptr) { + OSSL_FUNC(EVP_CIPHER_free)(aes128_ctr_ptr); + aes128_ctr_ptr = NULL; + } + if (aes256_ecb_ptr) { + OSSL_FUNC(EVP_CIPHER_free)(aes256_ecb_ptr); + aes256_ecb_ptr = NULL; + } + if (aes256_ctr_ptr) { + OSSL_FUNC(EVP_CIPHER_free)(aes256_ctr_ptr); + aes256_ctr_ptr = NULL; + } +} + static void fetch_ossl_objects(void) { sha256_ptr = OSSL_FUNC(EVP_MD_fetch)(NULL, "SHA256", NULL); sha384_ptr = OSSL_FUNC(EVP_MD_fetch)(NULL, "SHA384", NULL); @@ -40,35 +91,10 @@ static void fetch_ossl_objects(void) { !sha3_384_ptr || !sha3_512_ptr || !shake128_ptr || !shake256_ptr || !aes128_ecb_ptr || !aes128_ctr_ptr || !aes256_ecb_ptr || !aes256_ctr_ptr) { fprintf(stderr, "liboqs warning: OpenSSL initialization failure. Is provider for SHA, SHAKE, AES enabled?\n"); + free_ossl_objects(); } } -static void free_ossl_objects(void) { - OSSL_FUNC(EVP_MD_free)(sha256_ptr); - sha256_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(sha384_ptr); - sha384_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(sha512_ptr); - sha512_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(sha3_256_ptr); - sha3_256_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(sha3_384_ptr); - sha3_384_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(sha3_512_ptr); - sha3_512_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(shake128_ptr); - shake128_ptr = NULL; - OSSL_FUNC(EVP_MD_free)(shake256_ptr); - shake256_ptr = NULL; - OSSL_FUNC(EVP_CIPHER_free)(aes128_ecb_ptr); - aes128_ecb_ptr = NULL; - OSSL_FUNC(EVP_CIPHER_free)(aes128_ctr_ptr); - aes128_ctr_ptr = NULL; - OSSL_FUNC(EVP_CIPHER_free)(aes256_ecb_ptr); - aes256_ecb_ptr = NULL; - OSSL_FUNC(EVP_CIPHER_free)(aes256_ctr_ptr); - aes256_ctr_ptr = NULL; -} #endif // OPENSSL_VERSION_NUMBER >= 0x30000000L void oqs_ossl_destroy(void) { @@ -76,11 +102,7 @@ void oqs_ossl_destroy(void) { #if defined(OQS_USE_PTHREADS) pthread_once(&free_once_control, free_ossl_objects); #else - if (sha256_ptr || sha384_ptr || sha512_ptr || sha3_256_ptr || - sha3_384_ptr || sha3_512_ptr || shake128_ptr || shake256_ptr || - aes128_ecb_ptr || aes128_ctr_ptr || aes256_ecb_ptr || aes256_ctr_ptr) { - free_ossl_objects(); - } + free_ossl_objects(); #endif #endif } @@ -237,7 +259,6 @@ const EVP_CIPHER *oqs_aes_128_ecb(void) { return OSSL_FUNC(EVP_aes_128_ecb)(); #endif } - const EVP_CIPHER *oqs_aes_128_ctr(void) { #if OPENSSL_VERSION_NUMBER >= 0x30000000L #if defined(OQS_USE_PTHREADS) @@ -301,19 +322,19 @@ static pthread_once_t dlopen_once_control = PTHREAD_ONCE_INIT; #define ENSURE_LIBRARY pthread_once(&dlopen_once_control, ensure_library) #else #define ENSURE_LIBRARY do { \ - if (!libcrypto_dlhandle) { \ - ensure_library(); \ - } \ - } while (0) + if (!libcrypto_dlhandle) { \ + ensure_library(); \ + } \ + } while (0) #endif // OQS_USE_PTHREADS /* Define redirection symbols */ #if (2 <= __GNUC__ || (4 <= __clang_major__)) #define FUNC(ret, name, args, cargs) \ - static __typeof__(name)(*_oqs_ossl_sym_##name); + static __typeof__(name)(*_oqs_ossl_sym_##name); #else #define FUNC(ret, name, args, cargs) \ - static ret(*_oqs_ossl_sym_##name)args; + static ret(*_oqs_ossl_sym_##name)args; #endif #define VOID_FUNC FUNC #include "ossl_functions.h" @@ -322,19 +343,23 @@ static pthread_once_t dlopen_once_control = PTHREAD_ONCE_INIT; /* Define redirection wrapper functions */ #define FUNC(ret, name, args, cargs) \ -ret _oqs_ossl_##name args \ -{ \ - ENSURE_LIBRARY; \ - assert(_oqs_ossl_sym_##name); \ - return _oqs_ossl_sym_##name cargs; \ -} + ret _oqs_ossl_##name args \ + { \ + ENSURE_LIBRARY; \ + if (!_oqs_ossl_sym_##name) { \ + return (ret)0; \ + } \ + return _oqs_ossl_sym_##name cargs; \ + } #define VOID_FUNC(ret, name, args, cargs) \ -ret _oqs_ossl_##name args \ -{ \ - ENSURE_LIBRARY; \ - assert(_oqs_ossl_sym_##name); \ - _oqs_ossl_sym_##name cargs; \ -} + ret _oqs_ossl_##name args \ + { \ + ENSURE_LIBRARY; \ + if (!_oqs_ossl_sym_##name) { \ + return; \ + } \ + _oqs_ossl_sym_##name cargs; \ + } #include "ossl_functions.h" #undef VOID_FUNC #undef FUNC @@ -359,9 +384,9 @@ static void ensure_library(void) { } #define ENSURE_SYMBOL(name) \ - ensure_symbol(#name, (void **)&_oqs_ossl_sym_##name) + ensure_symbol(#name, (void **)&_oqs_ossl_sym_##name) #define FUNC(ret, name, args, cargs) \ - ENSURE_SYMBOL(name); + ENSURE_SYMBOL(name); #define VOID_FUNC FUNC #include "ossl_functions.h" #undef VOID_FUNC