Skip to content

Commit

Permalink
re-add multi-hash support, just in case
Browse files Browse the repository at this point in the history
  • Loading branch information
acagliano committed Aug 29, 2024
1 parent d201045 commit bc5ec67
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 122 deletions.
46 changes: 23 additions & 23 deletions src/tls/core/aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ void aes_gcm_prepare_iv(struct tls_aes_context *ctx, const uint8_t *iv, size_t i
// Performs the action of generating the keys that will be used in every round of
// encryption. "key" is the user-supplied input key, "w" is the output key schedule,
// "keysize" is the length in bits of "key", must be 128, 192, or 256.
bool tls_aes_init(struct tls_aes_context *ctx, const uint8_t *key, size_t keysize, uint8_t* iv, size_t iv_len)
bool tls_aes_init(struct tls_aes_context *ctx, const uint8_t *key, size_t key_len, const uint8_t* iv, size_t iv_len)
{
if((ctx==NULL) ||
(key==NULL) ||
Expand All @@ -868,8 +868,8 @@ bool tls_aes_init(struct tls_aes_context *ctx, const uint8_t *key, size_t keysiz

if(iv_len>AES_BLOCK_SIZE) return false;
memset(ctx, 0, sizeof(struct tls_aes_context));
keysize<<=3;
switch (keysize) {
key_len<<=3;
switch (key_len) {
case 128: Nr = 10; Nk = 4; break;
case 192: Nr = 12; Nk = 6; break;
case 256: Nr = 14; Nk = 8; break;
Expand All @@ -878,7 +878,7 @@ bool tls_aes_init(struct tls_aes_context *ctx, const uint8_t *key, size_t keysiz

memcpy(ctx->iv, iv, iv_len);
//memset(&ctx->iv[iv_len], 0, 16-iv_len);
ctx->keysize = keysize;
ctx->keysize = key_len;
for (idx=0; idx < Nk; ++idx) {
ctx->round_keys[idx] = ((uint32_t)(key[4 * idx]) << 24) | ((uint32_t)(key[4 * idx + 1]) << 16) |
((uint32_t)(key[4 * idx + 2]) << 8) | ((uint32_t)(key[4 * idx + 3]));
Expand Down Expand Up @@ -913,7 +913,7 @@ enum GCM_OPS_ALLOWED {
};


bool tls_aes_update_aad(struct tls_aes_context* ctx, const void *aad, size_t aad_len){
bool tls_aes_update_aad(struct tls_aes_context* ctx, const uint8_t *aad, size_t aad_len){
if(ctx->_private.lock > LOCK_ALLOW_ALL) return false;

// update the tag for full blocks of aad in input, cache any partial blocks
Expand Down Expand Up @@ -947,14 +947,14 @@ bool tls_aes_digest(struct tls_aes_context* ctx, uint8_t* digest){
}

#define AES_BLOCKSIZE 16
bool tls_aes_encrypt(struct tls_aes_context* ctx, const void *in, size_t in_len, void *out)
bool tls_aes_encrypt(struct tls_aes_context* ctx, const uint8_t *inbuf, size_t in_len, uint8_t *outbuf)
{
uint8_t buf[AES_BLOCK_SIZE];
//int keysize = key->keysize;
//uint32_t *round_keys = key->round_keys;
int blocks, idx;

if(in==NULL || out==NULL || ctx==NULL) return false;
if(inbuf==NULL || outbuf==NULL || ctx==NULL) return false;
if(in_len == 0) return false;
if(ctx->op_assoc == AES_OP_DECRYPT) return false;
if(ctx->_private.lock > LOCK_ALLOW_ENCRYPT) return false;
Expand All @@ -977,20 +977,20 @@ bool tls_aes_encrypt(struct tls_aes_context* ctx, const void *in, size_t in_len,
// xor last bytes of encryption buf w/ new plaintext for new ciphertext
if(bytes_to_copy%AES_BLOCK_SIZE){
bytes_offset = AES_BLOCK_SIZE - bytes_to_copy;
memcpy(out, in, bytes_offset);
xor_buf(&ctx->_private.last_block[bytes_to_copy], out, bytes_offset);
memcpy(outbuf, inbuf, bytes_offset);
xor_buf(&ctx->_private.last_block[bytes_to_copy], outbuf, bytes_offset);
blocks = ((in_len - bytes_offset) / AES_BLOCK_SIZE);
}

// encrypt remaining plaintext
for(idx = 0; idx <= blocks; idx++){
bytes_to_copy = MIN(AES_BLOCK_SIZE, in_len - bytes_offset - (idx * AES_BLOCK_SIZE));
//bytes_to_pad = AES_BLOCK_SIZE-bytes_to_copy;
memcpy(&out[idx*AES_BLOCK_SIZE+bytes_offset], &in[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
memcpy(&outbuf[idx*AES_BLOCK_SIZE+bytes_offset], &inbuf[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
// memset(&buf[bytes_to_copy], 0, bytes_to_pad);
// if bytes_to_copy is less than blocksize, do nothing, since msg is truncated anyway
aes_encrypt_block(iv, buf, ctx);
xor_buf(buf, &out[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
xor_buf(buf, &outbuf[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
increment_iv(iv, AES_GCM_NONCE_LEN, AES_GCM_CTR_LEN); // increment iv for continued use
if(idx==blocks){
memcpy(ctx->_private.last_block, buf, AES_BLOCK_SIZE);
Expand All @@ -999,7 +999,7 @@ bool tls_aes_encrypt(struct tls_aes_context* ctx, const void *in, size_t in_len,
}

// authenticate the ciphertext
ghash(ctx, tag, out, in_len);
ghash(ctx, tag, outbuf, in_len);
ctx->_private.ct_len += in_len;


Expand All @@ -1008,20 +1008,20 @@ bool tls_aes_encrypt(struct tls_aes_context* ctx, const void *in, size_t in_len,
}

bool decrypt_call_from_verify = false;
bool tls_aes_decrypt(struct tls_aes_context* ctx, const void *in, size_t in_len, void *out)
bool tls_aes_decrypt(struct tls_aes_context* ctx, const uint8_t *inbuf, size_t in_len, uint8_t *outbuf)
{

if((ctx==NULL) ||
(in==NULL) ||
(inbuf==NULL) ||
(in_len==0)) return false;

if(!decrypt_call_from_verify && (out==NULL)) return false;
if(!decrypt_call_from_verify && (outbuf==NULL)) return false;
if(ctx->op_assoc == AES_OP_ENCRYPT) return false;
if(ctx->_private.lock > LOCK_ALLOW_ENCRYPT) return false;
ctx->op_assoc = AES_OP_DECRYPT;


uint8_t* iv = ctx->iv;
uint8_t *iv = ctx->iv;
uint8_t *tag = ctx->_private.auth_tag;
uint8_t buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE];
int blocks = in_len / AES_BLOCK_SIZE, idx;
Expand All @@ -1034,26 +1034,26 @@ bool tls_aes_decrypt(struct tls_aes_context* ctx, const void *in, size_t in_len,
memset(buf_in, 0, AES_BLOCK_SIZE);
ghash(ctx, tag, buf_in, AES_BLOCK_SIZE - ctx->_private.aad_cache_len);
}
ghash(ctx, tag, in, in_len);
ghash(ctx, tag, inbuf, in_len);
ctx->_private.ct_len += in_len;

if(out){
if(outbuf){
ctx->_private.lock = LOCK_ALLOW_ENCRYPT;
if(bytes_to_copy%AES_BLOCK_SIZE){
bytes_offset = AES_BLOCK_SIZE - bytes_to_copy;
memcpy(out, in, bytes_offset);
xor_buf(&ctx->_private.last_block[bytes_to_copy], out, bytes_offset);
memcpy(outbuf, inbuf, bytes_offset);
xor_buf(&ctx->_private.last_block[bytes_to_copy], outbuf, bytes_offset);
blocks = ((in_len - bytes_offset) / AES_BLOCK_SIZE);
}

// encrypt remaining plaintext
for(idx = 0; idx <= blocks; idx++){
bytes_to_copy = MIN(AES_BLOCK_SIZE, in_len - bytes_offset - (idx * AES_BLOCK_SIZE));
//bytes_to_pad = AES_BLOCK_SIZE-bytes_to_copy;
memcpy(&out[idx*AES_BLOCK_SIZE+bytes_offset], &in[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
memcpy(&outbuf[idx*AES_BLOCK_SIZE+bytes_offset], &inbuf[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
// memset(&buf[bytes_to_copy], 0, bytes_to_pad); // if bytes_to_copy is less than blocksize, do nothing, since msg is truncated anyway
aes_encrypt_block(iv, buf_in, ctx);
xor_buf(buf_in, &out[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
xor_buf(buf_in, &outbuf[idx*AES_BLOCK_SIZE+bytes_offset], bytes_to_copy);
increment_iv(iv, AES_GCM_NONCE_LEN, AES_GCM_CTR_LEN); // increment iv for continued use
if(idx==blocks){
memcpy(ctx->_private.last_block, buf_in, AES_BLOCK_SIZE);
Expand All @@ -1067,7 +1067,7 @@ bool tls_aes_decrypt(struct tls_aes_context* ctx, const void *in, size_t in_len,
}


bool tls_aes_verify(struct tls_aes_context *ctx, const void *aad, size_t aad_len, const void *ciphertext, size_t ciphertext_len, const uint8_t *tag){
bool tls_aes_verify(struct tls_aes_context *ctx, const uint8_t *aad, size_t aad_len, const uint8_t *ciphertext, size_t ciphertext_len, const uint8_t *tag){

if((ctx==NULL) ||
(ciphertext==NULL) ||
Expand Down
26 changes: 26 additions & 0 deletions src/tls/core/hash.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <stdint.h>
#include "../includes/hash.h"


bool tls_hash_context_init(struct tls_hash_context *ctx, uint8_t algorithm){
if(ctx==NULL) return false;
switch(algorithm){
case TLS_SHA256:
ctx->digestlen = TLS_SHA256_DIGEST_LEN;
ctx->init = tls_sha256_init;
ctx->update = tls_sha256_update;
ctx->digest = tls_sha256_digest;
break;
default:
return false;
}
return ctx->init(&ctx->_private);
}

void tls_hash_update(struct tls_hash_context *ctx, const uint8_t *data, size_t len) {
ctx->update(&ctx->_private, data, len);
}

void tls_hash_digest(struct tls_hash_context *ctx, uint8_t *digest) {
ctx->digest(&ctx->_private, digest);
}
138 changes: 81 additions & 57 deletions src/tls/core/rsa.c
Original file line number Diff line number Diff line change
@@ -1,134 +1,158 @@

#include <stdint.h>
#include <string.h>
#include "../includes/rsa.h"
#include "../includes/bytes.h"
#include "../includes/random.h"
#include "../includes/hash.h"
#include "../includes/rsa.h"

#define ENCODE_START 0
#define ENCODE_SALT (1 + ENCODE_START)
bool tls_rsa_encode_oaep(const uint8_t* in, size_t in_len, uint8_t* out,
size_t modulus_len, const uint8_t *auth, uint8_t alg){
bool tls_rsa_encode_oaep(const uint8_t* inbuf, size_t in_len, uint8_t* outbuf,
size_t modulus_len, const char *auth, uint8_t hash_alg){

// initial sanity checks
if((modulus_len > RSA_MODULUS_MAX_SUPPORTED) ||
(modulus_len < RSA_MODULUS_MIN_SUPPORTED) ||
(in == NULL) ||
(out == NULL) ||
(inbuf == NULL) ||
(outbuf == NULL) ||
(in_len == 0)
) return false;

struct tls_context_sha256 ctx;
if(!tls_sha256_init(&ctx)) return 0;
size_t min_padding_len = (TLS_SHA256_DIGEST_LEN<<1) + 2;
size_t ps_len = modulus_len - len - min_padding_len;
size_t db_len = modulus_len - TLS_SHA256_DIGEST_LEN - 1;
size_t encode_lhash = ENCODE_SALT + TLS_SHA256_DIGEST_LEN;
size_t encode_ps = encode_lhash + TLS_SHA256_DIGEST_LEN;
uint8_t mgf1_digest[RSA_MODULUS_MAX];
struct tls_hash_context hash;
if(!tls_hash_context_init(&hash, hash_alg)) return false;
size_t min_padding_len = (hash.digestlen<<1) + 2;
size_t ps_len = modulus_len - in_len - min_padding_len;
size_t db_len = modulus_len - hash.digestlen - 1;
size_t encode_lhash = ENCODE_SALT + hash.digestlen;
size_t encode_ps = encode_lhash + hash.digestlen;
uint8_t mgf1_digest[RSA_MODULUS_MAX_SUPPORTED];

if((in_len + min_padding_len) > modulus_len) return false;

// set first byte to 00
out[ENCODE_START] = 0x00;
outbuf[ENCODE_START] = 0x00;
// seed next 32 bytes
tls_random_bytes(&out[ENCODE_SALT], TLS_SHA256_DIGEST_LEN);
tls_random_bytes(&outbuf[ENCODE_SALT], hash.digestlen);

// hash the authentication string
if(auth != NULL) tls_sha256_update(&ctx, auth, strlen(auth));
tls_sha256_digest(&ctx, &out[encode_lhash]); // nothing to actually hash
if(auth != NULL) hash.update(&hash._private, auth, strlen(auth));
hash.digest(&hash._private, &outbuf[encode_lhash]); // nothing to actually hash

memset(&out[encode_ps], 0, ps_len); // write padding zeros
out[encode_ps + ps_len] = 0x01; // write 0x01
memcpy(&out[encode_ps + ps_len + 1], in, in_len); // write plaintext to end of output
memset(&outbuf[encode_ps], 0, ps_len); // write padding zeros
outbuf[encode_ps + ps_len] = 0x01; // write 0x01
memcpy(&outbuf[encode_ps + ps_len + 1], inbuf, in_len); // write plaintext to end of output

// hash the salt with MGF1, return hash length of db
tls_mgf1(&out[ENCODE_SALT], TLS_SHA256_DIGEST_LEN, mgf1_digest, db_len);
tls_mgf1(&outbuf[ENCODE_SALT], hash.digestlen, mgf1_digest, db_len, hash_alg);

// XOR hash with db
for(size_t i=0; i < db_len; i++)
out[encode_lhash + i] ^= mgf1_digest[i];
outbuf[encode_lhash + i] ^= mgf1_digest[i];

// hash db with MGF1, return hash length of RSA_SALT_SIZE
tls_mgf1(&out[encode_lhash], db_len, mgf1_digest, TLS_SHA256_DIGEST_LEN);
tls_mgf1(&outbuf[encode_lhash], db_len, mgf1_digest, hash.digestlen, hash_alg);

// XOR hash with salt
for(size_t i=0; i<hash_len; i++)
out[ENCODE_SALT + i] ^= mgf1_digest[i];
for(size_t i=0; i<hash.digestlen; i++)
outbuf[ENCODE_SALT + i] ^= mgf1_digest[i];

// Return the static size of 256
return true;
}


size_t tls_rsa_decode_oaep(const uint8_t *in, size_t len, uint8_t* out, const uint8_t *auth){
size_t tls_rsa_decode_oaep(const uint8_t *inbuf, size_t in_len, uint8_t* outbuf, const char *auth, uint8_t hash_alg){

if((len > RSA_MODULUS_MAX_SUPPORTED) ||
(len < RSA_MODULUS_MIN_SUPPORTED) ||
(in == NULL) ||
(out == NULL)
) return 0;
if((in_len > RSA_MODULUS_MAX_SUPPORTED) ||
(in_len < RSA_MODULUS_MIN_SUPPORTED) ||
(inbuf == NULL) ||
(outbuf == NULL)) return 0;

struct tls_context_sha256 ctx;
if(!tls_sha256_init(&ctx)) return 0;
struct tls_hash_context hash;
if(!tls_hash_context_init(&hash, hash_alg)) return false;

size_t db_len = len - TLS_SHA256_DIGEST_LEN - 1;
size_t db_len = in_len - hash.digestlen - 1;
uint8_t sha256_digest[TLS_SHA256_DIGEST_LEN];
size_t encode_lhash = ENCODE_SALT + TLS_SHA256_DIGEST_LEN;
size_t encode_ps = encode_lhash + TLS_SHA256_DIGEST_LEN;
uint8_t mgf1_digest[RSA_MODULUS_MAX];
size_t encode_lhash = ENCODE_SALT + hash.digestlen;
size_t encode_ps = encode_lhash + hash.digestlen;
uint8_t mgf1_digest[RSA_MODULUS_MAX_SUPPORTED];
size_t i;
uint8_t tmp[RSA_MODULUS_MAX];
uint8_t tmp[RSA_MODULUS_MAX_SUPPORTED];

memcpy(tmp, in, len);
memcpy(tmp, inbuf, in_len);

// Copy last 16 bytes of input buf to salt to get encoded salt
// memcpy(salt, &in[len-RSA_SALT_SIZE-1], RSA_SALT_SIZE);

// SHA-256 hash db
tls_mgf1(&tmp[encode_lhash], db_len, mgf1_digest, TLS_SHA256_DIGEST_LEN);
tls_mgf1(&tmp[encode_lhash], db_len, mgf1_digest, hash.digestlen, hash_alg);

// XOR hash with encoded salt to return salt
for(i = 0; i < hash_len; i++)
for(i = 0; i < TLS_SHA256_DIGEST_LEN; i++)
tmp[ENCODE_SALT + i] ^= mgf1_digest[i];

// MGF1 hash the salt
tls_mgf1(&tmp[ENCODE_SALT], hash_len, mgf1_digest, db_len);
tls_mgf1(&tmp[ENCODE_SALT], hash.digestlen, mgf1_digest, db_len, hash_alg);

// XOR MGF1 of salt with encoded message to get decoded message
for(i = 0; i < db_len; i++)
tmp[encode_lhash + i] ^= mgf1_digest[i];

// verify authentication
if(auth != NULL) tls_sha256_update(&ctx, auth, strlen(auth));
tls_sha256_digest(&ctx, &out[encode_lhash]);
if(auth != NULL) hash.update(&hash._private, auth, strlen(auth));
hash.digest(&hash._private, &outbuf[encode_lhash]);

if(!tls_bytes_compare(sha256_digest, out, TLS_SHA256_DIGEST_LEN)) return 0;
if(!tls_bytes_compare(sha256_digest, outbuf, TLS_SHA256_DIGEST_LEN)) return 0;

for(i = encode_ps; i < len; i++)
for(i = encode_ps; i < in_len; i++)
if(tmp[i] == 0x01) break;
if(i==len) return false;
if(i==in_len) return false;
i++;
memcpy(out, &tmp[i], len-i);
memcpy(outbuf, &tmp[i], in_len-i);


return len-i;
return in_len-i;
}

void powmod_exp_u24(uint8_t size, uint8_t *restrict base, uint24_t exp, const uint8_t *restrict mod);
#define RSA_PUBLIC_EXP 65537
bool tls_rsa_encrypt(const void* msg, size_t msglen, uint8_t *out,
const uint8_t* pubkey, size_t keylen){
bool tls_rsa_encrypt(const uint8_t* inbuf, size_t in_len, uint8_t *outbuf,
const uint8_t* pubkey, size_t keylen, uint8_t hash_alg){
size_t spos = 0;
if((msg==NULL) ||
if((inbuf==NULL) ||
(pubkey==NULL) ||
(out==NULL) ||
(msglen==0) ||
(outbuf==NULL) ||
(in_len==0) ||
(keylen<RSA_MODULUS_MAX_SUPPORTED) ||
(keylen>RSA_MODULUS_MIN_SUPPORTED) ||
(!(pubkey[keylen-1]&1))) return false;

while(pubkey[spos]==0) {out[spos++] = 0;}
if(!tls_rsa_encode_oaep(msg, msglen, &out[spos], keylen-spos, NULL)) return false;
powmod_exp_u24((uint8_t)keylen, ct, RSA_PUBLIC_EXP, pubkey);
while(pubkey[spos]==0) {outbuf[spos++] = 0;}
if(!tls_rsa_encode_oaep(inbuf, in_len, &outbuf[spos], keylen-spos, NULL, hash_alg)) return false;
powmod_exp_u24((uint8_t)keylen, outbuf, RSA_PUBLIC_EXP, pubkey);
return true;
}


bool tls_mgf1(const uint8_t* data, size_t datalen, uint8_t* outbuf, size_t outlen, uint8_t hash_alg){
uint32_t ctr = 0;
uint8_t hash_digest[TLS_SHA256_DIGEST_LEN], ctr_data[4];
struct tls_hash_context hash_data, hash_ctr;
if(!tls_hash_context_init(&hash_data, hash_alg)) return false;
size_t hashlen = hash_data.digestlen;
hash_data.update(&hash_data._private, data, datalen);
for(size_t printlen=0; printlen<outlen; printlen+=hashlen, ctr++){
size_t copylen = (outlen-printlen > hashlen) ? hashlen : outlen-printlen;
//memcpy(ctr_data, &ctr, 4);
ctr_data[0] = (uint8_t) ((ctr >> 24) & 0xff);
ctr_data[1] = (uint8_t) ((ctr >> 16) & 0xff);
ctr_data[2] = (uint8_t) ((ctr >> 8) & 0xff);
ctr_data[3] = (uint8_t) ((ctr >> 0) & 0xff);
memcpy(&hash_ctr, &hash_data, sizeof hash_ctr);
hash_ctr.update(&hash_ctr._private, ctr_data, 4);
hash_ctr.digest(&hash_ctr._private, hash_digest);
memcpy(&outbuf[printlen], hash_digest, copylen);
}
return true;
}
Loading

0 comments on commit bc5ec67

Please sign in to comment.