From 42faf08a12d852dd30b769ef24bac993d5a80207 Mon Sep 17 00:00:00 2001 From: Alan Frindell Date: Thu, 19 Dec 2024 08:58:44 -0800 Subject: [PATCH] Use Consumer interface for MoQSession reads (#13) Summary: This is the second half of the MoQSession rewrite. subscribe and fetch callers now supply a Consumer which the library drives as a callback. To make the consumer API work required changing the codec callbacks. The relay now connects a Forwarder (Consumer) to the upstream subscription directly. Differential Revision: D66881617 --- moxygen/MoQClient.cpp | 3 +- moxygen/MoQCodec.cpp | 126 ++-- moxygen/MoQCodec.h | 35 +- moxygen/MoQServer.cpp | 6 - moxygen/MoQServer.h | 1 - moxygen/MoQSession.cpp | 695 +++++++++++++----- moxygen/MoQSession.h | 188 ++--- moxygen/relay/MoQRelay.cpp | 73 +- moxygen/relay/MoQRelay.h | 4 - moxygen/samples/chat/MoQChatClient.cpp | 59 +- .../MoQFlvStreamerClient.cpp | 5 - moxygen/samples/text-client/MoQTextClient.cpp | 70 +- moxygen/test/MoQCodecTest.cpp | 195 ++++- moxygen/test/MoQSessionTest.cpp | 173 +++-- moxygen/test/Mocks.h | 23 +- 15 files changed, 1067 insertions(+), 589 deletions(-) diff --git a/moxygen/MoQClient.cpp b/moxygen/MoQClient.cpp index 0dd8949..42ecb54 100644 --- a/moxygen/MoQClient.cpp +++ b/moxygen/MoQClient.cpp @@ -190,7 +190,8 @@ void MoQClient::HTTPHandler::onError( void MoQClient::onSessionEnd(folly::Optional) { // TODO: cleanup? XLOG(DBG1) << "resetting moqSession_"; - moqSession_.reset(); + auto moqSession = std::move(moqSession_); + moqSession.reset(); CHECK(!moqSession_); } diff --git a/moxygen/MoQCodec.cpp b/moxygen/MoQCodec.cpp index eb3554e..bcd418a 100644 --- a/moxygen/MoQCodec.cpp +++ b/moxygen/MoQCodec.cpp @@ -103,12 +103,13 @@ void MoQCodec::onIngressEnd( void MoQObjectStreamCodec::onIngress( std::unique_ptr data, - bool eom) { + bool endOfStream) { onIngressStart(std::move(data)); folly::io::Cursor cursor(ingress_.front()); + bool isFetch = std::get_if(&curObjectHeader_.trackIdentifier); while (!connError_ && ((ingress_.chainLength() > 0 && !cursor.isAtEnd())/* || - (eom && parseState_ == ParseState::OBJECT_PAYLOAD_NO_LENGTH)*/)) { + (endOfStream && parseState_ == ParseState::OBJECT_PAYLOAD_NO_LENGTH)*/)) { switch (parseState_) { case ParseState::STREAM_HEADER_TYPE: { auto newCursor = cursor; @@ -146,7 +147,12 @@ void MoQObjectStreamCodec::onIngress( connError_ = res.error(); break; } - curObjectHeader_.trackIdentifier = SubscribeID(res.value()); + auto subscribeID = SubscribeID(res.value()); + curObjectHeader_.trackIdentifier = subscribeID; + isFetch = true; + if (callback_) { + callback_->onFetchHeader(subscribeID); + } parseState_ = ParseState::MULTI_OBJECT_HEADER; cursor = newCursor; break; @@ -160,6 +166,16 @@ void MoQObjectStreamCodec::onIngress( break; } curObjectHeader_ = res.value(); + auto trackAlias = + std::get_if(&curObjectHeader_.trackIdentifier); + XCHECK(trackAlias); + if (callback_) { + callback_->onSubgroup( + *trackAlias, + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.priority); + } parseState_ = ParseState::MULTI_OBJECT_HEADER; cursor = newCursor; [[fallthrough]]; @@ -174,20 +190,60 @@ void MoQObjectStreamCodec::onIngress( break; } curObjectHeader_ = res.value(); - if (callback_) { - callback_->onObjectHeader(std::move(res.value())); - } cursor = newCursor; if (curObjectHeader_.status == ObjectStatus::NORMAL) { - parseState_ = ParseState::OBJECT_PAYLOAD; + XLOG(DBG2) << "Parsing object with length, need=" + << *curObjectHeader_.length + << " have=" << cursor.totalLength(); + std::unique_ptr payload; + auto chunkLen = cursor.cloneAtMost(payload, *curObjectHeader_.length); + auto endOfObject = chunkLen == *curObjectHeader_.length; + if (endOfStream && !endOfObject) { + connError_ = ErrorCode::PARSE_ERROR; + XLOG(DBG4) << __func__ << " " << uint32_t(*connError_); + break; + } + if (callback_) { + callback_->onObjectBegin( + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.id, + *curObjectHeader_.length, + std::move(payload), + endOfObject, + endOfStream && cursor.isAtEnd()); + } + *curObjectHeader_.length -= chunkLen; + if (endOfObject) { + if (endOfStream && cursor.isAtEnd()) { + parseState_ = ParseState::STREAM_FIN_DELIVERED; + } else { + parseState_ = ParseState::MULTI_OBJECT_HEADER; + } + break; + } else { + parseState_ = ParseState::OBJECT_PAYLOAD; + } } else { - parseState_ = ParseState::MULTI_OBJECT_HEADER; + if (callback_) { + callback_->onObjectStatus( + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.id, + curObjectHeader_.status); + } + if (curObjectHeader_.status == ObjectStatus::END_OF_TRACK_AND_GROUP || + (!isFetch && + curObjectHeader_.status == ObjectStatus::END_OF_GROUP)) { + parseState_ = ParseState::STREAM_FIN_DELIVERED; + } else { + parseState_ = ParseState::MULTI_OBJECT_HEADER; + } break; } [[fallthrough]]; } case ParseState::OBJECT_PAYLOAD: { - auto newCursor = cursor; // need to check for bufLen == 0? std::unique_ptr payload; // TODO: skip clone and do split @@ -195,63 +251,41 @@ void MoQObjectStreamCodec::onIngress( XCHECK(curObjectHeader_.length); XLOG(DBG2) << "Parsing object with length, need=" << *curObjectHeader_.length; - if (ingress_.chainLength() > 0 && newCursor.canAdvance(1)) { - chunkLen = newCursor.cloneAtMost(payload, *curObjectHeader_.length); + if (ingress_.chainLength() > 0 && cursor.canAdvance(1)) { + chunkLen = cursor.cloneAtMost(payload, *curObjectHeader_.length); } *curObjectHeader_.length -= chunkLen; - if (eom && *curObjectHeader_.length != 0) { + if (endOfStream && *curObjectHeader_.length != 0) { connError_ = ErrorCode::PARSE_ERROR; XLOG(DBG4) << __func__ << " " << uint32_t(*connError_); break; } bool endOfObject = (*curObjectHeader_.length == 0); if (callback_ && (payload || endOfObject)) { - callback_->onObjectPayload( - curObjectHeader_.trackIdentifier, - curObjectHeader_.group, - curObjectHeader_.id, - std::move(payload), - endOfObject); + callback_->onObjectPayload(std::move(payload), endOfObject); } if (endOfObject) { parseState_ = ParseState::MULTI_OBJECT_HEADER; } - cursor = newCursor; break; } -#if 0 -// This code is no longer reachable, but I'm leaving it here in case -// the wire format changes back - case ParseState::OBJECT_PAYLOAD_NO_LENGTH: { - auto newCursor = cursor; - // need to check for bufLen == 0? - std::unique_ptr payload; - // TODO: skip clone and do split - if (ingress_.chainLength() > 0 && newCursor.canAdvance(1)) { - newCursor.cloneAtMost(payload, std::numeric_limits::max()); - } - XCHECK(!curObjectHeader_.length); - if (callback_ && (payload || eom)) { - callback_->onObjectPayload( - curObjectHeader_.trackIdentifier, - curObjectHeader_.group, - curObjectHeader_.id, - std::move(payload), - eom); - } - if (eom) { - parseState_ = ParseState::FRAME_HEADER_TYPE; - } - cursor = newCursor; + case ParseState::STREAM_FIN_DELIVERED: { + XLOG(DBG2) << "Bytes=" << cursor.totalLength() + << " remaining in STREAM_FIN_DELIVERED"; + connError_ = ErrorCode::PARSE_ERROR; + break; } -#endif } } size_t remainingLength = 0; - if (!eom && !cursor.isAtEnd()) { + if (!endOfStream && !cursor.isAtEnd()) { remainingLength = cursor.totalLength(); // must be less than 1 message } - onIngressEnd(remainingLength, eom, callback_); + if (endOfStream && parseState_ != ParseState::STREAM_FIN_DELIVERED && + !connError_ && callback_) { + callback_->onEndOfStream(); + } + onIngressEnd(remainingLength, endOfStream, callback_); } folly::Expected MoQControlCodec::parseFrame( diff --git a/moxygen/MoQCodec.h b/moxygen/MoQCodec.h index b8acf03..63f88d5 100644 --- a/moxygen/MoQCodec.h +++ b/moxygen/MoQCodec.h @@ -141,15 +141,27 @@ class MoQObjectStreamCodec : public MoQCodec { public: ~ObjectCallback() override = default; - virtual void onFetchHeader(uint64_t subscribeID) = 0; - virtual void onObjectHeader(ObjectHeader objectHeader) = 0; - - virtual void onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) = 0; + virtual void onFetchHeader(SubscribeID subscribeID) = 0; + virtual void onSubgroup( + TrackAlias alias, + uint64_t group, + uint64_t subgroup, + uint8_t priority) = 0; + virtual void onObjectBegin( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool subgroupComplete) = 0; + virtual void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) = 0; + virtual void onObjectPayload(Payload payload, bool objectComplete) = 0; + virtual void onEndOfStream() = 0; }; MoQObjectStreamCodec(ObjectCallback* callback) : callback_(callback) {} @@ -160,10 +172,6 @@ class MoQObjectStreamCodec : public MoQCodec { void onIngress(std::unique_ptr data, bool eom) override; - TrackIdentifier getTrackIdentifier() const { - return curObjectHeader_.trackIdentifier; - } - private: enum class ParseState { STREAM_HEADER_TYPE, @@ -171,6 +179,7 @@ class MoQObjectStreamCodec : public MoQCodec { FETCH_HEADER, MULTI_OBJECT_HEADER, OBJECT_PAYLOAD, + STREAM_FIN_DELIVERED, // OBJECT_PAYLOAD_NO_LENGTH }; ParseState parseState_{ParseState::STREAM_HEADER_TYPE}; diff --git a/moxygen/MoQServer.cpp b/moxygen/MoQServer.cpp index b25bf96..f443d0e 100644 --- a/moxygen/MoQServer.cpp +++ b/moxygen/MoQServer.cpp @@ -92,12 +92,6 @@ void MoQServer::ControlVisitor::operator()(Fetch fetch) const { XLOG(INFO) << "Fetch id=" << fetch.subscribeID; } -void MoQServer::ControlVisitor::operator()(SubscribeDone subscribeDone) const { - XLOG(INFO) << "SubscribeDone id=" << subscribeDone.subscribeID - << " code=" << folly::to_underlying(subscribeDone.statusCode) - << " reason=" << subscribeDone.reasonPhrase; -} - void MoQServer::ControlVisitor::operator()(Unsubscribe unsubscribe) const { XLOG(INFO) << "Unsubscribe id=" << unsubscribe.subscribeID; } diff --git a/moxygen/MoQServer.h b/moxygen/MoQServer.h index c774e8c..ef26744 100644 --- a/moxygen/MoQServer.h +++ b/moxygen/MoQServer.h @@ -43,7 +43,6 @@ class MoQServer : public MoQSession::ServerSetupCallback { void operator()(AnnounceCancel announceCancel) const override; void operator()(SubscribeAnnounces subscribeAnnounces) const override; void operator()(UnsubscribeAnnounces unsubscribeAnnounces) const override; - void operator()(SubscribeDone subscribeDone) const override; void operator()(Unsubscribe unsubscribe) const override; void operator()(TrackStatusRequest trackStatusRequest) const override; void operator()(TrackStatus trackStatus) const override; diff --git a/moxygen/MoQSession.cpp b/moxygen/MoQSession.cpp index 8d391e2..ee6b1c7 100644 --- a/moxygen/MoQSession.cpp +++ b/moxygen/MoQSession.cpp @@ -5,7 +5,7 @@ */ #include "moxygen/MoQSession.h" -#include +#include #include #include @@ -771,7 +771,7 @@ void MoQSession::start() { .start(); co_withCancellation( cancellationSource_.getToken(), - readLoop(StreamType::CONTROL, controlStream.readHandle)) + controlReadLoop(controlStream.readHandle)) .scheduleOn(evb_) .start(); } @@ -942,167 +942,493 @@ MoQSession::controlMessages() { } } -folly::coro::Task MoQSession::readLoop( - StreamType streamType, +folly::coro::Task MoQSession::controlReadLoop( proxygen::WebTransport::StreamReadHandle* readHandle) { XLOG(DBG1) << __func__ << " sess=" << this; auto g = folly::makeGuard([func = __func__, this] { XLOG(DBG1) << "exit " << func << " sess=" << this; }); co_await folly::coro::co_safe_point; - std::unique_ptr codec; - MoQObjectStreamCodec* objCodec = nullptr; - if (streamType == StreamType::CONTROL) { - codec = std::make_unique(dir_, this); - } else { - auto res = std::make_unique(this); - objCodec = res.get(); - codec = std::move(res); - } + MoQControlCodec codec(dir_, this); auto streamId = readHandle->getID(); - codec->setStreamId(streamId); + codec.setStreamId(streamId); - // TODO: disallow OBJECT on control streams and non-object on non-control bool fin = false; auto token = co_await folly::coro::co_current_cancellation_token; - std::shared_ptr track; - folly::CancellationSource cs; while (!fin && !token.isCancellationRequested()) { auto streamData = co_await folly::coro::co_awaitTry( readHandle->readStreamData().via(evb_)); if (streamData.hasException()) { XLOG(ERR) << streamData.exception().what() << " id=" << streamId << " sess=" << this; - // TODO: possibly erase fetch - cs.requestCancellation(); break; } else { if (streamData->data || streamData->fin) { - codec->onIngress(std::move(streamData->data), streamData->fin); - } - if (!track && objCodec) { - // TODO: this might not be set - auto trackId = objCodec->getTrackIdentifier(); - if (auto subscribeID = std::get_if(&trackId)) { - // it's fetch - track = getTrack(trackId); - track->mergeReadCancelToken( - folly::CancellationToken::merge(cs.getToken(), token)); + try { + codec.onIngress(std::move(streamData->data), streamData->fin); + } catch (const std::exception& ex) { + XLOG(FATAL) << "exception thrown from onIngress ex=" << ex.what(); } } fin = streamData->fin; XLOG_IF(DBG3, fin) << "End of stream id=" << streamId << " sess=" << this; } } - if (track) { - track->fin(); - track->setAllDataReceived(); - if (track->fetchOkReceived()) { - fetches_.erase(track->subscribeID()); - checkForCloseOnDrain(); - } + // TODO: close session on control exit +} + +std::shared_ptr MoQSession::getSubscribeTrack( + TrackAlias trackAlias) { + auto trackIt = subTracks_.find(trackAlias); + if (trackIt == subTracks_.end()) { + // received an object for unknown track alias + XLOG(ERR) << "unknown track alias=" << trackAlias << " sess=" << this; + return nullptr; + } + return trackIt->second; +} + +std::shared_ptr MoQSession::getFetchTrack( + SubscribeID subscribeID) { + XLOG(DBG3) << "getTrack subID=" << subscribeID; + auto trackIt = fetches_.find(subscribeID); + if (trackIt == fetches_.end()) { + // received an object for unknown subscribe ID + XLOG(ERR) << "unknown subscribe ID=" << subscribeID << " sess=" << this; + return nullptr; } + return trackIt->second; } -std::shared_ptr MoQSession::getTrack( - TrackIdentifier trackIdentifier) { - // This can be cached in the handling stream - std::shared_ptr track; - auto alias = std::get_if(&trackIdentifier); - if (alias) { - auto trackIt = subTracks_.find(*alias); - if (trackIt == subTracks_.end()) { - // received an object for unknown track alias - XLOG(ERR) << "unknown track alias=" << alias->value << " sess=" << this; - return nullptr; +namespace { +class SubgroupCodecCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + explicit SubgroupCodecCallback(std::shared_ptr session) + : session_(std::move(session)) {} + + void setTrack(std::shared_ptr track) { + track_ = track; + } + + void onSubgroup( + TrackAlias alias, + uint64_t group, + uint64_t subgroup, + Priority priority) override { + XCHECK(track_); + if (!track_->callback_) { + return; } - track = trackIt->second; - } else { - auto subscribeID = std::get(trackIdentifier); - XLOG(DBG3) << "getTrack subID=" << subscribeID; - auto trackIt = fetches_.find(subscribeID); - if (trackIt == fetches_.end()) { - // received an object for unknown subscribe ID - XLOG(ERR) << "unknown subscribe ID=" << subscribeID << " sess=" << this; - return nullptr; + auto res = track_->callback_->beginSubgroup(group, subgroup, priority); + if (res.hasValue()) { + subgroupCallback_ = *res; + } else { + error_ = std::move(res.error()); } - track = trackIt->second; } - return track; -} -void MoQSession::onObjectHeader(ObjectHeader objHeader) { - XLOG(DBG1) << "MoQSession::" << __func__ << " " << objHeader - << " sess=" << this; - auto track = getTrack(objHeader.trackIdentifier); - if (track) { - track->onObjectHeader(std::move(objHeader)); + void onFetchHeader(SubscribeID /*subscribeID*/) override { + XLOG(FATAL) << "unreachable"; } -} -void MoQSession::onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) { - XLOG(DBG1) << __func__ << " sess=" << this; - auto track = getTrack(trackIdentifier); - if (track) { - track->onObjectPayload(groupID, id, std::move(payload), eom); + void onObjectBegin( + uint64_t /*group*/, + uint64_t /*subgroup*/, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool subgroupComplete) override { + XCHECK(track_); + if (!subgroupCallback_ || !track_->callback_) { + return; + } + + folly::Expected res{folly::unit}; + if (objectComplete) { + res = subgroupCallback_->object( + objectID, std::move(initialPayload), subgroupComplete); + if (subgroupComplete) { + subgroupCallback_.reset(); + } + } else { + res = subgroupCallback_->beginObject( + objectID, length, std::move(initialPayload)); + } + if (!res) { + error_ = std::move(res.error()); + } } -} -void MoQSession::TrackHandle::onObjectHeader(ObjectHeader objHeader) { - XLOG(DBG1) << __func__ << " objHeader=" << objHeader - << " trackHandle=" << this; - uint64_t objectIdKey = objHeader.id; - auto status = objHeader.status; - if (status != ObjectStatus::NORMAL) { - objectIdKey |= (1ull << 63); + void onObjectPayload(Payload payload, bool objectComplete) override { + XCHECK(track_); + if (!subgroupCallback_ || !!track_->callback_) { + return; + } + + auto res = + subgroupCallback_->objectPayload(std::move(payload), objectComplete); + if (!res) { + error_ = std::move(res.error()); + } } - auto res = objects_.emplace( - std::piecewise_construct, - std::forward_as_tuple(std::make_pair(objHeader.group, objectIdKey)), - std::forward_as_tuple(std::make_shared())); - res.first->second->header = std::move(objHeader); - res.first->second->fullTrackName = fullTrackName_; - res.first->second->cancelToken = cancelToken_; - if (status != ObjectStatus::NORMAL) { - res.first->second->payloadQueue.enqueue(nullptr); - } - // TODO: objects_ accumulates the headers of all objects for the life of the - // track. Remove an entry from objects when returning the end of the payload, - // or the object itself for non-normal. - newObjects_.enqueue(res.first->second); -} - -void MoQSession::TrackHandle::fin() { - newObjects_.enqueue(nullptr); -} - -void MoQSession::TrackHandle::onObjectPayload( - uint64_t groupId, - uint64_t id, - std::unique_ptr payload, - bool eom) { - XLOG(DBG1) << __func__ << " g=" << groupId << " o=" << id - << " len=" << (payload ? payload->computeChainDataLength() : 0) - << " eom=" << uint64_t(eom) << " trackHandle=" << this; - auto objIt = objects_.find(std::make_pair(groupId, id)); - if (objIt == objects_.end()) { - // error; - XLOG(ERR) << "unknown object gid=" << groupId << " seq=" << id - << " trackHandle=" << this; - return; + + void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) override { + XCHECK(track_); + if (!subgroupCallback_ || !track_->callback_) { + return; + } + + folly::Expected res{folly::unit}; + switch (status) { + case ObjectStatus::NORMAL: + XLOG(FATAL) << "Unreachable"; + break; + case ObjectStatus::OBJECT_NOT_EXIST: + res = subgroupCallback_->objectNotExists(objectID); + break; + case ObjectStatus::GROUP_NOT_EXIST: + res = track_->callback_->groupNotExists(group, subgroup, true); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_GROUP: + res = subgroupCallback_->endOfGroup(objectID); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_TRACK_AND_GROUP: + res = subgroupCallback_->endOfTrackAndGroup(objectID); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_SUBGROUP: + res = subgroupCallback_->endOfSubgroup(); + subgroupCallback_.reset(); + break; + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onEndOfStream() override { + XCHECK(track_); + if (subgroupCallback_ && track_->callback_) { + auto res = subgroupCallback_->endOfSubgroup(); + if (!res) { + error_ = std::move(res.error()); + } + subgroupCallback_.reset(); + } + } + + void onConnectionError(ErrorCode error) override { + XLOG(ERR) << "Parse error=" << folly::to_underlying(error); + session_->close(SessionCloseErrorCode::PROTOCOL_VIOLATION); + } + + // Called by read loop on read error (eg: RESET_STREAM) + bool reset(ResetStreamErrorCode error) { + if (subgroupCallback_) { + XCHECK(track_); + if (track_->callback_) { + // ignoring error from reset? + subgroupCallback_->reset(error); + } + subgroupCallback_.reset(); + return true; + } + return false; } - if (payload) { - XLOG(DBG1) << "payload enqueued trackHandle=" << this; - objIt->second->payloadQueue.enqueue(std::move(payload)); + + folly::Optional error() const { + return error_; + } + + private: + std::shared_ptr session_; + std::shared_ptr track_; + std::shared_ptr subgroupCallback_; + folly::Optional error_; +}; + +class FetchCodecCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + explicit FetchCodecCallback(std::shared_ptr session) + : session_(std::move(session)) {} + + void setTrack(std::shared_ptr track) { + track_ = track; + } + + void onSubgroup(TrackAlias, uint64_t, uint64_t, Priority) override { + XLOG(FATAL) << "unreachable"; } - if (eom) { - XLOG(DBG1) << "eom enqueued trackHandle=" << this; - objIt->second->payloadQueue.enqueue(nullptr); + + void onFetchHeader(SubscribeID) override {} + + void onObjectBegin( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool fetchComplete) override { + XCHECK(track_); + folly::Expected res{folly::unit}; + if (!track_->fetchCallback_) { + return; + } + if (objectComplete) { + res = track_->fetchCallback_->object( + group, subgroup, objectID, std::move(initialPayload), fetchComplete); + } else { + res = track_->fetchCallback_->beginObject( + group, subgroup, objectID, length, std::move(initialPayload)); + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onObjectPayload(Payload payload, bool objectComplete) override { + XCHECK(track_); + if (!track_->fetchCallback_) { + return; + } + auto res = track_->fetchCallback_->objectPayload( + std::move(payload), + /*finSubgroup=*/false); + if (!res) { + error_ = std::move(res.error()); + } else { + XCHECK_EQ(objectComplete, res.value() == ObjectPublishStatus::DONE); + } + } + + void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) override { + XCHECK(track_); + if (!track_->fetchCallback_) { + return; + } + folly::Expected res{folly::unit}; + switch (status) { + case ObjectStatus::NORMAL: + break; + case ObjectStatus::OBJECT_NOT_EXIST: + res = + track_->fetchCallback_->objectNotExists(group, subgroup, objectID); + break; + case ObjectStatus::GROUP_NOT_EXIST: + res = track_->fetchCallback_->groupNotExists(group, subgroup, false); + break; + case ObjectStatus::END_OF_GROUP: + res = track_->fetchCallback_->endOfGroup( + group, subgroup, objectID, false); + break; + case ObjectStatus::END_OF_TRACK_AND_GROUP: + res = track_->fetchCallback_->endOfTrackAndGroup( + group, subgroup, objectID); + track_->fetchCallback_.reset(); + break; + case ObjectStatus::END_OF_SUBGROUP: + break; + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onEndOfStream() override { + XCHECK(track_); + if (track_->fetchCallback_) { + track_->fetchCallback_->endOfFetch(); + track_->fetchCallback_.reset(); + } + } + + void onConnectionError(ErrorCode error) override { + XLOG(ERR) << "Parse error=" << folly::to_underlying(error); + session_->close(SessionCloseErrorCode::PROTOCOL_VIOLATION); + } + + bool reset(ResetStreamErrorCode error) { + if (track_) { + if (track_->fetchCallback_) { + track_->fetchCallback_->reset(error); + track_->fetchCallback_.reset(); + } + return true; + } + return false; + } + + folly::Optional error() const { + return error_; + } + + private: + std::shared_ptr session_; + std::shared_ptr track_; + folly::Optional error_; +}; + +class DispatchCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + DispatchCallback( + std::shared_ptr session, + MoQObjectStreamCodec& codec, + folly::CancellationToken& token) + : session_(session), + subgroupCallback_(session), + fetchCallback_(session), + codec_(codec), + token_(token) {} + + void onSubgroup( + TrackAlias alias, + uint64_t group, + uint64_t subgroup, + Priority priority) override { + auto track = session_->getSubscribeTrack(alias); + if (!track) { + error_ = MoQPublishError( + MoQPublishError::CANCELLED, "Subgroup for unknown track"); + return; + } + token_ = folly::CancellationToken::merge(token_, track->getCancelToken()); + codec_.setCallback(&subgroupCallback_); + subgroupCallback_.setTrack(std::move(track)); + subgroupCallback_.onSubgroup(alias, group, subgroup, priority); + } + + void onFetchHeader(SubscribeID subscribeID) override { + auto track = session_->getFetchTrack(subscribeID); + + if (!track) { + error_ = MoQPublishError( + MoQPublishError::CANCELLED, "Fetch response for unknown track"); + return; + } + token_ = folly::CancellationToken::merge(token_, track->getCancelToken()); + codec_.setCallback(&fetchCallback_); + fetchCallback_.setTrack(std::move(track)); + fetchCallback_.onFetchHeader(subscribeID); + } + + void + onObjectBegin(uint64_t, uint64_t, uint64_t, uint64_t, Payload, bool, bool) + override { + XCHECK(error_); + } + + void onObjectPayload(Payload, bool) override { + XCHECK(error_); + } + + void onObjectStatus(uint64_t, uint64_t, uint64_t, ObjectStatus) override { + XCHECK(error_); + } + + void onEndOfStream() override { + XCHECK(error_); + } + + void onConnectionError(ErrorCode error) override { + subgroupCallback_.onConnectionError(error); + } + + bool reset(ResetStreamErrorCode error) { + return (subgroupCallback_.reset(error) || fetchCallback_.reset(error)); + } + + folly::Optional error() const { + if (error_) { + return error_; + } + auto err = subgroupCallback_.error(); + if (err) { + return err; + } else { + return fetchCallback_.error(); + } + } + + private: + std::shared_ptr session_; + SubgroupCodecCallback subgroupCallback_; + FetchCodecCallback fetchCallback_; + MoQObjectStreamCodec& codec_; + folly::CancellationToken& token_; + folly::Optional error_; +}; +} // namespace + +folly::coro::Task MoQSession::unidirectionalReadLoop( + std::shared_ptr session, + proxygen::WebTransport::StreamReadHandle* readHandle) { + XLOG(DBG1) << __func__ << " sess=" << this; + auto g = folly::makeGuard([func = __func__, this] { + XLOG(DBG1) << "exit " << func << " sess=" << this; + }); + co_await folly::coro::co_safe_point; + auto token = co_await folly::coro::co_current_cancellation_token; + MoQObjectStreamCodec codec(nullptr); + DispatchCallback dcb(session, codec, /*by ref*/ token); + codec.setCallback(&dcb); + auto id = readHandle->getID(); + codec.setStreamId(id); + + bool fin = false; + while (!fin && !token.isCancellationRequested()) { + auto streamData = + co_await folly::coro::co_awaitTry(folly::coro::co_withCancellation( + token, + folly::coro::toTaskInterruptOnCancel( + readHandle->readStreamData().via(evb_)))); + if (streamData.hasException()) { + XLOG(ERR) << streamData.exception().what() << " id=" << id + << " sess=" << this; + ResetStreamErrorCode errorCode{ResetStreamErrorCode::INTERNAL_ERROR}; + auto wtEx = + streamData.tryGetExceptionObject(); + if (wtEx) { + errorCode = ResetStreamErrorCode(wtEx->error); + } else { + XLOG(ERR) << streamData.exception().what(); + } + if (!dcb.reset(errorCode)) { + XLOG(ERR) << __func__ << " terminating for unknown " + << "stream id=" << id << " sess=" << this; + } + break; + } else { + if (streamData->data || streamData->fin) { + fin = streamData->fin; + folly::Optional err; + try { + codec.onIngress(std::move(streamData->data), streamData->fin); + err = dcb.error(); + } catch (const std::exception& ex) { + err = MoQPublishError(MoQPublishError::CANCELLED, ex.what()); + } + XLOG_IF(DBG3, fin) << "End of stream id=" << id << " sess=" << this; + if (err) { + XLOG(ERR) << "Error parsing stream, err=" << err->what(); + if (!fin) { + readHandle->stopSending(/*error=*/0); + break; + } + } + } // else empty read + } } } @@ -1171,8 +1497,12 @@ void MoQSession::onSubscribeOk(SubscribeOk subOk) { << " sess=" << this; return; } - subTracks_[trackAliasIt->second]->subscribeOK( - subTracks_[trackAliasIt->second], subOk.groupOrder, subOk.latest); + auto trackHandleIt = subTracks_.find(trackAliasIt->second); + if (trackHandleIt != subTracks_.end()) { + trackHandleIt->second->subscribeOK(std::move(subOk)); + } else { + XLOG(ERR) << "Missing subTracks_ entry for alias=" << trackAliasIt->second; + } } void MoQSession::onSubscribeError(SubscribeError subErr) { @@ -1184,10 +1514,16 @@ void MoQSession::onSubscribeError(SubscribeError subErr) { << " sess=" << this; return; } - subTracks_[trackAliasIt->second]->subscribeError(std::move(subErr)); - subTracks_.erase(trackAliasIt->second); - subIdToTrackAlias_.erase(trackAliasIt); - checkForCloseOnDrain(); + + auto trackHandleIt = subTracks_.find(trackAliasIt->second); + if (trackHandleIt != subTracks_.end()) { + trackHandleIt->second->subscribeError(std::move(subErr)); + subTracks_.erase(trackHandleIt); + subIdToTrackAlias_.erase(trackAliasIt); + checkForCloseOnDrain(); + } else { + XLOG(ERR) << "Missing subTracks_ entry for alias=" << trackAliasIt->second; + } } void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { @@ -1210,13 +1546,12 @@ void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { if (trackHandleIt != subTracks_.end()) { auto trackHandle = trackHandleIt->second; subTracks_.erase(trackHandleIt); - trackHandle->fin(); + trackHandle->subscribeDone(std::move(subscribeDone)); } else { XLOG(DFATAL) << "trackAliasIt but no trackHandleIt for id=" << subscribeDone.subscribeID << " sess=" << this; } subIdToTrackAlias_.erase(trackAliasIt); - controlMessages_.enqueue(std::move(subscribeDone)); checkForCloseOnDrain(); } @@ -1277,7 +1612,7 @@ void MoQSession::onFetchCancel(FetchCancel fetchCancel) { // It's possible the fetch stream hasn't opened yet if the application // hasn't made it to fetchOK. pubTrackIt->second->reset(ResetStreamErrorCode::CANCELLED); - retireSubscribeId(/*signal=*/true); + retireSubscribeId(/*signalWriteLoop=*/true); } } @@ -1290,9 +1625,9 @@ void MoQSession::onFetchOk(FetchOk fetchOk) { return; } auto trackHandle = fetchIt->second; - trackHandle->fetchOK(trackHandle); + trackHandle->fetchOK(); if (trackHandle->allDataReceived()) { - fetches_.erase(trackHandle->subscribeID()); + fetches_.erase(fetchIt); } } @@ -1407,8 +1742,11 @@ void MoQSession::onGoaway(Goaway goaway) { void MoQSession::onConnectionError(ErrorCode error) { XLOG(DBG1) << __func__ << " sess=" << this; - XLOG(ERR) << "err=" << folly::to_underlying(error); - // TODO + XLOG(ERR) << "MoQCodec control stream parse error err=" + << folly::to_underlying(error); + // TODO: This error is coming from MoQCodec -- do we need a better + // error code? + close(SessionCloseErrorCode::PROTOCOL_VIOLATION); } folly::coro::Task> @@ -1500,37 +1838,10 @@ void MoQSession::subscribeAnnouncesError( controlWriteEvent_.signal(); } -folly::coro::AsyncGenerator< - std::shared_ptr> -MoQSession::TrackHandle::objects() { - XLOG(DBG1) << __func__ << " trackHandle=" << this; - auto g = - folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); - auto cancelToken = co_await folly::coro::co_current_cancellation_token; - auto mergeToken = folly::CancellationToken::merge(cancelToken, cancelToken_); - folly::EventBaseThreadTimekeeper tk(*evb_); - while (!cancelToken.isCancellationRequested()) { - auto optionalObj = newObjects_.try_dequeue(); - std::shared_ptr obj; - if (optionalObj) { - obj = *optionalObj; - } else { - obj = co_await folly::coro::co_withCancellation( - mergeToken, - folly::coro::timeout(newObjects_.dequeue(), objectTimeout_, &tk)); - } - if (!obj) { - XLOG(DBG3) << "Out of objects for trackHandle=" << this - << " id=" << subscribeID_; - break; - } - co_yield obj; - } -} - -folly::coro::Task< - folly::Expected, SubscribeError>> -MoQSession::subscribe(SubscribeRequest sub) { +folly::coro::Task> +MoQSession::subscribe( + SubscribeRequest sub, + std::shared_ptr callback) { XLOG(DBG1) << __func__ << " sess=" << this; auto fullTrackName = sub.fullTrackName; if (nextSubscribeID_ >= peerMaxSubscribeID_) { @@ -1552,17 +1863,21 @@ MoQSession::subscribe(SubscribeRequest sub) { controlWriteEvent_.signal(); auto res = subIdToTrackAlias_.emplace(subID, trackAlias); XCHECK(res.second) << "Duplicate subscribe ID"; + auto trackHandle = std::make_shared( + fullTrackName, subID, getEventBase(), callback, nullptr); auto subTrack = subTracks_.emplace( std::piecewise_construct, std::forward_as_tuple(trackAlias), - std::forward_as_tuple(std::make_shared( - fullTrackName, subID, evb_, cancellationSource_.getToken()))); + std::forward_as_tuple(trackHandle)); - auto trackHandle = subTrack.first->second; auto res2 = co_await trackHandle->ready(); XLOG(DBG1) << "Subscribe ready trackHandle=" << trackHandle << " subscribeID=" << subID; - co_return res2; + if (res2.hasValue()) { + co_return std::move(res2.value()); + } else { + co_return folly::makeUnexpected(res2.error()); + } } std::shared_ptr MoQSession::subscribeOk(SubscribeOk subOk) { @@ -1602,7 +1917,7 @@ void MoQSession::subscribeError(SubscribeError subErr) { } pubTracks_.erase(it); auto res = writeSubscribeError(controlWriteBuf_, std::move(subErr)); - retireSubscribeId(/*signal=*/false); + retireSubscribeId(/*signalWriteLoop=*/false); if (!res) { XLOG(ERR) << "writeSubscribeError failed sess=" << this; return; @@ -1631,6 +1946,7 @@ void MoQSession::unsubscribe(Unsubscribe unsubscribe) { << " sess=" << this; // if there are open streams for this subscription, we should STOP_SENDING // them? + trackIt->second->removeCallback(); auto res = writeUnsubscribe(controlWriteBuf_, std::move(unsubscribe)); if (!res) { XLOG(ERR) << "writeUnsubscribe failed sess=" << this; @@ -1665,7 +1981,7 @@ void MoQSession::subscribeDone(SubscribeDone subDone) { return; } - retireSubscribeId(/*signal=*/false); + retireSubscribeId(/*signalWriteLoop=*/false); controlWriteEvent_.signal(); } @@ -1679,7 +1995,7 @@ void MoQSession::retireSubscribeId(bool signal) { } } -void MoQSession::sendMaxSubscribeID(bool signal) { +void MoQSession::sendMaxSubscribeID(bool signalWriteLoop) { XLOG(DBG1) << "Issuing new maxSubscribeID=" << maxSubscribeID_ << " sess=" << this; auto res = @@ -1688,7 +2004,7 @@ void MoQSession::sendMaxSubscribeID(bool signal) { XLOG(ERR) << "writeMaxSubscribeId failed sess=" << this; return; } - if (signal) { + if (signalWriteLoop) { controlWriteEvent_.signal(); } } @@ -1735,9 +2051,9 @@ void MoQSession::subscribeUpdate(SubscribeUpdate subUpdate) { controlWriteEvent_.signal(); } -folly::coro::Task< - folly::Expected, FetchError>> -MoQSession::fetch(Fetch fetch) { +folly::coro::Task> MoQSession::fetch( + Fetch fetch, + std::shared_ptr fetchCallback) { XLOG(DBG1) << __func__ << " sess=" << this; auto g = folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); @@ -1757,13 +2073,13 @@ MoQSession::fetch(Fetch fetch) { FetchError({subID, 500, "local write failed"})); } controlWriteEvent_.signal(); + auto trackHandle = std::make_shared( + fullTrackName, subID, getEventBase(), nullptr, fetchCallback); auto subTrack = fetches_.emplace( std::piecewise_construct, std::forward_as_tuple(subID), - std::forward_as_tuple(std::make_shared( - fullTrackName, subID, evb_, cancellationSource_.getToken()))); + std::forward_as_tuple(trackHandle)); - auto trackHandle = subTrack.first->second; trackHandle->setNewObjectTimeout(std::chrono::seconds(2)); auto res = co_await trackHandle->fetchReady(); XLOG(DBG1) << __func__ << " fetchReady trackHandle=" << trackHandle; @@ -1830,6 +2146,7 @@ void MoQSession::fetchCancel(FetchCancel fetchCan) { << " sess=" << this; return; } + trackIt->second->removeCallback(); auto res = writeFetchCancel(controlWriteBuf_, std::move(fetchCan)); if (!res) { XLOG(ERR) << "writeFetchCancel failed sess=" << this; @@ -1848,7 +2165,7 @@ void MoQSession::onNewUniStream(proxygen::WebTransport::StreamReadHandle* rh) { // maybe not STREAM_HEADER_SUBGROUP, but at least not control co_withCancellation( cancellationSource_.getToken(), - readLoop(StreamType::STREAM_HEADER_SUBGROUP, rh)) + unidirectionalReadLoop(shared_from_this(), rh)) .scheduleOn(evb_) .start(); } @@ -1863,8 +2180,7 @@ void MoQSession::onNewBidiStream(proxygen::WebTransport::BidiStreamHandle bh) { } else { bh.writeHandle->setPriority(0, 0, false); co_withCancellation( - cancellationSource_.getToken(), - readLoop(StreamType::CONTROL, bh.readHandle)) + cancellationSource_.getToken(), controlReadLoop(bh.readHandle)) .scheduleOn(evb_) .start(); auto mergeToken = folly::CancellationToken::merge( @@ -1903,12 +2219,9 @@ void MoQSession::onDatagram(std::unique_ptr datagram) { readBuf.trimStart(dgLength - remainingLength); auto alias = std::get_if(&res->trackIdentifier); XCHECK(alias); - auto track = getTrack(*alias); - if (track) { - auto groupID = res->group; - auto objID = res->id; - track->onObjectHeader(std::move(*res)); - track->onObjectPayload(groupID, objID, readBuf.move(), true); + auto track = getSubscribeTrack(*alias); + if (track && track->callback_) { + track->callback_->datagram(std::move(*res), readBuf.move()); } } diff --git a/moxygen/MoQSession.h b/moxygen/MoQSession.h index a2c15bb..0d58db6 100644 --- a/moxygen/MoQSession.h +++ b/moxygen/MoQSession.h @@ -24,8 +24,8 @@ namespace moxygen { class MoQSession : public MoQControlCodec::ControlCallback, - public MoQObjectStreamCodec::ObjectCallback, - public proxygen::WebTransportHandler { + public proxygen::WebTransportHandler, + public std::enable_shared_from_this { public: class ServerSetupCallback { public: @@ -85,7 +85,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, SubscribeRequest, SubscribeUpdate, Unsubscribe, - SubscribeDone, Fetch, TrackStatusRequest, TrackStatus, @@ -136,10 +135,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, XLOG(INFO) << "SubscribeUpdate subID=" << subscribeUpdate.subscribeID; } - virtual void operator()(SubscribeDone subscribeDone) const { - XLOG(INFO) << "SubscribeDone subID=" << subscribeDone.subscribeID; - } - virtual void operator()(Unsubscribe unsubscribe) const { XLOG(INFO) << "Unsubscribe subID=" << unsubscribe.subscribeID; } @@ -189,21 +184,24 @@ class MoQSession : public MoQControlCodec::ControlCallback, class TrackHandle { public: + using SubscribeResult = folly::Expected; + TrackHandle( FullTrackName fullTrackName, SubscribeID subscribeID, folly::EventBase* evb, - folly::CancellationToken token) - : fullTrackName_(std::move(fullTrackName)), + std::shared_ptr callback, + std::shared_ptr fetchCallback) + : callback_(std::move(callback)), + fetchCallback_(std::move(fetchCallback)), + fullTrackName_(std::move(fullTrackName)), subscribeID_(subscribeID), - evb_(evb), - cancelToken_(std::move(token)) { - auto contract = folly::coro::makePromiseContract< - folly::Expected, SubscribeError>>(); + evb_(evb) { + auto contract = folly::coro::makePromiseContract(); promise_ = std::move(contract.first); future_ = std::move(contract.second); auto contract2 = folly::coro::makePromiseContract< - folly::Expected, FetchError>>(); + folly::Expected>(); fetchPromise_ = std::move(contract2.first); fetchFuture_ = std::move(contract2.second); } @@ -216,7 +214,7 @@ class MoQSession : public MoQControlCodec::ControlCallback, return fullTrackName_; } - SubscribeID subscribeID() const { + [[nodiscard]] SubscribeID subscribeID() const { return subscribeID_; } @@ -224,103 +222,57 @@ class MoQSession : public MoQControlCodec::ControlCallback, objectTimeout_ = objectTimeout; } - [[nodiscard]] folly::CancellationToken getCancelToken() const { - return cancelToken_; + folly::CancellationToken getCancelToken() { + return cancelSource_.getToken(); } - void mergeReadCancelToken(folly::CancellationToken readToken) { - cancelToken_ = folly::CancellationToken::merge(cancelToken_, readToken); + folly::coro::Task ready() { + co_return co_await std::move(future_); } - void fin(); - - folly::coro::Task< - folly::Expected, SubscribeError>> - ready() { - co_return co_await std::move(future_); + void removeCallback() { + callback_.reset(); + fetchCallback_.reset(); + cancelSource_.requestCancellation(); } - void subscribeOK( - std::shared_ptr self, - GroupOrder order, - folly::Optional latest) { - XCHECK_EQ(self.get(), this); - groupOrder_ = order; - latest_ = std::move(latest); - promise_.setValue(std::move(self)); + void subscribeOK(SubscribeOk subscribeOK) { + groupOrder_ = subscribeOK.groupOrder; + latest_ = subscribeOK.latest; + promise_.setValue(std::move(subscribeOK)); } void subscribeError(SubscribeError subErr) { + XLOG(DBG1) << __func__ << " trackHandle=" << this; if (!promise_.isFulfilled()) { subErr.subscribeID = subscribeID_; promise_.setValue(folly::makeUnexpected(std::move(subErr))); + } else { + subscribeDone( + {subscribeID_, + SubscribeDoneStatusCode::INTERNAL_ERROR, + "closed locally", + folly::none}); } } - folly::coro::Task, FetchError>> - fetchReady() { + void subscribeDone(SubscribeDone subDone) { + XLOG(DBG1) << __func__ << " trackHandle=" << this; + if (callback_) { + callback_->subscribeDone(std::move(subDone)); + } // else, unsubscribe raced with subscribeDone and callback was removed + } + + folly::coro::Task> fetchReady() { co_return co_await std::move(fetchFuture_); } - void fetchOK(std::shared_ptr self) { - XCHECK_EQ(self.get(), this); + void fetchOK() { XLOG(DBG1) << __func__ << " trackHandle=" << this; - fetchPromise_.setValue(std::move(self)); + fetchPromise_.setValue(subscribeID_); } void fetchError(FetchError fetchErr) { if (!promise_.isFulfilled()) { fetchErr.subscribeID = subscribeID_; fetchPromise_.setValue(folly::makeUnexpected(std::move(fetchErr))); - } - } - - struct ObjectSource { - ObjectHeader header; - FullTrackName fullTrackName; - folly::CancellationToken cancelToken; - - folly::coro::UnboundedQueue, true, true> - payloadQueue; - - folly::coro::Task> payload() { - if (header.status != ObjectStatus::NORMAL) { - co_return nullptr; - } - folly::IOBufQueue payloadBuf{folly::IOBufQueue::cacheChainLength()}; - auto curCancelToken = - co_await folly::coro::co_current_cancellation_token; - auto mergeToken = - folly::CancellationToken::merge(curCancelToken, cancelToken); - while (!curCancelToken.isCancellationRequested()) { - std::unique_ptr buf; - auto optionalBuf = payloadQueue.try_dequeue(); - if (optionalBuf) { - buf = std::move(*optionalBuf); - } else { - buf = co_await folly::coro::co_withCancellation( - cancelToken, payloadQueue.dequeue()); - } - if (!buf) { - break; - } - payloadBuf.append(std::move(buf)); - } - co_return payloadBuf.move(); - } - }; - - void onObjectHeader(ObjectHeader objHeader); - void onObjectPayload( - uint64_t groupId, - uint64_t id, - std::unique_ptr payload, - bool eom); - - folly::coro::AsyncGenerator> objects(); - - GroupOrder groupOrder() const { - return groupOrder_; - } - - folly::Optional latest() { - return latest_; + } // there's likely a missing case here from shutdown } void setAllDataReceived() { @@ -335,54 +287,40 @@ class MoQSession : public MoQControlCodec::ControlCallback, return fetchPromise_.isFulfilled(); } + std::shared_ptr callback_; + std::shared_ptr fetchCallback_; + private: FullTrackName fullTrackName_; SubscribeID subscribeID_; folly::EventBase* evb_; - using SubscribeResult = - folly::Expected, SubscribeError>; folly::coro::Promise promise_; folly::coro::Future future_; - using FetchResult = - folly::Expected, FetchError>; + using FetchResult = folly::Expected; folly::coro::Promise fetchPromise_; folly::coro::Future fetchFuture_; - folly:: - F14FastMap, std::shared_ptr> - objects_; - folly::coro::UnboundedQueue, true, true> - newObjects_; GroupOrder groupOrder_; folly::Optional latest_; - folly::CancellationToken cancelToken_; std::chrono::milliseconds objectTimeout_{std::chrono::hours(24)}; + folly::CancellationSource cancelSource_; bool allDataReceived_{false}; }; - folly::coro::Task< - folly::Expected, SubscribeError>> - subscribe(SubscribeRequest sub); + folly::coro::Task subscribe( + SubscribeRequest sub, + std::shared_ptr callback); std::shared_ptr subscribeOk(SubscribeOk subOk); void subscribeError(SubscribeError subErr); void unsubscribe(Unsubscribe unsubscribe); void subscribeUpdate(SubscribeUpdate subUpdate); - folly::coro::Task, FetchError>> - fetch(Fetch fetch); + folly::coro::Task> fetch( + Fetch fetch, + std::shared_ptr fetchCallback); std::shared_ptr fetchOk(FetchOk fetchOk); void fetchError(FetchError fetchError); void fetchCancel(FetchCancel fetchCancel); - class WebTransportException : public std::runtime_error { - public: - explicit WebTransportException( - proxygen::WebTransport::ErrorCode error, - const std::string& msg) - : std::runtime_error(msg), errorCode(error) {} - - proxygen::WebTransport::ErrorCode errorCode; - }; - class PublisherImpl { public: PublisherImpl( @@ -440,28 +378,24 @@ class MoQSession : public MoQControlCodec::ControlCallback, close(); } + std::shared_ptr getFetchTrack(SubscribeID subscribeID); + std::shared_ptr getSubscribeTrack(TrackAlias trackAlias); + private: void cleanup(); folly::coro::Task controlWriteLoop( proxygen::WebTransport::StreamWriteHandle* writeHandle); - folly::coro::Task readLoop( - StreamType streamType, + folly::coro::Task controlReadLoop( + proxygen::WebTransport::StreamReadHandle* readHandle); + folly::coro::Task unidirectionalReadLoop( + std::shared_ptr session, proxygen::WebTransport::StreamReadHandle* readHandle); - std::shared_ptr getTrack(TrackIdentifier trackidentifier); void subscribeDone(SubscribeDone subDone); void onClientSetup(ClientSetup clientSetup) override; void onServerSetup(ServerSetup setup) override; - void onObjectHeader(ObjectHeader objectHeader) override; - void onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) override; - void onFetchHeader(uint64_t) override {} void onSubscribe(SubscribeRequest subscribeRequest) override; void onSubscribeUpdate(SubscribeUpdate subscribeUpdate) override; void onSubscribeOk(SubscribeOk subscribeOk) override; diff --git a/moxygen/relay/MoQRelay.cpp b/moxygen/relay/MoQRelay.cpp index dffdcbb..c4680f4 100644 --- a/moxygen/relay/MoQRelay.cpp +++ b/moxygen/relay/MoQRelay.cpp @@ -180,25 +180,22 @@ folly::coro::Task MoQRelay::onSubscribe( // TODO: we only subscribe with the downstream locations. subReq.priority = 1; subReq.groupOrder = GroupOrder::Default; - auto subRes = co_await upstreamSession->subscribe(subReq); + forwarder = + std::make_shared(subReq.fullTrackName, folly::none); + // TODO: there's a race condition that the forwarder gets upstream objects + // before we add the downstream subscriber to it, below + auto subRes = co_await upstreamSession->subscribe(subReq, forwarder); if (subRes.hasError()) { session->subscribeError({subReq.subscribeID, 502, "subscribe failed"}); co_return; } - forwarder = std::make_shared( - subReq.fullTrackName, subRes.value()->latest()); - forwarder->setGroupOrder(subRes.value()->groupOrder()); - RelaySubscription rsub( - {forwarder, - upstreamSession, - (*subRes)->subscribeID(), - folly::CancellationSource()}); - auto token = rsub.cancellationSource.getToken(); + auto latest = subRes->latest; + if (latest) { + forwarder->updateLatest(latest->group, latest->object); + } + forwarder->setGroupOrder(subRes->groupOrder); + RelaySubscription rsub({forwarder, upstreamSession, subRes->subscribeID}); subscriptions_[subReq.fullTrackName] = std::move(rsub); - folly::coro::co_withCancellation( - token, forwardTrack(subRes.value(), forwarder)) - .scheduleOn(upstreamSession->getEventBase()) - .start(); } else { forwarder = subscriptionIt->second.forwarder; } @@ -223,52 +220,6 @@ folly::coro::Task MoQRelay::onSubscribe( } } -folly::coro::Task MoQRelay::forwardTrack( - std::shared_ptr track, - std::shared_ptr forwarder) { - while (auto obj = co_await track->objects().next()) { - XLOG(DBG1) << __func__ << " new object t=" << obj.value()->fullTrackName - << " g=" << obj.value()->header.group - << " o=" << obj.value()->header.id; - folly::IOBufQueue payloadBuf{folly::IOBufQueue::cacheChainLength()}; - bool eom = false; - // TODO: this is wrong - we're publishing each object in it's own subgroup - // stream now - auto res = forwarder->beginSubgroup( - obj.value()->header.group, - obj.value()->header.subgroup, - obj.value()->header.priority); - if (!res) { - XLOG(ERR) << "Failed to begin forwarding subgroup"; - // TODO: error - } - auto subgroupPub = std::move(res.value()); - subgroupPub->beginObject( - obj.value()->header.id, *obj.value()->header.length, nullptr); - while (!eom) { - auto payload = co_await obj.value()->payloadQueue.dequeue(); - if (payload) { - payloadBuf.append(std::move(payload)); - XLOG(DBG1) << __func__ - << " object bytes, buflen now=" << payloadBuf.chainLength(); - } else { - XLOG(DBG1) << __func__ - << " object eom, buflen now=" << payloadBuf.chainLength(); - eom = true; - } - auto payloadLength = payloadBuf.chainLength(); - if (eom || payloadLength > 1280) { - subgroupPub->objectPayload(payloadBuf.move(), eom); - } else { - XLOG(DBG1) << __func__ - << " Not publishing yet payloadLength=" << payloadLength - << " eom=" << uint64_t(eom); - } - } - subgroupPub.reset(); - } -} - void MoQRelay::onUnsubscribe( Unsubscribe unsub, std::shared_ptr session) { @@ -285,7 +236,6 @@ void MoQRelay::onUnsubscribe( subscription.forwarder->latest()}); if (subscription.forwarder->empty()) { XLOG(INFO) << "Removed last subscriber for " << subscriptionIt->first; - subscription.cancellationSource.requestCancellation(); subscription.upstream->unsubscribe({subscription.subscribeID}); subscriptionIt = subscriptions_.erase(subscriptionIt); } else { @@ -342,7 +292,6 @@ void MoQRelay::removeSession(const std::shared_ptr& session) { SubscribeDoneStatusCode::SUBSCRIPTION_ENDED, "upstream disconnect", subscription.forwarder->latest()}); - subscription.cancellationSource.requestCancellation(); } else { subscription.forwarder->removeSession(session); } diff --git a/moxygen/relay/MoQRelay.h b/moxygen/relay/MoQRelay.h index c5a9a06..601c867 100644 --- a/moxygen/relay/MoQRelay.h +++ b/moxygen/relay/MoQRelay.h @@ -54,11 +54,7 @@ class MoQRelay { std::shared_ptr forwarder; std::shared_ptr upstream; SubscribeID subscribeID; - folly::CancellationSource cancellationSource; }; - folly::coro::Task forwardTrack( - std::shared_ptr track, - std::shared_ptr forwarder); TrackNamespace allowedNamespacePrefix_; folly::F14FastMap diff --git a/moxygen/samples/chat/MoQChatClient.cpp b/moxygen/samples/chat/MoQChatClient.cpp index 227bec0..49d2924 100644 --- a/moxygen/samples/chat/MoQChatClient.cpp +++ b/moxygen/samples/chat/MoQChatClient.cpp @@ -5,6 +5,7 @@ */ #include "moxygen/samples/chat/MoQChatClient.h" +#include "moxygen/ObjectReceiver.h" #include #include @@ -112,12 +113,6 @@ folly::coro::Task MoQChatClient::controlReadLoop() { latest}); } - void operator()(SubscribeDone subDone) const override { - XLOG(INFO) << "SubscribeDone is=" << subDone.subscribeID; - client_.subscribeDone(std::move(subDone)); - // TODO: should be handled in session - } - void operator()(Unsubscribe unsubscribe) const override { XLOG(INFO) << "Unsubscribe id=" << unsubscribe.subscribeID; if (client_.chatSubscribeID_ && @@ -224,6 +219,43 @@ folly::coro::Task MoQChatClient::subscribeToUser( &userTracks.emplace_back(UserTrack({deviceId, timestamp, 0})); } // now subscribe and update timestamp. + class ChatObjectHandler : public ObjectReceiverCallback { + public: + explicit ChatObjectHandler(MoQChatClient& client, std::string username) + : client_(client), username_(username) {} + ~ChatObjectHandler() override = default; + FlowControlState onObject(const ObjectHeader&, Payload payload) override { + if (payload) { + std::cout << username_ << ": "; + payload->coalesce(); + std::cout << payload->moveToFbString() << std::endl; + } + return FlowControlState::UNBLOCKED; + } + void onObjectStatus(const ObjectHeader&) override {} + void onEndOfStream() override {} + void onError(ResetStreamErrorCode error) override { + std::cout << "Stream Error=" << folly::to_underlying(error) << std::endl; + } + + void onSubscribeDone(SubscribeDone subDone) override { + XLOG(INFO) << "SubscribeDone: " << subDone.reasonPhrase; + if (subDone.statusCode != SubscribeDoneStatusCode::UNSUBSCRIBED && + client_.moqClient_.moqSession_) { + client_.moqClient_.moqSession_->unsubscribe({subDone.subscribeID}); + } + client_.subscribeDone(std::move(subDone)); + baton.post(); + } + + folly::coro::Baton baton; + + private: + MoQChatClient& client_; + std::string username_; + }; + ChatObjectHandler handler(*this, username); + auto track = co_await co_awaitTry(moqClient_.moqSession_->subscribe( {0, 0, @@ -233,7 +265,8 @@ folly::coro::Task MoQChatClient::subscribeToUser( LocationType::LatestGroup, folly::none, folly::none, - {}})); + {}}, + std::make_shared(ObjectReceiver::SUBSCRIBE, &handler))); if (track.hasException()) { // subscribe failed XLOG(ERR) << track.exception(); @@ -246,17 +279,9 @@ folly::coro::Task MoQChatClient::subscribeToUser( co_return; } - userTrackPtr->subscribeId = track->value()->subscribeID(); + userTrackPtr->subscribeId = track->value().subscribeID; userTrackPtr->timestamp = timestamp; - while (auto obj = co_await track->value()->objects().next()) { - // how to cancel this loop - auto payload = co_await obj.value()->payload(); - if (payload) { - std::cout << username << ": "; - payload->coalesce(); - std::cout << payload->moveToFbString() << std::endl; - } - } + co_await handler.baton; } void MoQChatClient::subscribeDone(SubscribeDone subDone) { diff --git a/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp b/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp index 8e156c8..2dd0410 100644 --- a/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp +++ b/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp @@ -261,11 +261,6 @@ class MoQFlvStreamerClient { return; } - void operator()(SubscribeDone /* subscribeDone */) const override { - // Not expecxted to receive this - XLOG(INFO) << "SubscribeDone"; - } - void operator()(Unsubscribe unSubs) const override { XLOG(INFO) << "Unsubscribe"; // Delete subscribe diff --git a/moxygen/samples/text-client/MoQTextClient.cpp b/moxygen/samples/text-client/MoQTextClient.cpp index 57042f7..f903e3e 100644 --- a/moxygen/samples/text-client/MoQTextClient.cpp +++ b/moxygen/samples/text-client/MoQTextClient.cpp @@ -7,6 +7,7 @@ #include #include "moxygen/MoQClient.h" #include "moxygen/MoQLocation.h" +#include "moxygen/ObjectReceiver.h" #include #include @@ -66,6 +67,31 @@ SubParams flags2params() { return result; } +class TextHandler : public ObjectReceiverCallback { + public: + ~TextHandler() override = default; + FlowControlState onObject(const ObjectHeader&, Payload payload) override { + if (payload) { + std::cout << payload->moveToFbString() << std::endl; + } + return FlowControlState::UNBLOCKED; + } + void onObjectStatus(const ObjectHeader& objHeader) override { + std::cout << "ObjectStatus=" << uint32_t(objHeader.status) << std::endl; + } + void onEndOfStream() override {} + void onError(ResetStreamErrorCode error) override { + std::cout << "Stream Error=" << folly::to_underlying(error) << std::endl; + } + + void onSubscribeDone(SubscribeDone) override { + std::cout << __func__ << std::endl; + baton.post(); + } + + folly::coro::Baton baton; +}; + class MoQTextClient { public: MoQTextClient(folly::EventBase* evb, proxygen::URL url, FullTrackName ftn) @@ -92,11 +118,14 @@ class MoQTextClient { sub.locType = LocationType::LatestObject; sub.start = folly::none; sub.end = folly::none; - auto track = co_await moqClient_.moqSession_->subscribe(sub); + subTextHandler_ = std::make_shared( + ObjectReceiver::SUBSCRIBE, &textHandler_); + auto track = + co_await moqClient_.moqSession_->subscribe(sub, subTextHandler_); if (track.hasValue()) { - subscribeID_ = track.value()->subscribeID(); + subscribeID_ = track->subscribeID; XLOG(DBG1) << "subscribeID=" << subscribeID_; - auto latest = track.value()->latest(); + auto latest = track->latest; if (latest) { XLOG(INFO) << "Latest={" << latest->group << ", " << latest->object << "}"; @@ -122,6 +151,8 @@ class MoQTextClient { XLOG(DBG1) << "FETCH start={" << range.start.group << "," << range.start.object << "} end={" << fetchEnd.group << "," << fetchEnd.object << "}"; + fetchTextHandler_ = std::make_shared( + ObjectReceiver::FETCH, &textHandler_); auto fetchTrack = co_await moqClient_.moqSession_->fetch( {SubscribeID(0), sub.fullTrackName, @@ -129,16 +160,13 @@ class MoQTextClient { sub.groupOrder, range.start, fetchEnd, - {}}); + {}}, + fetchTextHandler_); if (fetchTrack.hasError()) { XLOG(ERR) << "Fetch failed err=" << fetchTrack.error().errorCode << " reason=" << fetchTrack.error().reasonPhrase; } else { - XLOG(DBG1) << "subscribeID=" - << fetchTrack.value()->subscribeID(); - readTrack(std::move(fetchTrack.value())) - .scheduleOn(exec) - .start(); + XLOG(DBG1) << "subscribeID=" << fetchTrack.value(); } } } // else we started from current or no content - nothing to FETCH @@ -154,7 +182,6 @@ class MoQTextClient { sub.params}); } } - co_await readTrack(std::move(track.value())); } else { XLOG(INFO) << "SubscribeError id=" << track.error().subscribeID << " code=" << track.error().errorCode @@ -165,10 +192,13 @@ class MoQTextClient { XLOG(ERR) << ex.what(); co_return; } + co_await textHandler_.baton; XLOG(INFO) << __func__ << " done"; } void stop() { + textHandler_.baton.post(); + // TODO: maybe need fetchCancel + fetchTextHandler_.baton.post() moqClient_.moqSession_->unsubscribe({subscribeID_}); moqClient_.moqSession_->close(); } @@ -191,10 +221,6 @@ class MoQTextClient { {subscribeReq.subscribeID, 404, "don't care"}); } - void operator()(SubscribeDone) const override { - XLOG(INFO) << "SubscribeDone"; - } - void operator()(Goaway) const override { XLOG(INFO) << "Goaway"; client_.moqClient_.moqSession_->unsubscribe({client_.subscribeID_}); @@ -214,22 +240,12 @@ class MoQTextClient { } } - folly::coro::Task readTrack( - std::shared_ptr track) { - XLOG(INFO) << __func__; - auto g = - folly::makeGuard([func = __func__] { XLOG(INFO) << "exit " << func; }); - // TODO: check track.value()->getCancelToken() - while (auto obj = co_await track->objects().next()) { - auto payload = co_await obj.value()->payload(); - if (payload) { - std::cout << payload->moveToFbString() << std::endl; - } - } - } MoQClient moqClient_; FullTrackName fullTrackName_; SubscribeID subscribeID_{0}; + TextHandler textHandler_; + std::shared_ptr subTextHandler_; + std::shared_ptr fetchTextHandler_; }; } // namespace diff --git a/moxygen/test/MoQCodecTest.cpp b/moxygen/test/MoQCodecTest.cpp index da8f707..0c95c19 100644 --- a/moxygen/test/MoQCodecTest.cpp +++ b/moxygen/test/MoQCodecTest.cpp @@ -62,15 +62,28 @@ TEST(MoQCodec, All) { TEST(MoQCodec, AllObject) { auto allMsgs = moxygen::test::writeAllObjectMessages(); - testing::NiceMock callback; + testing::StrictMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)).Times(2); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + EXPECT_CALL( + callback, + onObjectBegin( + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + true, + false)); EXPECT_CALL( callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) - .Times(1); + onObjectStatus( + testing::_, + testing::_, + testing::_, + ObjectStatus::END_OF_TRACK_AND_GROUP)); codec.onIngress(std::move(allMsgs), true); } @@ -129,11 +142,19 @@ TEST(MoQCodec, UnderflowObjects) { folly::IOBufQueue readBuf{folly::IOBufQueue::cacheChainLength()}; readBuf.append(std::move(allMsgs)); - EXPECT_CALL(callback, onObjectHeader(testing::_)).Times(2); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); EXPECT_CALL( callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) + onObjectBegin( + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + testing::_)); + EXPECT_CALL(callback, onObjectPayload(testing::_, testing::_)) .Times(strlen("hello world")); while (!readBuf.empty()) { codec.onIngress(readBuf.split(1), false); @@ -141,6 +162,29 @@ TEST(MoQCodec, UnderflowObjects) { codec.onIngress(nullptr, true); } +TEST(MoQCodec, ObjectStreamPayloadFin) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + writeSingleObjectStream( + writeBuf, + {TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + 11}, + folly::IOBuf::copyBuffer("hello world")); + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); + EXPECT_CALL( + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, true, true)); + + codec.onIngress(writeBuf.move(), true); +} + TEST(MoQCodec, ObjectStreamPayload) { folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; writeSingleObjectStream( @@ -157,14 +201,12 @@ TEST(MoQCodec, ObjectStreamPayload) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); EXPECT_CALL( - callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) - .Times(1); + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, true, false)); codec.onIngress(writeBuf.move(), false); + EXPECT_CALL(callback, onEndOfStream()); codec.onIngress(std::unique_ptr(), true); } @@ -184,8 +226,10 @@ TEST(MoQCodec, EmptyObjectPayload) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); - + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); + EXPECT_CALL( + callback, onObjectStatus(2, 3, 4, ObjectStatus::OBJECT_NOT_EXIST)); + EXPECT_CALL(callback, onEndOfStream()); // extra coverage of underflow in header codec.onIngress(writeBuf.split(3), false); codec.onIngress(writeBuf.move(), false); @@ -221,11 +265,130 @@ TEST(MoQCodec, TruncatedObject) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + EXPECT_CALL(callback, onConnectionError(testing::_)); codec.onIngress(writeBuf.move(), true); } +TEST(MoQCodec, TruncatedObjectPayload) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + auto res = writeStreamHeader( + writeBuf, + ObjectHeader({ + TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + folly::none, + })); + res = writeObject( + writeBuf, + ObjectHeader( + {TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + 11}), + nullptr); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + + EXPECT_CALL( + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, false, false)); + codec.onIngress(writeBuf.move(), false); + EXPECT_CALL(callback, onConnectionError(testing::_)); + writeBuf.append(std::string("hello")); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, StreamTypeUnderflow) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + uint8_t big = 0xff; + writeBuf.append(&big, 1); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_UNDERFLOW)); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, UnknownStreamType) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + uint8_t bad = 0x12; + writeBuf.append(&bad, 1); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_ERROR)); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, Fetch) { + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + ObjectHeader obj{ + SubscribeID(1), + 2, + 3, + 4, + 5, + ForwardPreference::Fetch, + ObjectStatus::NORMAL, + folly::none, + }; + auto res = writeStreamHeader(writeBuf, ObjectHeader(obj)); + obj.length = 5; + res = writeObject(writeBuf, obj, folly::IOBuf::copyBuffer("hello")); + obj.id++; + obj.status = ObjectStatus::END_OF_TRACK_AND_GROUP; + obj.length = 0; + res = writeObject(writeBuf, obj, nullptr); + obj.id++; + obj.status = ObjectStatus::GROUP_NOT_EXIST; + obj.length = 0; + res = writeObject(writeBuf, obj, nullptr); + + EXPECT_CALL(callback, onFetchHeader(testing::_)); + EXPECT_CALL(callback, onObjectBegin(2, 3, 4, 5, testing::_, true, false)); + EXPECT_CALL( + callback, onObjectStatus(2, 3, 5, ObjectStatus::END_OF_TRACK_AND_GROUP)); + // object after terminal status + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_ERROR)); + codec.onIngress(writeBuf.move(), false); +} + +TEST(MoQCodec, FetchHeaderUnderflow) { + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + ObjectHeader obj{ + SubscribeID(0xffffffffffffff), + 2, + 3, + 4, + 5, + ForwardPreference::Fetch, + ObjectStatus::NORMAL, + folly::none, + }; + auto res = writeStreamHeader(writeBuf, ObjectHeader(obj)); + // only deliver first byte of fetch header + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_UNDERFLOW)); + codec.onIngress(writeBuf.splitAtMost(2), true); +} + TEST(MoQCodec, InvalidFrame) { folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; writeBuf.append(std::string(" ")); diff --git a/moxygen/test/MoQSessionTest.cpp b/moxygen/test/MoQSessionTest.cpp index fceb755..ebc6285 100644 --- a/moxygen/test/MoQSessionTest.cpp +++ b/moxygen/test/MoQSessionTest.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include using namespace moxygen; @@ -24,7 +25,6 @@ class MockControlVisitorBase { virtual ~MockControlVisitorBase() = default; virtual void onSubscribe(SubscribeRequest subscribeRequest) const = 0; virtual void onSubscribeUpdate(SubscribeUpdate subscribeUpdate) const = 0; - virtual void onSubscribeDone(SubscribeDone subscribeDone) const = 0; virtual void onUnsubscribe(Unsubscribe unsubscribe) const = 0; virtual void onFetch(Fetch fetch) const = 0; virtual void onAnnounce(Announce announce) const = 0; @@ -80,11 +80,6 @@ class MockControlVisitor : public MoQSession::ControlVisitor, onSubscribeUpdate(subscribeUpdate); } - MOCK_METHOD(void, onSubscribeDone, (SubscribeDone), (const)); - void operator()(SubscribeDone subscribeDone) const override { - onSubscribeDone(subscribeDone); - } - MOCK_METHOD(void, onUnsubscribe, (Unsubscribe), (const)); void operator()(Unsubscribe unsubscribe) const override { onUnsubscribe(unsubscribe); @@ -205,28 +200,34 @@ TEST_F(MoQSessionTest, Setup) { clientSession_->close(); } +MATCHER_P(HasChainDataLengthOf, n, "") { + return arg->computeChainDataLength() == n; +} + TEST_F(MoQSessionTest, Fetch) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + folly::coro::Baton baton; + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), true)) + .WillOnce(testing::Invoke([&] { + baton.post(); + return folly::unit; + })); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto obj = co_await handle.value()->objects().next(); - EXPECT_NE(obj.value(), nullptr); - EXPECT_EQ( - *std::get_if(&obj.value()->header.trackIdentifier), - SubscribeID(0)); - auto payload = co_await obj.value()->payload(); - EXPECT_EQ(payload->computeChainDataLength(), 100); - obj = co_await handle.value()->objects().next(); - EXPECT_FALSE(obj.has_value()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); + co_await baton; session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -258,15 +259,18 @@ TEST_F(MoQSessionTest, FetchCleanupFromStreamFin) { std::shared_ptr serverSession, std::shared_ptr& fetchPub) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // publish here now we know FETCH_OK has been received at client XCHECK(fetchPub); fetchPub->object( @@ -275,16 +279,14 @@ TEST_F(MoQSessionTest, FetchCleanupFromStreamFin) { /*objectID=*/0, moxygen::test::makeBuf(100), /*finFetch=*/true); - - auto obj = co_await handle.value()->objects().next(); - EXPECT_NE(obj.value(), nullptr); - EXPECT_EQ( - *std::get_if(&obj.value()->header.trackIdentifier), - SubscribeID(0)); - auto payload = co_await obj.value()->payload(); - EXPECT_EQ(payload->computeChainDataLength(), 100); - obj = co_await handle.value()->objects().next(); - EXPECT_FALSE(obj.has_value()); + folly::coro::Baton baton; + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), true)) + .WillOnce(testing::Invoke([&] { + baton.post(); + return folly::unit; + })); + co_await baton; session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -307,17 +309,20 @@ TEST_F(MoQSessionTest, FetchError) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 1}, AbsoluteLocation{0, 0}, - {}}); - EXPECT_TRUE(handle.hasError()); + {}}, + fetchCallback); + EXPECT_TRUE(res.hasError()); EXPECT_EQ( - handle.error().errorCode, + res.error().errorCode, folly::to_underlying(FetchErrorCode::INVALID_RANGE)); session->close(); }; @@ -332,29 +337,38 @@ TEST_F(MoQSessionTest, FetchCancel) { std::shared_ptr serverSession, std::shared_ptr& fetchPub) mutable -> folly::coro::Task { - auto handle = co_await clientSession->fetch( - {SubscribeID(0), + auto fetchCallback = + std::make_shared>(); + SubscribeID subscribeID(0); + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), false)) + .WillOnce(testing::Return(folly::unit)); + // TODO: fetchCancel removes the callback - should it also deliver a + // reset() call to the callback? + // EXPECT_CALL(*fetchCallback, reset(ResetStreamErrorCode::CANCELLED)); + auto res = co_await clientSession->fetch( + {subscribeID, FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 2}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto subscribeID = handle.value()->subscribeID(); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); clientSession->fetchCancel({subscribeID}); co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor; XCHECK(fetchPub); - auto res = fetchPub->object( + auto res2 = fetchPub->object( /*groupID=*/0, /*subgroupID=*/0, /*objectID=*/1, moxygen::test::makeBuf(100), /*finFetch=*/true); // publish after fetchCancel fails - EXPECT_TRUE(res.hasError()); + EXPECT_TRUE(res2.hasError()); clientSession->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -373,7 +387,7 @@ TEST_F(MoQSessionTest, FetchCancel) { /*subgroupID=*/0, fetch.start.object, moxygen::test::makeBuf(100), - true); + false); // published 1 object })); f(clientSession_, serverSession_, fetchPub).scheduleOn(&eventBase_).start(); @@ -384,16 +398,19 @@ TEST_F(MoQSessionTest, FetchEarlyCancel) { setupMoQSession(); auto f = [](std::shared_ptr clientSession) mutable -> folly::coro::Task { - auto handle = co_await clientSession->fetch( - {SubscribeID(0), + auto fetchCallback = + std::make_shared>(); + SubscribeID subscribeID(0); + auto res = co_await clientSession->fetch( + {subscribeID, FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 2}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto subscribeID = handle.value()->subscribeID(); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // TODO: this no-ops right now so there's nothing to verify clientSession->fetchCancel({subscribeID}); clientSession->close(); @@ -418,19 +435,32 @@ TEST_F(MoQSessionTest, FetchBadLength) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // FETCH_OK comes but the FETCH stream is reset and we timeout waiting // for a new object. + auto contract = folly::coro::makePromiseContract(); + ON_CALL( + *fetchCallback, + beginObject(testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillByDefault([&] { + contract.first.setValue(); + return folly::Expected(folly::unit); + }); EXPECT_THROW( - co_await handle.value()->objects().next(), folly::FutureTimeout); + co_await folly::coro::timeout( + std::move(contract.second), std::chrono::milliseconds(100)), + folly::FutureTimeout); session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -461,6 +491,12 @@ TEST_F(MoQSessionTest, FetchOverLimit) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { + auto fetchCallback1 = + std::make_shared>(); + auto fetchCallback2 = + std::make_shared>(); + auto fetchCallback3 = + std::make_shared>(); Fetch fetch{ SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, @@ -469,10 +505,10 @@ TEST_F(MoQSessionTest, FetchOverLimit) { AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, {}}; - auto handle = co_await session->fetch(fetch); - handle = co_await session->fetch(fetch); - handle = co_await session->fetch(fetch); - EXPECT_TRUE(handle.hasError()); + auto res = co_await session->fetch(fetch, fetchCallback1); + res = co_await session->fetch(fetch, fetchCallback2); + res = co_await session->fetch(fetch, fetchCallback3); + EXPECT_TRUE(res.hasError()); }; EXPECT_CALL(serverControl, onFetch(testing::_)) .WillOnce(testing::Invoke([this](Fetch fetch) { @@ -615,7 +651,13 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { folly::none, folly::none, {}}; - auto res = co_await clientSession->subscribe(sub); + auto trackPublisher1 = + std::make_shared>(); + auto trackPublisher2 = + std::make_shared>(); + auto trackPublisher3 = + std::make_shared>(); + auto res = co_await clientSession->subscribe(sub, trackPublisher1); co_await folly::coro::co_reschedule_on_current_executor; // This is true because initial is 2 in this test case and we grant credit // every 50%. @@ -623,15 +665,21 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { EXPECT_EQ(serverSession->maxSubscribeID(), expectedSubId); // subscribe again but this time we get a DONE - res = co_await clientSession->subscribe(sub); + EXPECT_CALL(*trackPublisher2, subscribeDone(testing::_)) + .WillOnce(testing::Return(folly::unit)); + res = co_await clientSession->subscribe(sub, trackPublisher2); co_await folly::coro::co_reschedule_on_current_executor; expectedSubId++; EXPECT_EQ(serverSession->maxSubscribeID(), expectedSubId); - // subscribe three more times, last one should fail - res = co_await clientSession->subscribe(sub); - res = co_await clientSession->subscribe(sub); - res = co_await clientSession->subscribe(sub); + // subscribe three more times, last one should fail, the first two will get + // subscribeDone via the session closure + EXPECT_CALL(*trackPublisher3, subscribeDone(testing::_)) + .WillOnce(testing::Return(folly::unit)) + .WillOnce(testing::Return(folly::unit)); + res = co_await clientSession->subscribe(sub, trackPublisher3); + res = co_await clientSession->subscribe(sub, trackPublisher3); + res = co_await clientSession->subscribe(sub, trackPublisher3); EXPECT_TRUE(res.hasError()); }(clientSession_, serverSession_) .scheduleOn(&eventBase_) @@ -671,6 +719,5 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { {}}); })); - EXPECT_CALL(clientControl, onSubscribeDone(testing::_)); eventBase_.loop(); } diff --git a/moxygen/test/Mocks.h b/moxygen/test/Mocks.h index fc0b464..9918284 100644 --- a/moxygen/test/Mocks.h +++ b/moxygen/test/Mocks.h @@ -14,16 +14,6 @@ class MockMoQCodecCallback : public MoQControlCodec::ControlCallback, MOCK_METHOD(void, onFrame, (FrameType /*frameType*/)); MOCK_METHOD(void, onClientSetup, (ClientSetup clientSetup)); MOCK_METHOD(void, onServerSetup, (ServerSetup serverSetup)); - MOCK_METHOD(void, onObjectHeader, (ObjectHeader objectHeader)); - MOCK_METHOD( - void, - onObjectPayload, - (TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom)); - MOCK_METHOD(void, onFetchHeader, (uint64_t subscribeID)); MOCK_METHOD(void, onSubscribe, (SubscribeRequest subscribeRequest)); MOCK_METHOD(void, onSubscribeUpdate, (SubscribeUpdate subscribeUpdate)); MOCK_METHOD(void, onSubscribeOk, (SubscribeOk subscribeOk)); @@ -57,6 +47,19 @@ class MockMoQCodecCallback : public MoQControlCodec::ControlCallback, MOCK_METHOD(void, onTrackStatus, (TrackStatus trackStatus)); MOCK_METHOD(void, onGoaway, (Goaway goaway)); MOCK_METHOD(void, onConnectionError, (ErrorCode error)); + + MOCK_METHOD(void, onFetchHeader, (SubscribeID)); + MOCK_METHOD(void, onSubgroup, (TrackAlias, uint64_t, uint64_t, uint8_t)); + MOCK_METHOD( + void, + onObjectBegin, + (uint64_t, uint64_t, uint64_t, uint64_t, Payload, bool, bool)); + MOCK_METHOD( + void, + onObjectStatus, + (uint64_t, uint64_t, uint64_t, ObjectStatus)); + MOCK_METHOD(void, onObjectPayload, (Payload, bool)); + MOCK_METHOD(void, onEndOfStream, ()); }; class MockTrackConsumer : public TrackConsumer {