Skip to content

Commit

Permalink
Modernize Key Derivation Functions
Browse files Browse the repository at this point in the history
Pass the parameters from the abstract (user-facing) interface as
std::span_s and use the latest convenience tools in the implementations
of the KDFs.
  • Loading branch information
reneme committed Dec 3, 2024
1 parent 09a7a98 commit cab79aa
Show file tree
Hide file tree
Showing 20 changed files with 403 additions and 620 deletions.
119 changes: 42 additions & 77 deletions src/lib/kdf/hkdf/hkdf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <botan/exceptn.h>
#include <botan/internal/fmt.h>
#include <botan/internal/loadstor.h>
#include <botan/internal/stl_util.h>

namespace Botan {

Expand All @@ -22,20 +23,16 @@ std::string HKDF::name() const {
return fmt("HKDF({})", m_prf->name());
}

void HKDF::kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const {
void HKDF::perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const {
HKDF_Extract extract(m_prf->new_object());
HKDF_Expand expand(m_prf->new_object());
secure_vector<uint8_t> prk(m_prf->output_length());

extract.kdf(prk.data(), prk.size(), secret, secret_len, salt, salt_len, nullptr, 0);
expand.kdf(key, key_len, prk.data(), prk.size(), nullptr, 0, label, label_len);
extract.derive_key(prk, secret, salt, {});
expand.derive_key(key, prk, {}, label);
}

std::unique_ptr<KDF> HKDF_Extract::new_object() const {
Expand All @@ -46,42 +43,31 @@ std::string HKDF_Extract::name() const {
return fmt("HKDF-Extract({})", m_prf->name());
}

void HKDF_Extract::kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t /*label*/[],
size_t label_len) const {
if(key_len == 0) {
void HKDF_Extract::perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const {
if(key.empty()) {
return;
}

const size_t prf_output_len = m_prf->output_length();
BOTAN_ARG_CHECK(key.size() <= prf_output_len, "HKDF-Extract maximum output length exceeeded");
BOTAN_ARG_CHECK(label.empty(), "HKDF-Extract does not support a label input");

if(key_len > prf_output_len) {
throw Invalid_Argument("HKDF-Extract maximum output length exceeeded");
}

if(label_len > 0) {
throw Invalid_Argument("HKDF-Extract does not support a label input");
}

if(salt_len == 0) {
if(salt.empty()) {
m_prf->set_key(std::vector<uint8_t>(prf_output_len));
} else {
m_prf->set_key(salt, salt_len);
m_prf->set_key(salt);
}

m_prf->update(secret, secret_len);
m_prf->update(secret);

if(key_len == prf_output_len) {
if(key.size() == prf_output_len) {
m_prf->final(key);
} else {
secure_vector<uint8_t> prk;
m_prf->final(prk);
copy_mem(&key[0], prk.data(), key_len);
const auto prk = m_prf->final();
copy_mem(key, std::span{prk}.first(key.size()));
}
}

Expand All @@ -93,75 +79,54 @@ std::string HKDF_Expand::name() const {
return fmt("HKDF-Expand({})", m_prf->name());
}

void HKDF_Expand::kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const {
if(key_len == 0) {
void HKDF_Expand::perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const {
if(key.empty()) {
return;
}

if(key_len > m_prf->output_length() * 255) {
throw Invalid_Argument("HKDF-Expand maximum output length exceeeded");
}

m_prf->set_key(secret, secret_len);
BOTAN_ARG_CHECK(key.size() <= m_prf->output_length() * 255, "HKDF-Expand maximum output length exceeeded");

uint8_t counter = 1;
secure_vector<uint8_t> h;
size_t offset = 0;

while(offset != key_len) {
BufferStuffer k(key);
m_prf->set_key(secret);
for(uint8_t counter = 1; !k.full(); ++counter) {
m_prf->update(h);
m_prf->update(label, label_len);
m_prf->update(salt, salt_len);
m_prf->update(counter++);
m_prf->update(label);
m_prf->update(salt);
m_prf->update(counter);
m_prf->final(h);

const size_t written = std::min(h.size(), key_len - offset);
copy_mem(&key[offset], h.data(), written);
offset += written;
const auto bytes_to_write = std::min(h.size(), k.remaining_capacity());
k.append(std::span{h}.first(bytes_to_write));
}
}

secure_vector<uint8_t> hkdf_expand_label(std::string_view hash_fn,
const uint8_t secret[],
size_t secret_len,
std::span<const uint8_t> secret,
std::string_view label,
const uint8_t hash_val[],
size_t hash_val_len,
std::span<const uint8_t> hash_val,
size_t length) {
BOTAN_ARG_CHECK(length <= 0xFFFF, "HKDF-Expand-Label requested output too large");
BOTAN_ARG_CHECK(label.size() <= 0xFF, "HKDF-Expand-Label label too long");
BOTAN_ARG_CHECK(hash_val_len <= 0xFF, "HKDF-Expand-Label hash too long");

const uint16_t length16 = static_cast<uint16_t>(length);
BOTAN_ARG_CHECK(hash_val.size() <= 0xFF, "HKDF-Expand-Label hash too long");

HKDF_Expand hkdf(MessageAuthenticationCode::create_or_throw(fmt("HMAC({})", hash_fn)));

secure_vector<uint8_t> output(length16);
std::vector<uint8_t> prefix(3 + label.size() + 1);

prefix[0] = get_byte<0>(length16);
prefix[1] = get_byte<1>(length16);
prefix[2] = static_cast<uint8_t>(label.size());

copy_mem(prefix.data() + 3, cast_char_ptr_to_uint8(label.data()), label.size());

prefix[3 + label.size()] = static_cast<uint8_t>(hash_val_len);
const auto prefix = concat<std::vector<uint8_t>>(store_be(static_cast<uint16_t>(length)),
store_be(static_cast<uint8_t>(label.size())),
std::span{cast_char_ptr_to_uint8(label.data()), label.size()},
store_be(static_cast<uint8_t>(hash_val.size())));

/*
* We do something a little dirty here to avoid copying the hash_val,
* making use of the fact that Botan's KDF interface supports label+salt,
* and knowing that our HKDF hashes first param label then param salt.
*/
hkdf.kdf(output.data(), output.size(), secret, secret_len, hash_val, hash_val_len, prefix.data(), prefix.size());

return output;
return hkdf.derive_key(length, secret, hash_val, prefix);
}

} // namespace Botan
47 changes: 17 additions & 30 deletions src/lib/kdf/hkdf/hkdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ class HKDF final : public KDF {

std::string name() const override;

void kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const override;
private:
void perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const override;

private:
std::unique_ptr<MessageAuthenticationCode> m_prf;
Expand All @@ -55,14 +52,11 @@ class HKDF_Extract final : public KDF {

std::string name() const override;

void kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const override;
private:
void perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const override;

private:
std::unique_ptr<MessageAuthenticationCode> m_prf;
Expand All @@ -82,14 +76,11 @@ class HKDF_Expand final : public KDF {

std::string name() const override;

void kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const override;
private:
void perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const override;

private:
std::unique_ptr<MessageAuthenticationCode> m_prf;
Expand All @@ -99,19 +90,15 @@ class HKDF_Expand final : public KDF {
* HKDF-Expand-Label from TLS 1.3/QUIC
* @param hash_fn the hash to use
* @param secret the secret bits
* @param secret_len the length of secret
* @param label the full label (no "TLS 1.3, " or "tls13 " prefix
* is applied)
* @param hash_val the previous hash value (used for chaining, may be empty)
* @param hash_val_len the length of hash_val
* @param length the desired output length
*/
secure_vector<uint8_t> BOTAN_TEST_API hkdf_expand_label(std::string_view hash_fn,
const uint8_t secret[],
size_t secret_len,
std::span<const uint8_t> secret,
std::string_view label,
const uint8_t hash_val[],
size_t hash_val_len,
std::span<const uint8_t> hash_val,
size_t length);

} // namespace Botan
Expand Down
68 changes: 39 additions & 29 deletions src/lib/kdf/kdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
* @param label purpose for the derived keying material
* @param label_len size of label in bytes
*/
virtual void kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const = 0;
void kdf(uint8_t key[],
size_t key_len,
const uint8_t secret[],
size_t secret_len,
const uint8_t salt[],
size_t salt_len,
const uint8_t label[],
size_t label_len) const {
perform_kdf({key, key_len}, {secret, secret_len}, {salt, salt_len}, {label, label_len});
}

/**
* Derive a key
Expand All @@ -90,9 +92,7 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
size_t salt_len,
const uint8_t label[] = nullptr,
size_t label_len = 0) const {
T key(key_len);
kdf(key.data(), key.size(), secret, secret_len, salt, salt_len, label, label_len);
return key;
return derive_key<T>(key_len, {secret, secret_len}, {salt, salt_len}, {label, label_len});
}

/**
Expand All @@ -109,12 +109,9 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
std::string_view salt = "",
std::string_view label = "") const {
return derive_key<T>(key_len,
secret.data(),
secret.size(),
cast_char_ptr_to_uint8(salt.data()),
salt.length(),
cast_char_ptr_to_uint8(label.data()),
label.length());
secret,
{cast_char_ptr_to_uint8(salt.data()), salt.length()},
{cast_char_ptr_to_uint8(label.data()), label.length()});
}

/**
Expand All @@ -128,8 +125,7 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const {
return kdf(
key.data(), key.size(), secret.data(), secret.size(), salt.data(), salt.size(), label.data(), label.size());
return perform_kdf(key, secret, salt, label);
}

/**
Expand All @@ -145,8 +141,9 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const {
return derive_key<T>(
key_len, secret.data(), secret.size(), salt.data(), salt.size(), label.data(), label.size());
T key(key_len);
perform_kdf(key, secret, salt, label);
return key;
}

/**
Expand All @@ -164,8 +161,7 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
const uint8_t salt[],
size_t salt_len,
std::string_view label = "") const {
return derive_key<T>(
key_len, secret.data(), secret.size(), salt, salt_len, cast_char_ptr_to_uint8(label.data()), label.size());
return derive_key<T>(key_len, secret, {salt, salt_len}, {cast_char_ptr_to_uint8(label.data()), label.size()});
}

/**
Expand All @@ -184,12 +180,9 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
std::string_view salt = "",
std::string_view label = "") const {
return derive_key<T>(key_len,
secret,
secret_len,
cast_char_ptr_to_uint8(salt.data()),
salt.length(),
cast_char_ptr_to_uint8(label.data()),
label.length());
{secret, secret_len},
{cast_char_ptr_to_uint8(salt.data()), salt.length()},
{cast_char_ptr_to_uint8(label.data()), label.length()});
}

/**
Expand All @@ -201,6 +194,23 @@ class BOTAN_PUBLIC_API(2, 0) KDF {
* @return new object representing the same algorithm as *this
*/
KDF* clone() const { return this->new_object().release(); }

protected:
/**
* Internal customization point for subclasses
*
* The byte size of the @p key span is the number of bytes to be produced
* by the concrete key derivation function.
*
* @param key the output buffer for the to-be-derived key
* @param secret the secret input
* @param salt a diversifier
* @param label purpose for the derived keying material
*/
virtual void perform_kdf(std::span<uint8_t> key,
std::span<const uint8_t> secret,
std::span<const uint8_t> salt,
std::span<const uint8_t> label) const = 0;
};

/**
Expand Down
Loading

0 comments on commit cab79aa

Please sign in to comment.