Skip to content

Commit

Permalink
Refactor rand functions with enhanced error handling and return code
Browse files Browse the repository at this point in the history
Signed-off-by: Songling Han <[email protected]>
  • Loading branch information
songlingatpan committed Oct 28, 2024
1 parent 7132473 commit 533fbdc
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 124 deletions.
52 changes: 27 additions & 25 deletions src/common/rand/rand.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

#include <oqs/oqs.h>

void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read);
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read);
#ifdef OQS_USE_OPENSSL
void OQS_randombytes_openssl(uint8_t *random_array, size_t bytes_to_read);
OQS_STATUS OQS_randombytes_openssl(uint8_t *random_array, size_t bytes_to_read);
#endif

#ifdef OQS_USE_OPENSSL
#include "../ossl_helpers.h"
// Use OpenSSL's RAND_bytes as the default PRNG
static void (*oqs_randombytes_algorithm)(uint8_t *, size_t) = &OQS_randombytes_openssl;
static OQS_STATUS (*oqs_randombytes_algorithm)(uint8_t *, size_t) = &OQS_randombytes_openssl;
#else
static void (*oqs_randombytes_algorithm)(uint8_t *, size_t) = &OQS_randombytes_system;
static OQS_STATUS (*oqs_randombytes_algorithm)(uint8_t *, size_t) = &OQS_randombytes_system;
#endif
OQS_API OQS_STATUS OQS_randombytes_switch_algorithm(const char *algorithm) {
if (0 == strcasecmp(OQS_RAND_alg_system, algorithm)) {
Expand All @@ -49,67 +49,70 @@ OQS_API void OQS_randombytes_custom_algorithm(void (*algorithm_ptr)(uint8_t *, s
oqs_randombytes_algorithm = algorithm_ptr;
}

OQS_API void OQS_randombytes(uint8_t *random_array, size_t bytes_to_read) {
oqs_randombytes_algorithm(random_array, bytes_to_read);
OQS_API OQS_STATUS OQS_randombytes(uint8_t *random_array, size_t bytes_to_read) {
return oqs_randombytes_algorithm(random_array, bytes_to_read);
}

// Select the implementation for OQS_randombytes_system
#if defined(_WIN32)
void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
HCRYPTPROV hCryptProv;
if (!CryptAcquireContext(&hCryptProv, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT) ||
!CryptGenRandom(hCryptProv, (DWORD) bytes_to_read, random_array)) {
exit(EXIT_FAILURE); // better to fail than to return bad random data
return OQS_ERROR;
}
CryptReleaseContext(hCryptProv, 0);
return OQS_SUCCESS;
}
#elif defined(__APPLE__)
void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
arc4random_buf(random_array, bytes_to_read);
return OQS_SUCCESS;
}
#elif defined(OQS_EMBEDDED_BUILD)
void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
fprintf(stderr, "OQS_randombytes_system is not available in an embedded build.\n");
fprintf(stderr, "Call OQS_randombytes_custom_algorithm() to set a custom method for your system.\n");
exit(EXIT_FAILURE);
return OQS_ERROR;
}
#elif defined(OQS_HAVE_GETENTROPY)
void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
while (bytes_to_read > 256) {
if (getentropy(random_array, 256)) {
exit(EXIT_FAILURE);
return OQS_ERROR;
}
random_array += 256;
bytes_to_read -= 256;
}
if (getentropy(random_array, bytes_to_read)) {
exit(EXIT_FAILURE);
return OQS_ERROR;
}
return OQS_SUCCESS;
}
#else
void OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_system(uint8_t *random_array, size_t bytes_to_read) {
FILE *handle;
size_t bytes_read;

handle = fopen("/dev/urandom", "rb");
if (!handle) {
perror("OQS_randombytes");
exit(EXIT_FAILURE);
return OQS_ERROR;
}

bytes_read = fread(random_array, 1, bytes_to_read, handle);
if (bytes_read < bytes_to_read || ferror(handle)) {
perror("OQS_randombytes");
exit(EXIT_FAILURE);
fclose(handle);

if (bytes_read < bytes_to_read) {
return OQS_ERROR;
}

fclose(handle);
return OQS_SUCCESS;
}
#endif

#ifdef OQS_USE_OPENSSL
#define OQS_RAND_POLL_RETRY 3 // in case failure to get randomness is a temporary problem, allow some repeats
void OQS_randombytes_openssl(uint8_t *random_array, size_t bytes_to_read) {
OQS_STATUS OQS_randombytes_openssl(uint8_t *random_array, size_t bytes_to_read) {
int rep = OQS_RAND_POLL_RETRY;
SIZE_T_TO_INT_OR_EXIT(bytes_to_read, bytes_to_read_int)
do {
Expand All @@ -120,9 +123,8 @@ void OQS_randombytes_openssl(uint8_t *random_array, size_t bytes_to_read) {
} while (rep-- >= 0);
if (OSSL_FUNC(RAND_bytes)(random_array, bytes_to_read_int) != 1) {
fprintf(stderr, "No OpenSSL randomness retrieved. DRBG available?\n");
// because of void signature we have no other way to signal the problem
// we cannot possibly return without randomness
exit(EXIT_FAILURE);
return OQS_ERROR;
}
return OQS_SUCCESS;
}
#endif
3 changes: 2 additions & 1 deletion src/common/rand/rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ OQS_API void OQS_randombytes_custom_algorithm(void (*algorithm_ptr)(uint8_t *, s
*
* @param[out] random_array Pointer to the memory to fill with (pseudo)random bytes
* @param[in] bytes_to_read The number of random bytes to read into memory
* @return OQS_SUCCESS on success, OQS_ERROR otherwise.
*/
OQS_API void OQS_randombytes(uint8_t *random_array, size_t bytes_to_read);
OQS_API OQS_STATUS OQS_randombytes(uint8_t *random_array, size_t bytes_to_read);

#if defined(__cplusplus)
} // extern "C"
Expand Down
209 changes: 115 additions & 94 deletions src/common/rand/rand_nist.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,118 +27,139 @@ You are solely responsible for determining the appropriateness of using and dist
#include <oqs/aes.h>
#endif

void OQS_randombytes_nist_kat(unsigned char *x, size_t xlen);
OQS_STATUS OQS_randombytes_nist_kat(unsigned char *x, size_t xlen);

static OQS_NIST_DRBG_struct DRBG_ctx;
static void AES256_CTR_DRBG_Update(unsigned char *provided_data, unsigned char *Key, unsigned char *V);
static OQS_NIST_DRBG_struct DRBG_ctx;
static OQS_STATUS AES256_CTR_DRBG_Update(unsigned char *provided_data, unsigned char *Key, unsigned char *V);

// Use whatever AES implementation you have. This uses AES from openSSL library
// key - 256-bit AES key
// ctr - a 128-bit plaintext value
// buffer - a 128-bit ciphertext value
static void AES256_ECB(unsigned char *key, unsigned char *ctr, unsigned char *buffer) {
#ifdef OQS_USE_OPENSSL
EVP_CIPHER_CTX *ctx;
// Use whatever AES implementation you have. This uses AES from openSSL library
// key - 256-bit AES key
// ctr - a 128-bit plaintext value
// buffer - a 128-bit ciphertext value
static OQS_STATUS AES256_ECB(unsigned char *key, unsigned char *ctr, unsigned char *buffer) {
#ifdef OQS_USE_OPENSSL
EVP_CIPHER_CTX *ctx;

int len;
int len;

/* Create and initialise the context */
ctx = OSSL_FUNC(EVP_CIPHER_CTX_new)();
OQS_EXIT_IF_NULLPTR(ctx, "OpenSSL");
/* Create and initialise the context */
ctx = OSSL_FUNC(EVP_CIPHER_CTX_new)();
if (ctx == NULL) {
return OQS_ERROR;
}

OQS_OPENSSL_GUARD(OSSL_FUNC(EVP_EncryptInit_ex)(ctx, oqs_aes_256_ecb(), NULL, key, NULL));
OQS_OPENSSL_GUARD(OSSL_FUNC(EVP_EncryptUpdate)(ctx, buffer, &len, ctr, 16));
if (OSSL_FUNC(EVP_EncryptInit_ex)(ctx, oqs_aes_256_ecb(), NULL, key, NULL) != 1 ||
OSSL_FUNC(EVP_EncryptUpdate)(ctx, buffer, &len, ctr, 16) != 1) {
OSSL_FUNC(EVP_CIPHER_CTX_free)(ctx);
return OQS_ERROR;
}

/* Clean up */
OSSL_FUNC(EVP_CIPHER_CTX_free)(ctx);
#else
void *schedule = NULL;
OQS_AES256_ECB_load_schedule(key, &schedule);
OQS_AES256_ECB_enc(ctr, 16, key, buffer);
OQS_AES256_free_schedule(schedule);
#endif
}
/* Clean up */
OSSL_FUNC(EVP_CIPHER_CTX_free)(ctx);
#else
void *schedule = NULL;
OQS_AES256_ECB_load_schedule(key, &schedule);
OQS_AES256_ECB_enc(ctr, 16, key, buffer);
OQS_AES256_free_schedule(schedule);
#endif
return OQS_SUCCESS;
}

void OQS_randombytes_nist_kat_init_256bit(const uint8_t *entropy_input, const uint8_t *personalization_string) {
unsigned char seed_material[48];
OQS_STATUS OQS_randombytes_nist_kat_init_256bit(const uint8_t *entropy_input, const uint8_t *personalization_string) {
unsigned char seed_material[48];

memcpy(seed_material, entropy_input, 48);
if (personalization_string)
for (int i = 0; i < 48; i++) {
seed_material[i] ^= personalization_string[i];
memcpy(seed_material, entropy_input, 48);
if (personalization_string)
for (int i = 0; i < 48; i++) {
seed_material[i] ^= personalization_string[i];
}
memset(DRBG_ctx.Key, 0x00, 32);
memset(DRBG_ctx.V, 0x00, 16);
if (AES256_CTR_DRBG_Update(seed_material, DRBG_ctx.Key, DRBG_ctx.V) != OQS_SUCCESS) {
return OQS_ERROR;
}
memset(DRBG_ctx.Key, 0x00, 32);
memset(DRBG_ctx.V, 0x00, 16);
AES256_CTR_DRBG_Update(seed_material, DRBG_ctx.Key, DRBG_ctx.V);
DRBG_ctx.reseed_counter = 1;
}

void OQS_randombytes_nist_kat(unsigned char *x, size_t xlen) {
unsigned char block[16];
int i = 0;

while (xlen > 0) {
//increment V
for (int j = 15; j >= 0; j--) {
if (DRBG_ctx.V[j] == 0xff) {
DRBG_ctx.V[j] = 0x00;
DRBG_ctx.reseed_counter = 1;
return OQS_SUCCESS;
}

OQS_STATUS OQS_randombytes_nist_kat(unsigned char *x, size_t xlen) {
unsigned char block[16];
int i = 0;

while (xlen > 0) {
//increment V
for (int j = 15; j >= 0; j--) {
if (DRBG_ctx.V[j] == 0xff) {
DRBG_ctx.V[j] = 0x00;
} else {
DRBG_ctx.V[j]++;
break;
}
}
if (AES256_ECB(DRBG_ctx.Key, DRBG_ctx.V, block) != OQS_SUCCESS) {
return OQS_ERROR;
}
if (xlen > 15) {
memcpy(x + i, block, 16);
i += 16;
xlen -= 16;
} else {
DRBG_ctx.V[j]++;
break;
memcpy(x + i, block, xlen);
xlen = 0;
}
}
AES256_ECB(DRBG_ctx.Key, DRBG_ctx.V, block);
if (xlen > 15) {
memcpy(x + i, block, 16);
i += 16;
xlen -= 16;
} else {
memcpy(x + i, block, xlen);
xlen = 0;
if (AES256_CTR_DRBG_Update(NULL, DRBG_ctx.Key, DRBG_ctx.V) != OQS_SUCCESS) {
return OQS_ERROR;
}
DRBG_ctx.reseed_counter++;
return OQS_SUCCESS;
}
AES256_CTR_DRBG_Update(NULL, DRBG_ctx.Key, DRBG_ctx.V);
DRBG_ctx.reseed_counter++;
}

void OQS_randombytes_nist_kat_get_state(void *out) {
OQS_NIST_DRBG_struct *out_state = (OQS_NIST_DRBG_struct *)out;
if (out_state != NULL) {
memcpy(out_state->Key, DRBG_ctx.Key, sizeof(DRBG_ctx.Key));
memcpy(out_state->V, DRBG_ctx.V, sizeof(DRBG_ctx.V));
out_state->reseed_counter = DRBG_ctx.reseed_counter;

OQS_STATUS OQS_randombytes_nist_kat_get_state(void *out) {
OQS_NIST_DRBG_struct *out_state = (OQS_NIST_DRBG_struct *)out;
if (out_state != NULL) {
memcpy(out_state->Key, DRBG_ctx.Key, sizeof(DRBG_ctx.Key));
memcpy(out_state->V, DRBG_ctx.V, sizeof(DRBG_ctx.V));
out_state->reseed_counter = DRBG_ctx.reseed_counter;
return OQS_SUCCESS;
}
return OQS_ERROR;
}
}

void OQS_randombytes_nist_kat_set_state(const void *in) {
const OQS_NIST_DRBG_struct *in_state = (const OQS_NIST_DRBG_struct *)in;
if (in_state != NULL) {
memcpy(DRBG_ctx.Key, in_state->Key, sizeof(DRBG_ctx.Key));
memcpy(DRBG_ctx.V, in_state->V, sizeof(DRBG_ctx.V));
DRBG_ctx.reseed_counter = in_state->reseed_counter;

OQS_STATUS OQS_randombytes_nist_kat_set_state(const void *in) {
const OQS_NIST_DRBG_struct *in_state = (const OQS_NIST_DRBG_struct *)in;
if (in_state != NULL) {
memcpy(DRBG_ctx.Key, in_state->Key, sizeof(DRBG_ctx.Key));
memcpy(DRBG_ctx.V, in_state->V, sizeof(DRBG_ctx.V));
DRBG_ctx.reseed_counter = in_state->reseed_counter;
return OQS_SUCCESS;
}
return OQS_ERROR;
}
}

static void AES256_CTR_DRBG_Update(unsigned char *provided_data, unsigned char *Key, unsigned char *V) {
unsigned char temp[48];
static OQS_STATUS AES256_CTR_DRBG_Update(unsigned char *provided_data, unsigned char *Key, unsigned char *V) {
unsigned char temp[48];

for (int i = 0; i < 3; i++) {
//increment V
for (int j = 15; j >= 0; j--) {
if (V[j] == 0xff) {
V[j] = 0x00;
} else {
V[j]++;
break;
}
}

for (int i = 0; i < 3; i++) {
//increment V
for (int j = 15; j >= 0; j--) {
if (V[j] == 0xff) {
V[j] = 0x00;
} else {
V[j]++;
break;
if (AES256_ECB(Key, V, temp + 16 * i) != OQS_SUCCESS) {
return OQS_ERROR;
}
}

AES256_ECB(Key, V, temp + 16 * i);
if (provided_data != NULL)
for (int i = 0; i < 48; i++) {
temp[i] ^= provided_data[i];
}
memcpy(Key, temp, 32);
memcpy(V, temp + 32, 16);
return OQS_SUCCESS;
}
if (provided_data != NULL)
for (int i = 0; i < 48; i++) {
temp[i] ^= provided_data[i];
}
memcpy(Key, temp, 32);
memcpy(V, temp + 32, 16);
}
Loading

0 comments on commit 533fbdc

Please sign in to comment.