From c4b6587fc8d92f9156c44b64ea848cf356efd256 Mon Sep 17 00:00:00 2001 From: Songling Han Date: Tue, 24 Sep 2024 03:24:35 +0000 Subject: [PATCH] Add allocator check in tests/test_code_conventions.py Signed-off-by: Songling Han --- src/common/common.c | 12 +++---- src/common/sha2/sha2_c.c | 8 ++--- src/common/sha3/ossl_sha3x4.c | 8 ++--- src/sig_stfl/lms/external/hss_alloc.c | 10 +++--- src/sig_stfl/lms/external/hss_generate.c | 4 +-- src/sig_stfl/lms/external/hss_keygen.c | 4 +-- .../lms/external/hss_thread_pthread.c | 10 +++--- tests/test_code_conventions.py | 36 ++++++++++++------- 8 files changed, 52 insertions(+), 40 deletions(-) diff --git a/src/common/common.c b/src/common/common.c index aa12b4e717..06fbda8167 100644 --- a/src/common/common.c +++ b/src/common/common.c @@ -301,7 +301,7 @@ void *OQS_MEM_checked_aligned_alloc(size_t alignment, size_t size) { OQS_API void OQS_MEM_secure_free(void *ptr, size_t len) { if (ptr != NULL) { OQS_MEM_cleanse(ptr, len); - OQS_MEM_insecure_free(ptr); // IGNORE free-check + OQS_MEM_insecure_free(ptr); } } @@ -309,7 +309,7 @@ OQS_API void OQS_MEM_insecure_free(void *ptr) { #if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && defined(OPENSSL_VERSION_NUMBER) OPENSSL_free(ptr); #else - free(ptr); + free(ptr); // IGNORE memory-check #endif } @@ -374,7 +374,7 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) { // | // diff = ptr - buffer const size_t offset = alignment - 1 + sizeof(uint8_t); - uint8_t *buffer = malloc(size + offset); + uint8_t *buffer = malloc(size + offset);// IGNORE memory-check if (!buffer) { return NULL; } @@ -384,7 +384,7 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) { ptrdiff_t diff = ptr - buffer; if (diff > UINT8_MAX) { // This should never happen in our code, but just to be safe - free(buffer); // IGNORE free-check + free(buffer); // IGNORE memory-check errno = EINVAL; return NULL; } @@ -405,7 +405,7 @@ void OQS_MEM_aligned_free(void *ptr) { uint8_t *u8ptr = ptr; OPENSSL_free(u8ptr - u8ptr[-1]); #elif defined(OQS_HAVE_ALIGNED_ALLOC) || defined(OQS_HAVE_POSIX_MEMALIGN) || defined(OQS_HAVE_MEMALIGN) - free(ptr); // IGNORE free-check + free(ptr); // IGNORE memory-check #elif defined(__MINGW32__) || defined(__MINGW64__) __mingw_aligned_free(ptr); #elif defined(_MSC_VER) @@ -414,6 +414,6 @@ void OQS_MEM_aligned_free(void *ptr) { // Reconstruct the pointer returned from malloc using the difference // stored one byte ahead of ptr. uint8_t *u8ptr = ptr; - free(u8ptr - u8ptr[-1]); // IGNORE free-check + free(u8ptr - u8ptr[-1]); // IGNORE memory-check #endif } diff --git a/src/common/sha2/sha2_c.c b/src/common/sha2/sha2_c.c index d35d3e6496..a8751a375f 100644 --- a/src/common/sha2/sha2_c.c +++ b/src/common/sha2/sha2_c.c @@ -588,22 +588,22 @@ void oqs_sha2_sha512_inc_ctx_clone_c(sha512ctx *stateout, const sha512ctx *state /* Destroy the hash state. */ void oqs_sha2_sha224_inc_ctx_release_c(sha224ctx *state) { - OQS_MEM_insecure_free(state->ctx); // IGNORE free-check + OQS_MEM_insecure_free(state->ctx); } /* Destroy the hash state. */ void oqs_sha2_sha256_inc_ctx_release_c(sha256ctx *state) { - OQS_MEM_insecure_free(state->ctx); // IGNORE free-check + OQS_MEM_insecure_free(state->ctx); } /* Destroy the hash state. */ void oqs_sha2_sha384_inc_ctx_release_c(sha384ctx *state) { - OQS_MEM_insecure_free(state->ctx); // IGNORE free-check + OQS_MEM_insecure_free(state->ctx); } /* Destroy the hash state. */ void oqs_sha2_sha512_inc_ctx_release_c(sha512ctx *state) { - OQS_MEM_insecure_free(state->ctx); // IGNORE free-check + OQS_MEM_insecure_free(state->ctx); } void oqs_sha2_sha256_inc_blocks_c(sha256ctx *state, const uint8_t *in, size_t inblocks) { diff --git a/src/common/sha3/ossl_sha3x4.c b/src/common/sha3/ossl_sha3x4.c index a1a69949a7..a5cfeb5242 100644 --- a/src/common/sha3/ossl_sha3x4.c +++ b/src/common/sha3/ossl_sha3x4.c @@ -94,7 +94,7 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out3, tmp + s->n_out, outlen); - OQS_MEM_insecure_free(tmp); // IGNORE free-check + OQS_MEM_insecure_free(tmp); } OSSL_FUNC(EVP_MD_CTX_free)(clone); s->n_out += outlen; @@ -117,7 +117,7 @@ static void SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); - OQS_MEM_insecure_free(s); // IGNORE free-check + OQS_MEM_insecure_free(s); } static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) { @@ -215,7 +215,7 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out3, tmp + s->n_out, outlen); - OQS_MEM_insecure_free(tmp); // IGNORE free-check + OQS_MEM_insecure_free(tmp); } OSSL_FUNC(EVP_MD_CTX_free)(clone); s->n_out += outlen; @@ -238,7 +238,7 @@ static void SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); - OQS_MEM_insecure_free(s); // IGNORE free-check + OQS_MEM_insecure_free(s); } static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) { diff --git a/src/sig_stfl/lms/external/hss_alloc.c b/src/sig_stfl/lms/external/hss_alloc.c index 70b5ca0a68..00c0d628b5 100644 --- a/src/sig_stfl/lms/external/hss_alloc.c +++ b/src/sig_stfl/lms/external/hss_alloc.c @@ -542,15 +542,15 @@ void hss_free_working_key(struct hss_working_key *w) { unsigned j, k; for (j=0; jsubtree[j][k]); // IGNORE free-check + OQS_MEM_insecure_free(tree->subtree[j][k]); hss_zeroize( tree, sizeof *tree ); /* We have seeds here */ } - OQS_MEM_insecure_free(tree); // IGNORE free-check + OQS_MEM_insecure_free(tree); } for (i=0; isigned_pk[i]); // IGNORE free-check + OQS_MEM_insecure_free(w->signed_pk[i]); } - OQS_MEM_insecure_free(w->stack); // IGNORE free-check + OQS_MEM_insecure_free(w->stack); hss_zeroize( w, sizeof *w ); /* We have secret information here */ - OQS_MEM_insecure_free(w); // IGNORE free-check + OQS_MEM_insecure_free(w); } diff --git a/src/sig_stfl/lms/external/hss_generate.c b/src/sig_stfl/lms/external/hss_generate.c index 44171abdc4..359706ad6f 100644 --- a/src/sig_stfl/lms/external/hss_generate.c +++ b/src/sig_stfl/lms/external/hss_generate.c @@ -796,7 +796,7 @@ bool hss_generate_working_key( #if DO_FLOATING_POINT /* Don't leak suborders on an intermediate error */ for (i=0; i<(sequence_t)count_order; i++) { - OQS_MEM_insecure_free( order[i].sub ); // IGNORE free-check + OQS_MEM_insecure_free( order[i].sub ); } #endif info->error_code = got_error; @@ -831,7 +831,7 @@ bool hss_generate_working_key( hash_size, tree->h, I); } - OQS_MEM_insecure_free( sub ); // IGNORE free-check + OQS_MEM_insecure_free( sub ); p_order->sub = 0; } #endif diff --git a/src/sig_stfl/lms/external/hss_keygen.c b/src/sig_stfl/lms/external/hss_keygen.c index 2f1482a298..6dc0d02b78 100644 --- a/src/sig_stfl/lms/external/hss_keygen.c +++ b/src/sig_stfl/lms/external/hss_keygen.c @@ -278,7 +278,7 @@ bool hss_generate_private_key( } else { hss_zeroize( context, PRIVATE_KEY_LEN ); } - OQS_MEM_insecure_free(temp_buffer); // IGNORE free-check + OQS_MEM_insecure_free(temp_buffer); return false; } @@ -355,7 +355,7 @@ bool hss_generate_private_key( /* Hey, what do you know -- it all worked! */ hss_zeroize( private_key, sizeof private_key ); /* Zeroize local copy of */ /* the private key */ - OQS_MEM_insecure_free(temp_buffer); // IGNORE free-check + OQS_MEM_insecure_free(temp_buffer); return true; } #endif diff --git a/src/sig_stfl/lms/external/hss_thread_pthread.c b/src/sig_stfl/lms/external/hss_thread_pthread.c index 741bae0c36..b5df4a6054 100644 --- a/src/sig_stfl/lms/external/hss_thread_pthread.c +++ b/src/sig_stfl/lms/external/hss_thread_pthread.c @@ -91,13 +91,13 @@ struct thread_collection *hss_thread_init(int num_thread) { col->num_thread = num_thread; if (0 != pthread_mutex_init( &col->lock, 0 )) { - OQS_MEM_insecure_free(col); // IGNORE free-check + OQS_MEM_insecure_free(col); return 0; } if (0 != pthread_mutex_init( &col->write_lock, 0 )) { pthread_mutex_destroy( &col->lock ); - OQS_MEM_insecure_free(col); // IGNORE free-check + OQS_MEM_insecure_free(col); return 0; } @@ -126,7 +126,7 @@ static void *worker_thread( void *arg ) { (w->function)(w->x.detail, col); /* Ok, we did that */ - OQS_MEM_insecure_free(w); // IGNORE free-check + OQS_MEM_insecure_free(w); /* Check if there's anything else to do */ pthread_mutex_lock( &col->lock ); @@ -219,7 +219,7 @@ void hss_thread_issue_work(struct thread_collection *col, /* Hmmm, couldn't spawn it; fall back */ default: /* On error condition */ pthread_mutex_unlock( &col->lock ); - OQS_MEM_insecure_free(w); // IGNORE free-check + OQS_MEM_insecure_free(w); function( detail, col ); return; } @@ -277,7 +277,7 @@ void hss_thread_done(struct thread_collection *col) { pthread_mutex_destroy( &col->lock ); pthread_mutex_destroy( &col->write_lock ); - OQS_MEM_insecure_free(col); // IGNORE free-check + OQS_MEM_insecure_free(col); } void hss_thread_before_write(struct thread_collection *col) { diff --git a/tests/test_code_conventions.py b/tests/test_code_conventions.py index ed88f483ab..081bf8dd9c 100644 --- a/tests/test_code_conventions.py +++ b/tests/test_code_conventions.py @@ -48,26 +48,38 @@ def test_spdx(): print(result) assert False -# Ensure "free" is not used unprotected in the main OQS code. -@helpers.filtered_test -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows") -def test_free(): +def test_memory_functions(): c_files = [] for path, _, files in os.walk('src'): - c_files += [os.path.join(path,f) for f in files if f[-2:] == '.c'] + c_files += [os.path.join(path, f) for f in files if f.endswith('.c')] + + memory_functions = ['free', 'malloc', 'calloc', 'realloc', 'strdup'] okay = True + for fn in c_files: with open(fn) as f: - # Find all lines that contain 'free(' but not '_free(' - for no, line in enumerate(f,1): - if not re.match(r'^.*[^_]free\(.*$', line): + content = f.read() + lines = content.splitlines() + for no, line in enumerate(lines, 1): + # Skip comments + if line.strip().startswith('//') or line.strip().startswith('/*'): continue - if 'IGNORE free-check' in line: + # Check if we're inside a multi-line comment + if '/*' in content[:content.find(line)] and '*/' not in content[:content.find(line)]: continue - okay = False - print("Suspicious `free` in {}:{}:{}".format(fn,no,line)) - assert okay, "'free' is used in some files. These should be changed to 'OQS_MEM_secure_free' or 'OQS_MEM_insecure_free' as appropriate. If you are sure you want to use 'free' in a particular spot, add the comment '// IGNORE free-check' on the line where 'free' occurs." + for func in memory_functions: + if re.search(r'\b{}\('.format(func), line) and not re.search(r'\b_{}\('.format(func), line): + if 'IGNORE memory-check' in line: + continue + okay = False + print(f"Suspicious `{func}` in {fn}:{no}:{line.strip()}") + + assert okay, ("Standard memory functions are used in some files. " + "These should be changed to OQS_MEM_* equivalents as appropriate. " + "If you are sure you want to use these functions in a particular spot, " + "add the comment '// IGNORE memory-check' on the line where the function occurs.") if __name__ == "__main__": + test_memory_functions() import sys pytest.main(sys.argv)