From 18b6a0e17199069575dce2c30cdb553450275593 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Thu, 21 Dec 2023 10:13:48 -0500 Subject: [PATCH] Support all Certificate Fingerprint Algorithms Before was hardcoded to sha256 Resolves #1076 --- include/rtc/description.hpp | 54 ++++++++++++++++++++- src/description.cpp | 71 ++++++++++++++++++++------- src/impl/certificate.cpp | 87 +++++++++++++++++++++++++++------- src/impl/certificate.hpp | 9 ++-- src/impl/dtlssrtptransport.cpp | 3 +- src/impl/dtlssrtptransport.hpp | 5 +- src/impl/dtlstransport.cpp | 27 ++++++----- src/impl/dtlstransport.hpp | 2 + src/impl/peerconnection.cpp | 14 ++++-- test/connectivity.cpp | 5 +- 10 files changed, 216 insertions(+), 61 deletions(-) diff --git a/include/rtc/description.hpp b/include/rtc/description.hpp index b06014fd9..5b0332044 100644 --- a/include/rtc/description.hpp +++ b/include/rtc/description.hpp @@ -33,6 +33,53 @@ class RTC_CPP_EXPORT Description { enum class Type { Unspec, Offer, Answer, Pranswer, Rollback }; enum class Role { ActPass, Passive, Active }; + enum class FingerprintAlgorithm { + MD5, // RFC3279 + SHA1, // RFC3279 + SHA224, // RFC4055 + SHA256, // RFC4055 + SHA384, // RFC4055 + SHA512, // RFC4055 + }; + + static size_t fingerprintAlgorithmSize(FingerprintAlgorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case FingerprintAlgorithm::MD5: + return 16; + case FingerprintAlgorithm::SHA1: + return 20; + case FingerprintAlgorithm::SHA224: + return 28; + case FingerprintAlgorithm::SHA256: + return 32; + case FingerprintAlgorithm::SHA384: + return 48; + case FingerprintAlgorithm::SHA512: + return 64; + } + + return 0; + }; + + static std::string fingerprintAlgorithmIdentifier(FingerprintAlgorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case FingerprintAlgorithm::MD5: + return "md5"; + case FingerprintAlgorithm::SHA1: + return "sha-1"; + case FingerprintAlgorithm::SHA224: + return "sha-224"; + case FingerprintAlgorithm::SHA256: + return "sha-256"; + case FingerprintAlgorithm::SHA384: + return "sha-256"; + case FingerprintAlgorithm::SHA512: + return "sha-512"; + } + + return ""; + } + enum class Direction { SendOnly = RTC_DIRECTION_SENDONLY, RecvOnly = RTC_DIRECTION_RECVONLY, @@ -52,10 +99,11 @@ class RTC_CPP_EXPORT Description { optional iceUfrag() const; optional icePwd() const; optional fingerprint() const; + optional fingerprintAlgorithm() const; bool ended() const; void hintType(Type type); - void setFingerprint(string fingerprint); + void setFingerprint(string fingerprint, Description::FingerprintAlgorithm fingerprintAlgorithm); void addIceOption(string option); void removeIceOption(const string &option); @@ -292,6 +340,7 @@ class RTC_CPP_EXPORT Description { std::vector mIceOptions; optional mIceUfrag, mIcePwd; optional mFingerprint; + optional mFingerprintAlgorithm; std::vector mAttributes; // other attributes // Entries @@ -308,6 +357,7 @@ class RTC_CPP_EXPORT Description { RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const rtc::Description &description); RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, rtc::Description::Type type); RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, rtc::Description::Role role); -RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const rtc::Description::Direction &direction); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, + const rtc::Description::Direction &direction); #endif diff --git a/src/description.cpp b/src/description.cpp index f7910e390..d42346730 100644 --- a/src/description.cpp +++ b/src/description.cpp @@ -33,11 +33,6 @@ inline bool match_prefix(string_view str, string_view prefix) { std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end(); } -inline void trim_begin(string &str) { - str.erase(str.begin(), - std::find_if(str.begin(), str.end(), [](char c) { return !std::isspace(c); })); -} - inline void trim_end(string &str) { str.erase( std::find_if(str.rbegin(), str.rend(), [](char c) { return !std::isspace(c); }).base(), @@ -64,9 +59,13 @@ template T to_integer(string_view s) { } } -inline bool is_sha256_fingerprint(string_view f) { - if (f.size() != 32 * 3 - 1) +inline bool is_valid_fingerprint(string_view f, size_t expectedSize) { + auto valuesOnly = std::string(f); + valuesOnly.erase(std::remove(valuesOnly.begin(), valuesOnly.end(), ':'), valuesOnly.end()); + + if (valuesOnly.size() != expectedSize * 2) { return false; + } for (size_t i = 0; i < f.size(); ++i) { if (i % 3 == 2) { @@ -124,13 +123,41 @@ Description::Description(const string &sdp, Type type, Role role) // media-level SDP attribute. If it is a session-level attribute, it applies to all // TLS sessions for which no media-level fingerprint attribute is defined. if (!mFingerprint || index == 0) { // first media overrides session-level - if (match_prefix(value, "sha-256 ") || match_prefix(value, "SHA-256 ")) { - string fingerprint{value.substr(8)}; - trim_begin(fingerprint); - setFingerprint(std::move(fingerprint)); - } else { + auto fingerprintExploded = utils::explode(string(value), ' '); + if (fingerprintExploded.size() != 2) { PLOG_WARNING << "Unknown SDP fingerprint format: " << value; + continue; } + + auto first = fingerprintExploded.at(0); + std::transform(first.begin(), first.end(), first.begin(), + [](char c) { return char(std::tolower(c)); }); + + Description::FingerprintAlgorithm fingerprintAlgorithm; + if (first == + fingerprintAlgorithmIdentifier(Description::FingerprintAlgorithm::MD5)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::MD5; + } else if (first == fingerprintAlgorithmIdentifier( + Description::FingerprintAlgorithm::SHA1)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA1; + } else if (first == fingerprintAlgorithmIdentifier( + Description::FingerprintAlgorithm::SHA224)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA224; + } else if (first == fingerprintAlgorithmIdentifier( + Description::FingerprintAlgorithm::SHA256)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA256; + } else if (first == fingerprintAlgorithmIdentifier( + Description::FingerprintAlgorithm::SHA384)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA384; + } else if (first == fingerprintAlgorithmIdentifier( + Description::FingerprintAlgorithm::SHA512)) { + fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA512; + } else { + PLOG_WARNING << "Unknown certificate fingerprint algorithm: " << first; + continue; + } + + setFingerprint(std::move(fingerprintExploded.at(1)), fingerprintAlgorithm); } } else if (key == "ice-ufrag") { // RFC 8839: The "ice-pwd" and "ice-ufrag" attributes can appear at either the @@ -200,6 +227,10 @@ optional Description::icePwd() const { return mIcePwd; } optional Description::fingerprint() const { return mFingerprint; } +optional Description::fingerprintAlgorithm() const { + return mFingerprintAlgorithm; +} + bool Description::ended() const { return mEnded; } void Description::hintType(Type type) { @@ -207,13 +238,15 @@ void Description::hintType(Type type) { mType = type; } -void Description::setFingerprint(string fingerprint) { - if (!is_sha256_fingerprint(fingerprint)) - throw std::invalid_argument("Invalid SHA256 fingerprint \"" + fingerprint + "\""); +void Description::setFingerprint(string fingerprint, FingerprintAlgorithm fingerprintAlgorithm) { + if (!is_valid_fingerprint(fingerprint, + Description::fingerprintAlgorithmSize(fingerprintAlgorithm))) + throw std::invalid_argument("Invalid certificate fingerprint \"" + fingerprint + "\""); std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), [](char c) { return char(std::toupper(c)); }); mFingerprint.emplace(std::move(fingerprint)); + mFingerprintAlgorithm = fingerprintAlgorithm; } void Description::addIceOption(string option) { @@ -308,7 +341,9 @@ string Description::generateSdp(string_view eol) const { if (!mIceOptions.empty()) sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; if (mFingerprint) - sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; + sdp << "a=fingerprint:" + << Description::fingerprintAlgorithmIdentifier(mFingerprintAlgorithm.value()) << " " + << *mFingerprint << eol; for (const auto &attr : mAttributes) sdp << "a=" << attr << eol; @@ -371,7 +406,9 @@ string Description::generateApplicationSdp(string_view eol) const { if (!mIceOptions.empty()) sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; if (mFingerprint) - sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; + sdp << "a=fingerprint:" + << Description::fingerprintAlgorithmIdentifier(mFingerprintAlgorithm.value()) << " " + << *mFingerprint << eol; for (const auto &attr : mAttributes) sdp << "a=" << attr << eol; diff --git a/src/impl/certificate.cpp b/src/impl/certificate.cpp index 7bc2848e3..14109adfa 100644 --- a/src/impl/certificate.cpp +++ b/src/impl/certificate.cpp @@ -100,18 +100,20 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey) : mCredentials(gnutls::new_credentials(), gnutls::free_credentials), - mFingerprint(make_fingerprint(crt)) { + mFingerprint(make_fingerprint(crt, Description::FingerprintAlgorithm::SHA256)) { gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey), "Unable to set certificate and key pair in credentials"); } Certificate::Certificate(shared_ptr creds) - : mCredentials(std::move(creds)), mFingerprint(make_fingerprint(*mCredentials)) {} + : mCredentials(std::move(creds)), + mFingerprint(make_fingerprint(*mCredentials, Description::FingerprintAlgorithm::SHA256)) {} gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } -string make_fingerprint(gnutls_certificate_credentials_t credentials) { +string make_fingerprint(gnutls_certificate_credentials_t credentials, + Description::FingerprintAlgorithm fingerprintAlgorithm) { auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * { gnutls_x509_crt_t *crt_list = nullptr; unsigned int crt_list_size = 0; @@ -127,14 +129,28 @@ string make_fingerprint(gnutls_certificate_credentials_t credentials) { unique_ptr crt_list(new_crt_list(), free_crt_list); - return make_fingerprint(*crt_list); + return make_fingerprint(*crt_list, fingerprintAlgorithm); } -string make_fingerprint(gnutls_x509_crt_t crt) { - const size_t size = 32; +string make_fingerprint(gnutls_x509_crt_t crt, + Description::FingerprintAlgorithm fingerprintAlgorithm) { + const size_t size = Description::fingerprintAlgorithmSize(fingerprintAlgorithm); unsigned char buffer[size]; size_t len = size; - gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len), + auto hashFunc = GNUTLS_DIG_SHA256; + if (fingerprintAlgorithm == Description::FingerprintAlgorithm::MD5) { + hashFunc = GNUTLS_DIG_MD5; + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA1) { + hashFunc = GNUTLS_DIG_SHA1; + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA224) { + hashFunc = GNUTLS_DIG_SHA224; + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA384) { + hashFunc = GNUTLS_DIG_SHA384; + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA512) { + hashFunc = GNUTLS_DIG_SHA512; + } + + gnutls::check(gnutls_x509_crt_get_fingerprint(crt, hashFunc, buffer, &len), "X509 fingerprint error"); std::ostringstream oss; @@ -148,14 +164,33 @@ string make_fingerprint(gnutls_x509_crt_t crt) { } #elif USE_MBEDTLS -string make_fingerprint(mbedtls_x509_crt *crt) { - const int size = 32; +string make_fingerprint(mbedtls_x509_crt *crt, + Description::FingerprintAlgorithm fingerprintAlgorithm) { + const int size = Description::fingerprintAlgorithmSize(fingerprintAlgorithm); uint8_t buffer[size]; std::stringstream fingerprint; - mbedtls::check( - mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 0), - "Failed to generate certificate fingerprint"); + if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA1) { + mbedtls::check( + mbedtls_sha1(crt->raw.p, crt->raw.len, reinterpret_cast(buffer)), + "Failed to generate certificate fingerprint"); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA224) { + mbedtls::check( + mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 1), + "Failed to generate certificate fingerprint"); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA256) { + mbedtls::check( + mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 0), + "Failed to generate certificate fingerprint"); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA384) { + mbedtls::check( + mbedtls_sha512(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 1), + "Failed to generate certificate fingerprint"); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA512) { + mbedtls::check( + mbedtls_sha512(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 0), + "Failed to generate certificate fingerprint"); + } for (auto i = 0; i < size; i++) { fingerprint << std::setfill('0') << std::setw(2) << std::hex << static_cast(buffer[i]); @@ -168,7 +203,8 @@ string make_fingerprint(mbedtls_x509_crt *crt) { } Certificate::Certificate(shared_ptr crt, shared_ptr pk) - : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt.get())) {} + : mCrt(crt), mPk(pk), + mFingerprint(make_fingerprint(crt.get(), Description::FingerprintAlgorithm::SHA256)) {} Certificate Certificate::FromString(string crt_pem, string key_pem) { PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)"; @@ -465,17 +501,32 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName } Certificate::Certificate(shared_ptr x509, shared_ptr pkey) - : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {} + : mX509(std::move(x509)), mPKey(std::move(pkey)), + mFingerprint(make_fingerprint(mX509.get(), Description::FingerprintAlgorithm::SHA256)) {} std::tuple Certificate::credentials() const { return {mX509.get(), mPKey.get()}; } -string make_fingerprint(X509 *x509) { - const size_t size = 32; +string make_fingerprint(X509 *x509, Description::FingerprintAlgorithm fingerprintAlgorithm) { + const size_t size = Description::fingerprintAlgorithmSize(fingerprintAlgorithm); unsigned char buffer[size]; - unsigned int len = size; - if (!X509_digest(x509, EVP_sha256(), buffer, &len)) + auto len = static_cast(size); + + auto hashFunc = EVP_sha256(); + if (fingerprintAlgorithm == Description::FingerprintAlgorithm::MD5) { + hashFunc = EVP_md5(); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA1) { + hashFunc = EVP_sha1(); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA224) { + hashFunc = EVP_sha224(); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA384) { + hashFunc = EVP_sha384(); + } else if (fingerprintAlgorithm == Description::FingerprintAlgorithm::SHA512) { + hashFunc = EVP_sha512(); + } + + if (!X509_digest(x509, hashFunc, buffer, &len)) throw std::runtime_error("X509 fingerprint error"); std::ostringstream oss; diff --git a/src/impl/certificate.hpp b/src/impl/certificate.hpp index 800bcd2d9..184f8b839 100644 --- a/src/impl/certificate.hpp +++ b/src/impl/certificate.hpp @@ -9,6 +9,7 @@ #ifndef RTC_IMPL_CERTIFICATE_H #define RTC_IMPL_CERTIFICATE_H +#include "description.hpp" #include "common.hpp" #include "configuration.hpp" // for CertificateType #include "init.hpp" @@ -57,12 +58,12 @@ class Certificate { }; #if USE_GNUTLS -string make_fingerprint(gnutls_certificate_credentials_t credentials); -string make_fingerprint(gnutls_x509_crt_t crt); +string make_fingerprint(gnutls_certificate_credentials_t credentials, Description::FingerprintAlgorithm fingerprintAlgorithm); +string make_fingerprint(gnutls_x509_crt_t crt, Description::FingerprintAlgorithm fingerprintAlgorithm); #elif USE_MBEDTLS -string make_fingerprint(mbedtls_x509_crt *crt); +string make_fingerprint(mbedtls_x509_crt *crt, Description::FingerprintAlgorithm fingerprintAlgorithm); #else -string make_fingerprint(X509 *x509); +string make_fingerprint(X509 *x509, Description::FingerprintAlgorithm fingerprintAlgorithm); #endif using certificate_ptr = shared_ptr; diff --git a/src/impl/dtlssrtptransport.cpp b/src/impl/dtlssrtptransport.cpp index 8566338ab..536c33f74 100644 --- a/src/impl/dtlssrtptransport.cpp +++ b/src/impl/dtlssrtptransport.cpp @@ -58,10 +58,11 @@ bool DtlsSrtpTransport::IsGcmSupported() { DtlsSrtpTransport::DtlsSrtpTransport(shared_ptr lower, shared_ptr certificate, optional mtu, + Description::FingerprintAlgorithm fingerprintAlgorithm, verifier_callback verifierCallback, message_callback srtpRecvCallback, state_callback stateChangeCallback) - : DtlsTransport(lower, certificate, mtu, std::move(verifierCallback), + : DtlsTransport(lower, certificate, mtu, fingerprintAlgorithm, std::move(verifierCallback), std::move(stateChangeCallback)), mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback diff --git a/src/impl/dtlssrtptransport.hpp b/src/impl/dtlssrtptransport.hpp index cd6276b9a..896357be5 100644 --- a/src/impl/dtlssrtptransport.hpp +++ b/src/impl/dtlssrtptransport.hpp @@ -31,8 +31,9 @@ class DtlsSrtpTransport final : public DtlsTransport { static bool IsGcmSupported(); DtlsSrtpTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - message_callback srtpRecvCallback, state_callback stateChangeCallback); + optional mtu, Description::FingerprintAlgorithm fingerprintAlgorithm, + verifier_callback verifierCallback, message_callback srtpRecvCallback, + state_callback stateChangeCallback); ~DtlsSrtpTransport(); bool sendMedia(message_ptr message); diff --git a/src/impl/dtlstransport.cpp b/src/impl/dtlstransport.cpp index aa5182411..05f356070 100644 --- a/src/impl/dtlstransport.cpp +++ b/src/impl/dtlstransport.cpp @@ -48,10 +48,11 @@ void DtlsTransport::Init() { void DtlsTransport::Cleanup() { gnutls_global_deinit(); } DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + Description::FingerprintAlgorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; @@ -295,7 +296,7 @@ int DtlsTransport::CertificateCallback(gnutls_session_t session) { return GNUTLS_E_CERTIFICATE_ERROR; } - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); gnutls_x509_crt_deinit(crt); bool success = t->mVerifierCallback(fingerprint); @@ -374,10 +375,11 @@ const mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = { }; DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + Description::FingerprintAlgorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (MbedTLS)"; @@ -609,7 +611,7 @@ void DtlsTransport::doRecv() { int DtlsTransport::CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int /*depth*/, uint32_t * /*flags*/) { auto this_ = static_cast(ctx); - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, this_->mFingerprintAlgorithm); std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), [](char c) { return char(std::toupper(c)); }); return this_->mVerifierCallback(fingerprint) ? 0 : 1; @@ -725,10 +727,11 @@ void DtlsTransport::Cleanup() { } DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + Description::FingerprintAlgorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; @@ -1034,7 +1037,7 @@ int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); X509 *crt = X509_STORE_CTX_get_current_cert(ctx); - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); return t->mVerifierCallback(fingerprint) ? 1 : 0; } diff --git a/src/impl/dtlstransport.hpp b/src/impl/dtlstransport.hpp index d84b312e8..e3a30c506 100644 --- a/src/impl/dtlstransport.hpp +++ b/src/impl/dtlstransport.hpp @@ -32,6 +32,7 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this; DtlsTransport(shared_ptr lower, certificate_ptr certificate, optional mtu, + Description::FingerprintAlgorithm fingerprintAlgorithm, verifier_callback verifierCallback, state_callback stateChangeCallback); ~DtlsTransport(); @@ -52,6 +53,7 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this mMtu; const certificate_ptr mCertificate; + Description::FingerprintAlgorithm mFingerprintAlgorithm; const verifier_callback mVerifierCallback; const bool mIsClient; diff --git a/src/impl/peerconnection.cpp b/src/impl/peerconnection.cpp index 6c1f648d8..79e8b86fc 100644 --- a/src/impl/peerconnection.cpp +++ b/src/impl/peerconnection.cpp @@ -211,6 +211,12 @@ shared_ptr PeerConnection::initDtlsTransport() { PLOG_VERBOSE << "Starting DTLS transport"; + auto fingerprintAlgorithm = Description::FingerprintAlgorithm::SHA256; + if (auto remote = remoteDescription(); + remote && remote->fingerprintAlgorithm().has_value()) { + fingerprintAlgorithm = remote->fingerprintAlgorithm().value(); + } + auto lower = std::atomic_load(&mIceTransport); if (!lower) throw std::logic_error("No underlying ICE transport for DTLS transport"); @@ -254,7 +260,7 @@ shared_ptr PeerConnection::initDtlsTransport() { // DTLS-SRTP transport = std::make_shared( - lower, certificate, config.mtu, verifierCallback, + lower, certificate, config.mtu, fingerprintAlgorithm, verifierCallback, weak_bind(&PeerConnection::forwardMedia, this, _1), dtlsStateChangeCallback); #else PLOG_WARNING << "Ignoring media support (not compiled with media support)"; @@ -264,7 +270,8 @@ shared_ptr PeerConnection::initDtlsTransport() { if (!transport) { // DTLS only transport = std::make_shared(lower, certificate, config.mtu, - verifierCallback, dtlsStateChangeCallback); + fingerprintAlgorithm, verifierCallback, + dtlsStateChangeCallback); } return emplaceTransport(this, &mDtlsTransport, std::move(transport)); @@ -994,7 +1001,8 @@ void PeerConnection::processLocalDescription(Description description) { } // Set local fingerprint (wait for certificate if necessary) - description.setFingerprint(mCertificate.get()->fingerprint()); + description.setFingerprint(mCertificate.get()->fingerprint(), + Description::FingerprintAlgorithm::SHA256); PLOG_VERBOSE << "Issuing local description: " << description; diff --git a/test/connectivity.cpp b/test/connectivity.cpp index ddafcae68..36c2fa05b 100644 --- a/test/connectivity.cpp +++ b/test/connectivity.cpp @@ -51,11 +51,12 @@ void test_connectivity(bool signal_wrong_fingerprint) { cout << "Description 1: " << sdp << endl; if (signal_wrong_fingerprint) { auto f = sdp.fingerprint(); - if (f.has_value()) { + auto a = sdp.fingerprintAlgorithm(); + if (f.has_value() && a.has_value()) { auto s = f.value(); auto& c = s[0]; if (c == 'F' || c == 'f') c = '0'; else c++; - sdp.setFingerprint(s); + sdp.setFingerprint(s, a.value()); } } pc2.setRemoteDescription(string(sdp));