From 18161658c7b1f003c86910cb657be50dc1254e14 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Thu, 21 Dec 2023 10:13:48 -0500 Subject: [PATCH] Support all Certificate Fingerprint Algorithms Before was hardcoded to sha256 Resolves #1076 --- include/rtc/description.hpp | 18 ++++-- src/description.cpp | 102 +++++++++++++++++++++++------- src/impl/certificate.cpp | 112 ++++++++++++++++++++++++++------- src/impl/certificate.hpp | 9 +-- src/impl/dtlssrtptransport.cpp | 3 +- src/impl/dtlssrtptransport.hpp | 5 +- src/impl/dtlstransport.cpp | 27 ++++---- src/impl/dtlstransport.hpp | 2 + src/impl/peerconnection.cpp | 25 +++++--- test/connectivity.cpp | 4 +- 10 files changed, 230 insertions(+), 77 deletions(-) diff --git a/include/rtc/description.hpp b/include/rtc/description.hpp index b06014fd9..b1afd3e3e 100644 --- a/include/rtc/description.hpp +++ b/include/rtc/description.hpp @@ -28,6 +28,15 @@ const string DEFAULT_OPUS_AUDIO_PROFILE = const string DEFAULT_H264_VIDEO_PROFILE = "profile-level-id=42e01f;packetization-mode=1;level-asymmetry-allowed=1"; +struct CertificateFingerprint { + enum Algorithm { Sha1, Sha224, Sha256, Sha384, Sha512 }; + static string AlgorithmIdentifier(Algorithm algorithm); + static size_t AlgorithmSize(Algorithm algorithm); + + Algorithm algorithm; + string value; +}; + class RTC_CPP_EXPORT Description { public: enum class Type { Unspec, Offer, Answer, Pranswer, Rollback }; @@ -51,11 +60,11 @@ class RTC_CPP_EXPORT Description { std::vector iceOptions() const; optional iceUfrag() const; optional icePwd() const; - optional fingerprint() const; + optional fingerprint() const; bool ended() const; void hintType(Type type); - void setFingerprint(string fingerprint); + void setFingerprint(string fingerprint, CertificateFingerprint::Algorithm fingerprintAlgorithm); void addIceOption(string option); void removeIceOption(const string &option); @@ -291,7 +300,7 @@ class RTC_CPP_EXPORT Description { string mSessionId; std::vector mIceOptions; optional mIceUfrag, mIcePwd; - optional mFingerprint; + optional mFingerprint; std::vector mAttributes; // other attributes // Entries @@ -308,6 +317,7 @@ class RTC_CPP_EXPORT Description { RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const rtc::Description &description); RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, rtc::Description::Type type); RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, rtc::Description::Role role); -RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const rtc::Description::Direction &direction); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, + const rtc::Description::Direction &direction); #endif diff --git a/src/description.cpp b/src/description.cpp index 73dee1df4..345abf1c8 100644 --- a/src/description.cpp +++ b/src/description.cpp @@ -33,11 +33,6 @@ inline bool match_prefix(string_view str, string_view prefix) { std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end(); } -inline void trim_begin(string &str) { - str.erase(str.begin(), - std::find_if(str.begin(), str.end(), [](char c) { return !std::isspace(c); })); -} - inline void trim_end(string &str) { str.erase( std::find_if(str.rbegin(), str.rend(), [](char c) { return !std::isspace(c); }).base(), @@ -71,9 +66,10 @@ template T to_integer(string_view s) { } } -inline bool is_sha256_fingerprint(string_view f) { - if (f.size() != 32 * 3 - 1) +inline bool is_valid_fingerprint(string_view f, size_t expectedSize) { + if (f.size() != expectedSize * 3 - 1) { return false; + } for (size_t i = 0; i < f.size(); ++i) { if (i % 3 == 2) { @@ -131,12 +127,35 @@ Description::Description(const string &sdp, Type type, Role role) // media-level SDP attribute. If it is a session-level attribute, it applies to all // TLS sessions for which no media-level fingerprint attribute is defined. if (!mFingerprint || index == 0) { // first media overrides session-level - if (match_prefix(value, "sha-256 ") || match_prefix(value, "SHA-256 ")) { - string fingerprint{value.substr(8)}; - trim_begin(fingerprint); - setFingerprint(std::move(fingerprint)); - } else { + auto fingerprintExploded = utils::explode(string(value), ' '); + if (fingerprintExploded.size() != 2) { PLOG_WARNING << "Unknown SDP fingerprint format: " << value; + continue; + } + + auto first = fingerprintExploded.at(0); + std::transform(first.begin(), first.end(), first.begin(), + [](char c) { return char(std::tolower(c)); }); + + std::optional fingerprintAlgorithm; + + for (auto a : std::vector{ + CertificateFingerprint::Algorithm::Sha1, + CertificateFingerprint::Algorithm::Sha224, + CertificateFingerprint::Algorithm::Sha256, + CertificateFingerprint::Algorithm::Sha384, + CertificateFingerprint::Algorithm::Sha512}) { + if (first == CertificateFingerprint::AlgorithmIdentifier(a)) { + fingerprintAlgorithm = a; + break; + } + } + + if (fingerprintAlgorithm.has_value()) { + setFingerprint(std::move(fingerprintExploded.at(1)), + fingerprintAlgorithm.value()); + } else { + PLOG_WARNING << "Unknown certificate fingerprint algorithm: " << first; } } } else if (key == "ice-ufrag") { @@ -205,7 +224,7 @@ std::vector Description::iceOptions() const { return mIceOptions; } optional Description::icePwd() const { return mIcePwd; } -optional Description::fingerprint() const { return mFingerprint; } +optional Description::fingerprint() const { return mFingerprint; } bool Description::ended() const { return mEnded; } @@ -214,13 +233,15 @@ void Description::hintType(Type type) { mType = type; } -void Description::setFingerprint(string fingerprint) { - if (!is_sha256_fingerprint(fingerprint)) - throw std::invalid_argument("Invalid SHA256 fingerprint \"" + fingerprint + "\""); +void Description::setFingerprint(string fingerprint, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + if (!is_valid_fingerprint(fingerprint, + CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm))) + throw std::invalid_argument("Invalid certificate fingerprint \"" + fingerprint + "\""); std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), [](char c) { return char(std::toupper(c)); }); - mFingerprint.emplace(std::move(fingerprint)); + mFingerprint = CertificateFingerprint{fingerprintAlgorithm, std::move(fingerprint)}; } void Description::addIceOption(string option) { @@ -315,7 +336,9 @@ string Description::generateSdp(string_view eol) const { if (!mIceOptions.empty()) sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; if (mFingerprint) - sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; + sdp << "a=fingerprint:" + << CertificateFingerprint::AlgorithmIdentifier(mFingerprint.value().algorithm) << " " + << mFingerprint.value().value << eol; for (const auto &attr : mAttributes) sdp << "a=" << attr << eol; @@ -378,7 +401,9 @@ string Description::generateApplicationSdp(string_view eol) const { if (!mIceOptions.empty()) sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; if (mFingerprint) - sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol; + sdp << "a=fingerprint:" + << CertificateFingerprint::AlgorithmIdentifier(mFingerprint.value().algorithm) << " " + << mFingerprint.value().value << eol; for (const auto &attr : mAttributes) sdp << "a=" << attr << eol; @@ -876,8 +901,7 @@ void Description::Application::parseSdpLine(string_view line) { } } -Description::Media::Media(const string &sdp) - : Entry(get_first_line(sdp), "", Direction::Unknown) { +Description::Media::Media(const string &sdp) : Entry(get_first_line(sdp), "", Direction::Unknown) { string line; std::istringstream ss(sdp); std::getline(ss, line); // discard first line @@ -1288,6 +1312,42 @@ string Description::typeToString(Type type) { } } +size_t +CertificateFingerprint::AlgorithmSize(CertificateFingerprint::Algorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + return 20; + case CertificateFingerprint::Algorithm::Sha224: + return 28; + case CertificateFingerprint::Algorithm::Sha256: + return 32; + case CertificateFingerprint::Algorithm::Sha384: + return 48; + case CertificateFingerprint::Algorithm::Sha512: + return 64; + } + + return 0; +} + +std::string CertificateFingerprint::AlgorithmIdentifier( + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + return "sha-1"; + case CertificateFingerprint::Algorithm::Sha224: + return "sha-224"; + case CertificateFingerprint::Algorithm::Sha256: + return "sha-256"; + case CertificateFingerprint::Algorithm::Sha384: + return "sha-256"; + case CertificateFingerprint::Algorithm::Sha512: + return "sha-512"; + } + + return ""; +} + } // namespace rtc std::ostream &operator<<(std::ostream &out, const rtc::Description &description) { diff --git a/src/impl/certificate.cpp b/src/impl/certificate.cpp index 3df419da6..65fe1a924 100644 --- a/src/impl/certificate.cpp +++ b/src/impl/certificate.cpp @@ -100,18 +100,20 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey) : mCredentials(gnutls::new_credentials(), gnutls::free_credentials), - mFingerprint(make_fingerprint(crt)) { + mFingerprint(make_fingerprint(crt, CertificateFingerprint::Algorithm::Sha256)) { gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey), "Unable to set certificate and key pair in credentials"); } Certificate::Certificate(shared_ptr creds) - : mCredentials(std::move(creds)), mFingerprint(make_fingerprint(*mCredentials)) {} + : mCredentials(std::move(creds)), + mFingerprint(make_fingerprint(*mCredentials, CertificateFingerprint::Algorithm::Sha256)) {} gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } -string make_fingerprint(gnutls_certificate_credentials_t credentials) { +string make_fingerprint(gnutls_certificate_credentials_t credentials, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * { gnutls_x509_crt_t *crt_list = nullptr; unsigned int crt_list_size = 0; @@ -127,14 +129,34 @@ string make_fingerprint(gnutls_certificate_credentials_t credentials) { unique_ptr crt_list(new_crt_list(), free_crt_list); - return make_fingerprint(*crt_list); + return make_fingerprint(*crt_list, fingerprintAlgorithm); } -string make_fingerprint(gnutls_x509_crt_t crt) { - const size_t size = 32; - unsigned char buffer[size]; +string make_fingerprint(gnutls_x509_crt_t crt, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + const size_t size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); size_t len = size; - gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len), + + gnutls_digest_algorithm_t hashFunc; + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + hashFunc = GNUTLS_DIG_SHA1; + break; + case CertificateFingerprint::Algorithm::Sha224: + hashFunc = GNUTLS_DIG_SHA224; + break; + case CertificateFingerprint::Algorithm::Sha384: + hashFunc = GNUTLS_DIG_SHA384; + break; + case CertificateFingerprint::Algorithm::Sha512: + hashFunc = GNUTLS_DIG_SHA512; + break; + default: + hashFunc = GNUTLS_DIG_SHA256; + } + + gnutls::check(gnutls_x509_crt_get_fingerprint(crt, hashFunc, buffer.data(), &len), "X509 fingerprint error"); std::ostringstream oss; @@ -142,23 +164,44 @@ string make_fingerprint(gnutls_x509_crt_t crt) { for (size_t i = 0; i < len; ++i) { if (i) oss << std::setw(1) << ':'; - oss << std::setw(2) << unsigned(buffer[i]); + oss << std::setw(2) << unsigned(buffer.at(i)); } return oss.str(); } #elif USE_MBEDTLS -string make_fingerprint(mbedtls_x509_crt *crt) { - const int size = 32; - uint8_t buffer[size]; +string make_fingerprint(mbedtls_x509_crt *crt, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + const int size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); std::stringstream fingerprint; - mbedtls::check( - mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 0), - "Failed to generate certificate fingerprint"); + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + mbedtls::check(mbedtls_sha1(crt->raw.p, crt->raw.len, buffer.data()), + "Failed to generate certificate fingerprint"); + break; + case CertificateFingerprint::Algorithm::Sha224: + mbedtls::check(mbedtls_sha256(crt->raw.p, crt->raw.len, buffer.data(), 1), + "Failed to generate certificate fingerprint"); + + break; + case CertificateFingerprint::Algorithm::Sha384: + mbedtls::check(mbedtls_sha512(crt->raw.p, crt->raw.len, buffer.data(), 1), + "Failed to generate certificate fingerprint"); + break; + case CertificateFingerprint::Algorithm::Sha512: + mbedtls::check(mbedtls_sha512(crt->raw.p, crt->raw.len, buffer.data(), 0), + "Failed to generate certificate fingerprint"); + break; + default: + mbedtls::check(mbedtls_sha256(crt->raw.p, crt->raw.len, buffer.data(), 0), + "Failed to generate certificate fingerprint"); + } for (auto i = 0; i < size; i++) { - fingerprint << std::setfill('0') << std::setw(2) << std::hex << static_cast(buffer[i]); + fingerprint << std::setfill('0') << std::setw(2) << std::hex + << static_cast(buffer.at(i)); if (i != (size - 1)) { fingerprint << ":"; } @@ -168,7 +211,8 @@ string make_fingerprint(mbedtls_x509_crt *crt) { } Certificate::Certificate(shared_ptr crt, shared_ptr pk) - : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt.get())) {} + : mCrt(crt), mPk(pk), + mFingerprint(make_fingerprint(crt.get(), CertificateFingerprint::Algorithm::Sha256)) {} Certificate Certificate::FromString(string crt_pem, string key_pem) { PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)"; @@ -465,17 +509,37 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName } Certificate::Certificate(shared_ptr x509, shared_ptr pkey) - : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {} + : mX509(std::move(x509)), mPKey(std::move(pkey)), + mFingerprint(make_fingerprint(mX509.get(), CertificateFingerprint::Algorithm::Sha256)) {} std::tuple Certificate::credentials() const { return {mX509.get(), mPKey.get()}; } -string make_fingerprint(X509 *x509) { - const size_t size = 32; - unsigned char buffer[size]; - unsigned int len = size; - if (!X509_digest(x509, EVP_sha256(), buffer, &len)) +string make_fingerprint(X509 *x509, CertificateFingerprint::Algorithm fingerprintAlgorithm) { + size_t size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); + auto len = static_cast(size); + + const EVP_MD *hashFunc; + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + hashFunc = EVP_sha1(); + break; + case CertificateFingerprint::Algorithm::Sha224: + hashFunc = EVP_sha224(); + break; + case CertificateFingerprint::Algorithm::Sha384: + hashFunc = EVP_sha384(); + break; + case CertificateFingerprint::Algorithm::Sha512: + hashFunc = EVP_sha512(); + break; + default: + hashFunc = EVP_sha256(); + } + + if (!X509_digest(x509, hashFunc, buffer.data(), &len)) throw std::runtime_error("X509 fingerprint error"); std::ostringstream oss; @@ -483,7 +547,7 @@ string make_fingerprint(X509 *x509) { for (size_t i = 0; i < len; ++i) { if (i) oss << std::setw(1) << ':'; - oss << std::setw(2) << unsigned(buffer[i]); + oss << std::setw(2) << unsigned(buffer.at(i)); } return oss.str(); } diff --git a/src/impl/certificate.hpp b/src/impl/certificate.hpp index 800bcd2d9..1ee23da83 100644 --- a/src/impl/certificate.hpp +++ b/src/impl/certificate.hpp @@ -9,6 +9,7 @@ #ifndef RTC_IMPL_CERTIFICATE_H #define RTC_IMPL_CERTIFICATE_H +#include "description.hpp" #include "common.hpp" #include "configuration.hpp" // for CertificateType #include "init.hpp" @@ -57,12 +58,12 @@ class Certificate { }; #if USE_GNUTLS -string make_fingerprint(gnutls_certificate_credentials_t credentials); -string make_fingerprint(gnutls_x509_crt_t crt); +string make_fingerprint(gnutls_certificate_credentials_t credentials, CertificateFingerprint::Algorithm fingerprintAlgorithm); +string make_fingerprint(gnutls_x509_crt_t crt, CertificateFingerprint::Algorithm fingerprintAlgorithm); #elif USE_MBEDTLS -string make_fingerprint(mbedtls_x509_crt *crt); +string make_fingerprint(mbedtls_x509_crt *crt, CertificateFingerprint::Algorithm fingerprintAlgorithm); #else -string make_fingerprint(X509 *x509); +string make_fingerprint(X509 *x509, CertificateFingerprint::Algorithm fingerprintAlgorithm); #endif using certificate_ptr = shared_ptr; diff --git a/src/impl/dtlssrtptransport.cpp b/src/impl/dtlssrtptransport.cpp index 8566338ab..a9cbc8f5a 100644 --- a/src/impl/dtlssrtptransport.cpp +++ b/src/impl/dtlssrtptransport.cpp @@ -58,10 +58,11 @@ bool DtlsSrtpTransport::IsGcmSupported() { DtlsSrtpTransport::DtlsSrtpTransport(shared_ptr lower, shared_ptr certificate, optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, verifier_callback verifierCallback, message_callback srtpRecvCallback, state_callback stateChangeCallback) - : DtlsTransport(lower, certificate, mtu, std::move(verifierCallback), + : DtlsTransport(lower, certificate, mtu, fingerprintAlgorithm, std::move(verifierCallback), std::move(stateChangeCallback)), mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback diff --git a/src/impl/dtlssrtptransport.hpp b/src/impl/dtlssrtptransport.hpp index cd6276b9a..208afab2a 100644 --- a/src/impl/dtlssrtptransport.hpp +++ b/src/impl/dtlssrtptransport.hpp @@ -31,8 +31,9 @@ class DtlsSrtpTransport final : public DtlsTransport { static bool IsGcmSupported(); DtlsSrtpTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - message_callback srtpRecvCallback, state_callback stateChangeCallback); + optional mtu, CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, message_callback srtpRecvCallback, + state_callback stateChangeCallback); ~DtlsSrtpTransport(); bool sendMedia(message_ptr message); diff --git a/src/impl/dtlstransport.cpp b/src/impl/dtlstransport.cpp index aa5182411..aeff1a526 100644 --- a/src/impl/dtlstransport.cpp +++ b/src/impl/dtlstransport.cpp @@ -48,10 +48,11 @@ void DtlsTransport::Init() { void DtlsTransport::Cleanup() { gnutls_global_deinit(); } DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; @@ -295,7 +296,7 @@ int DtlsTransport::CertificateCallback(gnutls_session_t session) { return GNUTLS_E_CERTIFICATE_ERROR; } - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); gnutls_x509_crt_deinit(crt); bool success = t->mVerifierCallback(fingerprint); @@ -374,10 +375,11 @@ const mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = { }; DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (MbedTLS)"; @@ -609,7 +611,7 @@ void DtlsTransport::doRecv() { int DtlsTransport::CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int /*depth*/, uint32_t * /*flags*/) { auto this_ = static_cast(ctx); - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, this_->mFingerprintAlgorithm); std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), [](char c) { return char(std::toupper(c)); }); return this_->mVerifierCallback(fingerprint) ? 0 : 1; @@ -725,10 +727,11 @@ void DtlsTransport::Cleanup() { } DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, - optional mtu, verifier_callback verifierCallback, - state_callback stateChangeCallback) + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), - mVerifierCallback(std::move(verifierCallback)), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), mIsClient(lower->role() == Description::Role::Active) { PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; @@ -1034,7 +1037,7 @@ int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); X509 *crt = X509_STORE_CTX_get_current_cert(ctx); - string fingerprint = make_fingerprint(crt); + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); return t->mVerifierCallback(fingerprint) ? 1 : 0; } diff --git a/src/impl/dtlstransport.hpp b/src/impl/dtlstransport.hpp index d84b312e8..96565b6a1 100644 --- a/src/impl/dtlstransport.hpp +++ b/src/impl/dtlstransport.hpp @@ -32,6 +32,7 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this; DtlsTransport(shared_ptr lower, certificate_ptr certificate, optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, verifier_callback verifierCallback, state_callback stateChangeCallback); ~DtlsTransport(); @@ -52,6 +53,7 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this mMtu; const certificate_ptr mCertificate; + CertificateFingerprint::Algorithm mFingerprintAlgorithm; const verifier_callback mVerifierCallback; const bool mIsClient; diff --git a/src/impl/peerconnection.cpp b/src/impl/peerconnection.cpp index 6c1f648d8..3820ce146 100644 --- a/src/impl/peerconnection.cpp +++ b/src/impl/peerconnection.cpp @@ -211,6 +211,12 @@ shared_ptr PeerConnection::initDtlsTransport() { PLOG_VERBOSE << "Starting DTLS transport"; + auto fingerprintAlgorithm = CertificateFingerprint::Algorithm::Sha256; + if (auto remote = remoteDescription(); + remote && remote->fingerprint().has_value()) { + fingerprintAlgorithm = remote->fingerprint().value().algorithm; + } + auto lower = std::atomic_load(&mIceTransport); if (!lower) throw std::logic_error("No underlying ICE transport for DTLS transport"); @@ -254,7 +260,7 @@ shared_ptr PeerConnection::initDtlsTransport() { // DTLS-SRTP transport = std::make_shared( - lower, certificate, config.mtu, verifierCallback, + lower, certificate, config.mtu, fingerprintAlgorithm, verifierCallback, weak_bind(&PeerConnection::forwardMedia, this, _1), dtlsStateChangeCallback); #else PLOG_WARNING << "Ignoring media support (not compiled with media support)"; @@ -264,7 +270,8 @@ shared_ptr PeerConnection::initDtlsTransport() { if (!transport) { // DTLS only transport = std::make_shared(lower, certificate, config.mtu, - verifierCallback, dtlsStateChangeCallback); + fingerprintAlgorithm, verifierCallback, + dtlsStateChangeCallback); } return emplaceTransport(this, &mDtlsTransport, std::move(transport)); @@ -417,14 +424,17 @@ void PeerConnection::rollbackLocalDescription() { bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { std::lock_guard lock(mRemoteDescriptionMutex); - auto expectedFingerprint = mRemoteDescription ? mRemoteDescription->fingerprint() : nullopt; - if (expectedFingerprint && *expectedFingerprint == fingerprint) { + std::string expectedFingerprint = "[none]"; + if (mRemoteDescription && mRemoteDescription->fingerprint().has_value()) { + expectedFingerprint = mRemoteDescription->fingerprint().value().value; + } + + if (expectedFingerprint == fingerprint) { PLOG_VERBOSE << "Valid fingerprint \"" << fingerprint << "\""; return true; } - PLOG_ERROR << "Invalid fingerprint \"" << fingerprint << "\", expected \"" - << expectedFingerprint.value_or("[none]") << "\""; + PLOG_ERROR << "Invalid fingerprint \"" << fingerprint << "\", expected \"" << expectedFingerprint << "\""; return false; } @@ -994,7 +1004,8 @@ void PeerConnection::processLocalDescription(Description description) { } // Set local fingerprint (wait for certificate if necessary) - description.setFingerprint(mCertificate.get()->fingerprint()); + description.setFingerprint(mCertificate.get()->fingerprint(), + CertificateFingerprint::Algorithm::Sha256); PLOG_VERBOSE << "Issuing local description: " << description; diff --git a/test/connectivity.cpp b/test/connectivity.cpp index ddafcae68..005af4498 100644 --- a/test/connectivity.cpp +++ b/test/connectivity.cpp @@ -52,10 +52,10 @@ void test_connectivity(bool signal_wrong_fingerprint) { if (signal_wrong_fingerprint) { auto f = sdp.fingerprint(); if (f.has_value()) { - auto s = f.value(); + auto s = f.value().value; auto& c = s[0]; if (c == 'F' || c == 'f') c = '0'; else c++; - sdp.setFingerprint(s); + sdp.setFingerprint(s, f->algorithm); } } pc2.setRemoteDescription(string(sdp));