Skip to content

Commit

Permalink
Properly implement the renegotiation needed mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
paullouisageneau committed Jul 30, 2024
1 parent 9daab5b commit 9928725
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 62 deletions.
6 changes: 3 additions & 3 deletions include/rtc/description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ class RTC_CPP_EXPORT Description {
int addAudio(string mid = "audio", Direction dir = Direction::SendOnly);
void clearMedia();

variant<Media *, Application *> media(unsigned int index);
variant<const Media *, const Application *> media(unsigned int index) const;
unsigned int mediaCount() const;
variant<Media *, Application *> media(int index);
variant<const Media *, const Application *> media(int index) const;
int mediaCount() const;

const Application *application() const;
Application *application();
Expand Down
1 change: 1 addition & 0 deletions include/rtc/peerconnection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class RTC_CPP_EXPORT PeerConnection final : CheshireCat<impl::PeerConnection> {
IceState iceState() const;
GatheringState gatheringState() const;
SignalingState signalingState() const;
bool negotiationNeeded() const;
bool hasMedia() const;
optional<Description> localDescription() const;
optional<Description> remoteDescription() const;
Expand Down
4 changes: 2 additions & 2 deletions src/capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ int rtcGetSsrcsForType(const char *mediaType, const char *sdp, uint32_t *buffer,
auto oldSDP = string(sdp);
auto description = Description(oldSDP, "unspec");
auto mediaCount = description.mediaCount();
for (unsigned int i = 0; i < mediaCount; i++) {
for (int i = 0; i < mediaCount; i++) {
if (std::holds_alternative<Description::Media *>(description.media(i))) {
auto media = std::get<Description::Media *>(description.media(i));
auto currentMediaType = lowercased(media->type());
Expand All @@ -1461,7 +1461,7 @@ int rtcSetSsrcForType(const char *mediaType, const char *sdp, char *buffer, cons
auto prevSDP = string(sdp);
auto description = Description(prevSDP, "unspec");
auto mediaCount = description.mediaCount();
for (unsigned int i = 0; i < mediaCount; i++) {
for (int i = 0; i < mediaCount; i++) {
if (std::holds_alternative<Description::Media *>(description.media(i))) {
auto media = std::get<Description::Media *>(description.media(i));
auto currentMediaType = lowercased(media->type());
Expand Down
11 changes: 5 additions & 6 deletions src/description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ void Description::clearMedia() {
mApplication.reset();
}

variant<Description::Media *, Description::Application *> Description::media(unsigned int index) {
if (index >= mEntries.size())
variant<Description::Media *, Description::Application *> Description::media(int index) {
if (index < 0 || index >= int(mEntries.size()))
throw std::out_of_range("Media index out of range");

const auto &entry = mEntries[index];
Expand All @@ -514,9 +514,8 @@ variant<Description::Media *, Description::Application *> Description::media(uns
}
}

variant<const Description::Media *, const Description::Application *>
Description::media(unsigned int index) const {
if (index >= mEntries.size())
variant<const Description::Media *, const Description::Application *> Description::media(int index) const {
if (index < 0 || index >= int(mEntries.size()))
throw std::out_of_range("Media index out of range");

const auto &entry = mEntries[index];
Expand All @@ -536,7 +535,7 @@ Description::media(unsigned int index) const {
}
}

unsigned int Description::mediaCount() const { return unsigned(mEntries.size()); }
int Description::mediaCount() const { return int(mEntries.size()); }

Description::Entry::Entry(const string &mline, string mid, Direction dir)
: mMid(std::move(mid)), mDirection(dir) {
Expand Down
122 changes: 97 additions & 25 deletions src/impl/peerconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ PeerConnection::~PeerConnection() {
}

void PeerConnection::close() {
negotiationNeeded = false;
if (!closing.exchange(true)) {
PLOG_VERBOSE << "Closing PeerConnection";
if (auto transport = std::atomic_load(&mSctpTransport))
Expand Down Expand Up @@ -829,27 +828,58 @@ void PeerConnection::iterateTracks(std::function<void(shared_ptr<Track> track)>
}
}

void PeerConnection::iterateRemoteTracks(std::function<void(shared_ptr<Track> track)> func) {
auto remote = remoteDescription();
if(!remote)
return;

std::vector<shared_ptr<Track>> locked;
{
std::shared_lock lock(mTracksMutex); // read-only
locked.reserve(mTracks.size());
for(int i = 0; i < remote->mediaCount(); ++i) {
if (std::holds_alternative<Description::Media *>(remote->media(i))) {
auto remoteMedia = std::get<Description::Media *>(remote->media(i));
if (!remoteMedia->isRemoved())
if (auto it = mTracks.find(remoteMedia->mid()); it != mTracks.end())
if (auto track = it->second.lock())
locked.push_back(std::move(track));
}
}
}

for (auto &track : locked) {
try {
func(std::move(track));
} catch (const std::exception &e) {
PLOG_WARNING << e.what();
}
}
}


void PeerConnection::openTracks() {
#if RTC_ENABLE_MEDIA
if (auto transport = std::atomic_load(&mDtlsTransport)) {
auto srtpTransport = std::dynamic_pointer_cast<DtlsSrtpTransport>(transport);

iterateTracks([&](const shared_ptr<Track> &track) {
if (!track->isOpen()) {
if (srtpTransport) {
track->open(srtpTransport);
} else {
// A track was added during a latter renegotiation, whereas SRTP transport was
// not initialized. This is an optimization to use the library with data
// channels only. Set forceMediaTransport to true to initialize the transport
// before dynamically adding tracks.
auto errorMsg = "The connection has no media transport";
PLOG_ERROR << errorMsg;
track->triggerError(errorMsg);
}
auto transport = std::atomic_load(&mDtlsTransport);
if (!transport)
return;

auto srtpTransport = std::dynamic_pointer_cast<DtlsSrtpTransport>(transport);
iterateRemoteTracks([&](shared_ptr<Track> track) {
if(!track->isOpen()) {
if (srtpTransport) {
track->open(srtpTransport);
} else {
// A track was added during a latter renegotiation, whereas SRTP transport was
// not initialized. This is an optimization to use the library with data
// channels only. Set forceMediaTransport to true to initialize the transport
// before dynamically adding tracks.
auto errorMsg = "The connection has no media transport";
PLOG_ERROR << errorMsg;
track->triggerError(errorMsg);
}
});
}
}
});
#endif
}

Expand All @@ -872,7 +902,7 @@ void PeerConnection::validateRemoteDescription(const Description &description) {
throw std::invalid_argument("Remote description has no media line");

int activeMediaCount = 0;
for (unsigned int i = 0; i < description.mediaCount(); ++i)
for (int i = 0; i < description.mediaCount(); ++i)
std::visit(rtc::overloaded{[&](const Description::Application *application) {
if (!application->isRemoved())
++activeMediaCount;
Expand Down Expand Up @@ -900,7 +930,7 @@ void PeerConnection::processLocalDescription(Description description) {

if (auto remote = remoteDescription()) {
// Reciprocate remote description
for (unsigned int i = 0; i < remote->mediaCount(); ++i)
for (int i = 0; i < remote->mediaCount(); ++i)
std::visit( // reciprocate each media
rtc::overloaded{
[&](Description::Application *remoteApp) {
Expand Down Expand Up @@ -1027,8 +1057,7 @@ void PeerConnection::processLocalDescription(Description description) {
}
}

// There might be no media at this point if the user created a Track, deleted it,
// then called setLocalDescription().
// There might be no media at this point, for instance if the user deleted tracks
if (description.mediaCount() == 0)
throw std::runtime_error("No DataChannel or Track to negotiate");
}
Expand Down Expand Up @@ -1102,15 +1131,19 @@ void PeerConnection::processRemoteDescription(Description description) {
mRemoteDescription->addCandidates(std::move(existingCandidates));
}

auto dtlsTransport = std::atomic_load(&mDtlsTransport);
if (description.hasApplication()) {
auto dtlsTransport = std::atomic_load(&mDtlsTransport);
auto sctpTransport = std::atomic_load(&mSctpTransport);
if (!sctpTransport && dtlsTransport &&
dtlsTransport->state() == Transport::State::Connected)
initSctpTransport();
} else {
mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels, shared_from_this());
}

if (dtlsTransport && dtlsTransport->state() == Transport::State::Connected)
mProcessor.enqueue(&PeerConnection::openTracks, shared_from_this());

}

void PeerConnection::processRemoteCandidate(Candidate candidate) {
Expand Down Expand Up @@ -1156,6 +1189,45 @@ string PeerConnection::localBundleMid() const {
return mLocalDescription ? mLocalDescription->bundleMid() : "0";
}

bool PeerConnection::negotiationNeeded() const {
auto description = localDescription();

{
std::shared_lock lock(mDataChannelsMutex);
if (!mDataChannels.empty() || !mUnassignedDataChannels.empty())
if(!description || !description->hasApplication()) {
PLOG_DEBUG << "Negotiation needed for data channels";
return true;
}
}

{
std::shared_lock lock(mTracksMutex);
for(const auto &[mid, weakTrack] : mTracks)
if (auto track = weakTrack.lock())
if (!description || !description->hasMid(track->mid())) {
PLOG_DEBUG << "Negotiation needed to add track, mid=" << track->mid();
return true;
}

if(description) {
for(int i = 0; i < description->mediaCount(); ++i) {
if (std::holds_alternative<Description::Media *>(description->media(i))) {
auto media = std::get<Description::Media *>(description->media(i));
if (!media->isRemoved())
if (auto it = mTracks.find(media->mid()); it != mTracks.end())
if (auto track = it->second.lock(); !track || track->isClosed()) {
PLOG_DEBUG << "Negotiation needed to remove track, mid=" << track->mid();
return true;
}
}
}
}
}

return false;
}

void PeerConnection::setMediaHandler(shared_ptr<MediaHandler> handler) {
std::unique_lock lock(mMediaHandlerMutex);
mMediaHandler = handler;
Expand Down Expand Up @@ -1321,7 +1393,7 @@ void PeerConnection::updateTrackSsrcCache(const Description &description) {
std::unique_lock lock(mTracksMutex); // for safely writing to mTracksBySsrc

// Setup SSRC -> Track mapping
for (unsigned int i = 0; i < description.mediaCount(); ++i)
for (int i = 0; i < description.mediaCount(); ++i)
std::visit( // ssrc -> track mapping
rtc::overloaded{
[&](Description::Application const *) { return; },
Expand Down
8 changes: 5 additions & 3 deletions src/impl/peerconnection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {

shared_ptr<Track> emplaceTrack(Description::Media description);
void iterateTracks(std::function<void(shared_ptr<Track> track)> func);
void iterateRemoteTracks(std::function<void(shared_ptr<Track> track)> func);
void openTracks();
void closeTracks();

Expand All @@ -80,6 +81,8 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
void processRemoteCandidate(Candidate candidate);
string localBundleMid() const;

bool negotiationNeeded() const;

void setMediaHandler(shared_ptr<MediaHandler> handler);
shared_ptr<MediaHandler> getMediaHandler();

Expand Down Expand Up @@ -115,7 +118,6 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {
std::atomic<IceState> iceState = IceState::New;
std::atomic<GatheringState> gatheringState = GatheringState::New;
std::atomic<SignalingState> signalingState = SignalingState::Stable;
std::atomic<bool> negotiationNeeded = false;
std::atomic<bool> closing = false;
std::mutex signalingMutex;

Expand Down Expand Up @@ -154,12 +156,12 @@ struct PeerConnection : std::enable_shared_from_this<PeerConnection> {

std::unordered_map<uint16_t, weak_ptr<DataChannel>> mDataChannels; // by stream ID
std::vector<weak_ptr<DataChannel>> mUnassignedDataChannels;
std::shared_mutex mDataChannelsMutex;
mutable std::shared_mutex mDataChannelsMutex;

std::unordered_map<string, weak_ptr<Track>> mTracks; // by mid
std::unordered_map<uint32_t, weak_ptr<Track>> mTracksBySsrc; // by SSRC
std::vector<weak_ptr<Track>> mTrackLines; // by SDP order
std::shared_mutex mTracksMutex;
mutable std::shared_mutex mTracksMutex;

Queue<shared_ptr<DataChannel>> mPendingDataChannels;
Queue<shared_ptr<Track>> mPendingTracks;
Expand Down
Loading

0 comments on commit 9928725

Please sign in to comment.