diff --git a/moxygen/MoQClient.cpp b/moxygen/MoQClient.cpp index 0dd8949..42ecb54 100644 --- a/moxygen/MoQClient.cpp +++ b/moxygen/MoQClient.cpp @@ -190,7 +190,8 @@ void MoQClient::HTTPHandler::onError( void MoQClient::onSessionEnd(folly::Optional) { // TODO: cleanup? XLOG(DBG1) << "resetting moqSession_"; - moqSession_.reset(); + auto moqSession = std::move(moqSession_); + moqSession.reset(); CHECK(!moqSession_); } diff --git a/moxygen/MoQCodec.cpp b/moxygen/MoQCodec.cpp index eb3554e..bcd418a 100644 --- a/moxygen/MoQCodec.cpp +++ b/moxygen/MoQCodec.cpp @@ -103,12 +103,13 @@ void MoQCodec::onIngressEnd( void MoQObjectStreamCodec::onIngress( std::unique_ptr data, - bool eom) { + bool endOfStream) { onIngressStart(std::move(data)); folly::io::Cursor cursor(ingress_.front()); + bool isFetch = std::get_if(&curObjectHeader_.trackIdentifier); while (!connError_ && ((ingress_.chainLength() > 0 && !cursor.isAtEnd())/* || - (eom && parseState_ == ParseState::OBJECT_PAYLOAD_NO_LENGTH)*/)) { + (endOfStream && parseState_ == ParseState::OBJECT_PAYLOAD_NO_LENGTH)*/)) { switch (parseState_) { case ParseState::STREAM_HEADER_TYPE: { auto newCursor = cursor; @@ -146,7 +147,12 @@ void MoQObjectStreamCodec::onIngress( connError_ = res.error(); break; } - curObjectHeader_.trackIdentifier = SubscribeID(res.value()); + auto subscribeID = SubscribeID(res.value()); + curObjectHeader_.trackIdentifier = subscribeID; + isFetch = true; + if (callback_) { + callback_->onFetchHeader(subscribeID); + } parseState_ = ParseState::MULTI_OBJECT_HEADER; cursor = newCursor; break; @@ -160,6 +166,16 @@ void MoQObjectStreamCodec::onIngress( break; } curObjectHeader_ = res.value(); + auto trackAlias = + std::get_if(&curObjectHeader_.trackIdentifier); + XCHECK(trackAlias); + if (callback_) { + callback_->onSubgroup( + *trackAlias, + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.priority); + } parseState_ = ParseState::MULTI_OBJECT_HEADER; cursor = newCursor; [[fallthrough]]; @@ -174,20 +190,60 @@ void MoQObjectStreamCodec::onIngress( break; } curObjectHeader_ = res.value(); - if (callback_) { - callback_->onObjectHeader(std::move(res.value())); - } cursor = newCursor; if (curObjectHeader_.status == ObjectStatus::NORMAL) { - parseState_ = ParseState::OBJECT_PAYLOAD; + XLOG(DBG2) << "Parsing object with length, need=" + << *curObjectHeader_.length + << " have=" << cursor.totalLength(); + std::unique_ptr payload; + auto chunkLen = cursor.cloneAtMost(payload, *curObjectHeader_.length); + auto endOfObject = chunkLen == *curObjectHeader_.length; + if (endOfStream && !endOfObject) { + connError_ = ErrorCode::PARSE_ERROR; + XLOG(DBG4) << __func__ << " " << uint32_t(*connError_); + break; + } + if (callback_) { + callback_->onObjectBegin( + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.id, + *curObjectHeader_.length, + std::move(payload), + endOfObject, + endOfStream && cursor.isAtEnd()); + } + *curObjectHeader_.length -= chunkLen; + if (endOfObject) { + if (endOfStream && cursor.isAtEnd()) { + parseState_ = ParseState::STREAM_FIN_DELIVERED; + } else { + parseState_ = ParseState::MULTI_OBJECT_HEADER; + } + break; + } else { + parseState_ = ParseState::OBJECT_PAYLOAD; + } } else { - parseState_ = ParseState::MULTI_OBJECT_HEADER; + if (callback_) { + callback_->onObjectStatus( + curObjectHeader_.group, + curObjectHeader_.subgroup, + curObjectHeader_.id, + curObjectHeader_.status); + } + if (curObjectHeader_.status == ObjectStatus::END_OF_TRACK_AND_GROUP || + (!isFetch && + curObjectHeader_.status == ObjectStatus::END_OF_GROUP)) { + parseState_ = ParseState::STREAM_FIN_DELIVERED; + } else { + parseState_ = ParseState::MULTI_OBJECT_HEADER; + } break; } [[fallthrough]]; } case ParseState::OBJECT_PAYLOAD: { - auto newCursor = cursor; // need to check for bufLen == 0? std::unique_ptr payload; // TODO: skip clone and do split @@ -195,63 +251,41 @@ void MoQObjectStreamCodec::onIngress( XCHECK(curObjectHeader_.length); XLOG(DBG2) << "Parsing object with length, need=" << *curObjectHeader_.length; - if (ingress_.chainLength() > 0 && newCursor.canAdvance(1)) { - chunkLen = newCursor.cloneAtMost(payload, *curObjectHeader_.length); + if (ingress_.chainLength() > 0 && cursor.canAdvance(1)) { + chunkLen = cursor.cloneAtMost(payload, *curObjectHeader_.length); } *curObjectHeader_.length -= chunkLen; - if (eom && *curObjectHeader_.length != 0) { + if (endOfStream && *curObjectHeader_.length != 0) { connError_ = ErrorCode::PARSE_ERROR; XLOG(DBG4) << __func__ << " " << uint32_t(*connError_); break; } bool endOfObject = (*curObjectHeader_.length == 0); if (callback_ && (payload || endOfObject)) { - callback_->onObjectPayload( - curObjectHeader_.trackIdentifier, - curObjectHeader_.group, - curObjectHeader_.id, - std::move(payload), - endOfObject); + callback_->onObjectPayload(std::move(payload), endOfObject); } if (endOfObject) { parseState_ = ParseState::MULTI_OBJECT_HEADER; } - cursor = newCursor; break; } -#if 0 -// This code is no longer reachable, but I'm leaving it here in case -// the wire format changes back - case ParseState::OBJECT_PAYLOAD_NO_LENGTH: { - auto newCursor = cursor; - // need to check for bufLen == 0? - std::unique_ptr payload; - // TODO: skip clone and do split - if (ingress_.chainLength() > 0 && newCursor.canAdvance(1)) { - newCursor.cloneAtMost(payload, std::numeric_limits::max()); - } - XCHECK(!curObjectHeader_.length); - if (callback_ && (payload || eom)) { - callback_->onObjectPayload( - curObjectHeader_.trackIdentifier, - curObjectHeader_.group, - curObjectHeader_.id, - std::move(payload), - eom); - } - if (eom) { - parseState_ = ParseState::FRAME_HEADER_TYPE; - } - cursor = newCursor; + case ParseState::STREAM_FIN_DELIVERED: { + XLOG(DBG2) << "Bytes=" << cursor.totalLength() + << " remaining in STREAM_FIN_DELIVERED"; + connError_ = ErrorCode::PARSE_ERROR; + break; } -#endif } } size_t remainingLength = 0; - if (!eom && !cursor.isAtEnd()) { + if (!endOfStream && !cursor.isAtEnd()) { remainingLength = cursor.totalLength(); // must be less than 1 message } - onIngressEnd(remainingLength, eom, callback_); + if (endOfStream && parseState_ != ParseState::STREAM_FIN_DELIVERED && + !connError_ && callback_) { + callback_->onEndOfStream(); + } + onIngressEnd(remainingLength, endOfStream, callback_); } folly::Expected MoQControlCodec::parseFrame( diff --git a/moxygen/MoQCodec.h b/moxygen/MoQCodec.h index b8acf03..63f88d5 100644 --- a/moxygen/MoQCodec.h +++ b/moxygen/MoQCodec.h @@ -141,15 +141,27 @@ class MoQObjectStreamCodec : public MoQCodec { public: ~ObjectCallback() override = default; - virtual void onFetchHeader(uint64_t subscribeID) = 0; - virtual void onObjectHeader(ObjectHeader objectHeader) = 0; - - virtual void onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) = 0; + virtual void onFetchHeader(SubscribeID subscribeID) = 0; + virtual void onSubgroup( + TrackAlias alias, + uint64_t group, + uint64_t subgroup, + uint8_t priority) = 0; + virtual void onObjectBegin( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool subgroupComplete) = 0; + virtual void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) = 0; + virtual void onObjectPayload(Payload payload, bool objectComplete) = 0; + virtual void onEndOfStream() = 0; }; MoQObjectStreamCodec(ObjectCallback* callback) : callback_(callback) {} @@ -160,10 +172,6 @@ class MoQObjectStreamCodec : public MoQCodec { void onIngress(std::unique_ptr data, bool eom) override; - TrackIdentifier getTrackIdentifier() const { - return curObjectHeader_.trackIdentifier; - } - private: enum class ParseState { STREAM_HEADER_TYPE, @@ -171,6 +179,7 @@ class MoQObjectStreamCodec : public MoQCodec { FETCH_HEADER, MULTI_OBJECT_HEADER, OBJECT_PAYLOAD, + STREAM_FIN_DELIVERED, // OBJECT_PAYLOAD_NO_LENGTH }; ParseState parseState_{ParseState::STREAM_HEADER_TYPE}; diff --git a/moxygen/MoQServer.cpp b/moxygen/MoQServer.cpp index b25bf96..f443d0e 100644 --- a/moxygen/MoQServer.cpp +++ b/moxygen/MoQServer.cpp @@ -92,12 +92,6 @@ void MoQServer::ControlVisitor::operator()(Fetch fetch) const { XLOG(INFO) << "Fetch id=" << fetch.subscribeID; } -void MoQServer::ControlVisitor::operator()(SubscribeDone subscribeDone) const { - XLOG(INFO) << "SubscribeDone id=" << subscribeDone.subscribeID - << " code=" << folly::to_underlying(subscribeDone.statusCode) - << " reason=" << subscribeDone.reasonPhrase; -} - void MoQServer::ControlVisitor::operator()(Unsubscribe unsubscribe) const { XLOG(INFO) << "Unsubscribe id=" << unsubscribe.subscribeID; } diff --git a/moxygen/MoQServer.h b/moxygen/MoQServer.h index c774e8c..ef26744 100644 --- a/moxygen/MoQServer.h +++ b/moxygen/MoQServer.h @@ -43,7 +43,6 @@ class MoQServer : public MoQSession::ServerSetupCallback { void operator()(AnnounceCancel announceCancel) const override; void operator()(SubscribeAnnounces subscribeAnnounces) const override; void operator()(UnsubscribeAnnounces unsubscribeAnnounces) const override; - void operator()(SubscribeDone subscribeDone) const override; void operator()(Unsubscribe unsubscribe) const override; void operator()(TrackStatusRequest trackStatusRequest) const override; void operator()(TrackStatus trackStatus) const override; diff --git a/moxygen/MoQSession.cpp b/moxygen/MoQSession.cpp index 6a98d11..8980cbb 100644 --- a/moxygen/MoQSession.cpp +++ b/moxygen/MoQSession.cpp @@ -5,7 +5,7 @@ */ #include "moxygen/MoQSession.h" -#include +#include #include #include @@ -728,15 +728,18 @@ void MoQSession::cleanup() { pubTracks_.clear(); for (auto& subTrack : subTracks_) { subTrack.second->subscribeError( - {/*TrackHandle fills in subId*/ 0, 500, "session closed", folly::none}); + {/*TrackReceiveState fills in subId*/ 0, + 500, + "session closed", + folly::none}); } subTracks_.clear(); for (auto& fetch : fetches_) { - // TODO: there needs to be a way to queue an error in TrackHandle, both - // from here, when close races the FETCH stream, and from readLoop + // TODO: there needs to be a way to queue an error in TrackReceiveState, + // both from here, when close races the FETCH stream, and from readLoop // where we get a reset. fetch.second->fetchError( - {/*TrackHandle fills in subId*/ 0, 500, "session closed"}); + {/*TrackReceiveState fills in subId*/ 0, 500, "session closed"}); } fetches_.clear(); for (auto& [trackNamespace, pendingAnn] : pendingAnnounce_) { @@ -776,7 +779,7 @@ void MoQSession::start() { .start(); co_withCancellation( cancellationSource_.getToken(), - readLoop(StreamType::CONTROL, controlStream.readHandle)) + controlReadLoop(controlStream.readHandle)) .scheduleOn(evb_) .start(); } @@ -947,167 +950,516 @@ MoQSession::controlMessages() { } } -folly::coro::Task MoQSession::readLoop( - StreamType streamType, +folly::coro::Task MoQSession::controlReadLoop( proxygen::WebTransport::StreamReadHandle* readHandle) { XLOG(DBG1) << __func__ << " sess=" << this; auto g = folly::makeGuard([func = __func__, this] { XLOG(DBG1) << "exit " << func << " sess=" << this; }); co_await folly::coro::co_safe_point; - std::unique_ptr codec; - MoQObjectStreamCodec* objCodec = nullptr; - if (streamType == StreamType::CONTROL) { - codec = std::make_unique(dir_, this); - } else { - auto res = std::make_unique(this); - objCodec = res.get(); - codec = std::move(res); - } + MoQControlCodec codec(dir_, this); auto streamId = readHandle->getID(); - codec->setStreamId(streamId); + codec.setStreamId(streamId); - // TODO: disallow OBJECT on control streams and non-object on non-control bool fin = false; auto token = co_await folly::coro::co_current_cancellation_token; - std::shared_ptr track; - folly::CancellationSource cs; while (!fin && !token.isCancellationRequested()) { auto streamData = co_await folly::coro::co_awaitTry( readHandle->readStreamData().via(evb_)); if (streamData.hasException()) { XLOG(ERR) << streamData.exception().what() << " id=" << streamId << " sess=" << this; - // TODO: possibly erase fetch - cs.requestCancellation(); break; } else { if (streamData->data || streamData->fin) { - codec->onIngress(std::move(streamData->data), streamData->fin); - } - if (!track && objCodec) { - // TODO: this might not be set - auto trackId = objCodec->getTrackIdentifier(); - if (auto subscribeID = std::get_if(&trackId)) { - // it's fetch - track = getTrack(trackId); - track->mergeReadCancelToken( - folly::CancellationToken::merge(cs.getToken(), token)); + try { + codec.onIngress(std::move(streamData->data), streamData->fin); + } catch (const std::exception& ex) { + XLOG(FATAL) << "exception thrown from onIngress ex=" << ex.what(); } } fin = streamData->fin; XLOG_IF(DBG3, fin) << "End of stream id=" << streamId << " sess=" << this; } } - if (track) { - track->fin(); - track->setAllDataReceived(); - if (track->fetchOkReceived()) { - fetches_.erase(track->subscribeID()); - checkForCloseOnDrain(); - } + // TODO: close session on control exit +} + +MoQSession::SubscribeCallback MoQSession::getSubscribeCallback( + TrackAlias trackAlias) { + auto trackIt = subTracks_.find(trackAlias); + if (trackIt == subTracks_.end()) { + // received an object for unknown track alias + XLOG(ERR) << "unknown track alias=" << trackAlias << " sess=" << this; + return SubscribeCallback(nullptr); } + return SubscribeCallback(trackIt->second); } -std::shared_ptr MoQSession::getTrack( - TrackIdentifier trackIdentifier) { - // This can be cached in the handling stream - std::shared_ptr track; - auto alias = std::get_if(&trackIdentifier); - if (alias) { - auto trackIt = subTracks_.find(*alias); - if (trackIt == subTracks_.end()) { - // received an object for unknown track alias - XLOG(ERR) << "unknown track alias=" << alias->value << " sess=" << this; - return nullptr; - } - track = trackIt->second; - } else { - auto subscribeID = std::get(trackIdentifier); - XLOG(DBG3) << "getTrack subID=" << subscribeID; - auto trackIt = fetches_.find(subscribeID); - if (trackIt == fetches_.end()) { - // received an object for unknown subscribe ID - XLOG(ERR) << "unknown subscribe ID=" << subscribeID << " sess=" << this; - return nullptr; - } - track = trackIt->second; +MoQSession::FetchCallback MoQSession::getFetchCallback( + SubscribeID subscribeID) { + XLOG(DBG3) << "getTrack subID=" << subscribeID; + auto trackIt = fetches_.find(subscribeID); + if (trackIt == fetches_.end()) { + // received an object for unknown subscribe ID + XLOG(ERR) << "unknown subscribe ID=" << subscribeID << " sess=" << this; + return FetchCallback(nullptr); } - return track; + return FetchCallback(trackIt->second); } -void MoQSession::onObjectHeader(ObjectHeader objHeader) { - XLOG(DBG1) << "MoQSession::" << __func__ << " " << objHeader - << " sess=" << this; - auto track = getTrack(objHeader.trackIdentifier); - if (track) { - track->onObjectHeader(std::move(objHeader)); +std::shared_ptr MoQSession::SubscribeCallback::get() const { + return trackReceiveState_ ? trackReceiveState_->callback_ : nullptr; +} + +void MoQSession::SubscribeCallback::reset() { + if (trackReceiveState_) { + trackReceiveState_->callback_.reset(); } } -void MoQSession::onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) { - XLOG(DBG1) << __func__ << " sess=" << this; - auto track = getTrack(trackIdentifier); - if (track) { - track->onObjectPayload(groupID, id, std::move(payload), eom); - } -} - -void MoQSession::TrackHandle::onObjectHeader(ObjectHeader objHeader) { - XLOG(DBG1) << __func__ << " objHeader=" << objHeader - << " trackHandle=" << this; - uint64_t objectIdKey = objHeader.id; - auto status = objHeader.status; - if (status != ObjectStatus::NORMAL) { - objectIdKey |= (1ull << 63); - } - auto res = objects_.emplace( - std::piecewise_construct, - std::forward_as_tuple(std::make_pair(objHeader.group, objectIdKey)), - std::forward_as_tuple(std::make_shared())); - res.first->second->header = std::move(objHeader); - res.first->second->fullTrackName = fullTrackName_; - res.first->second->cancelToken = cancelToken_; - if (status != ObjectStatus::NORMAL) { - res.first->second->payloadQueue.enqueue(nullptr); - } - // TODO: objects_ accumulates the headers of all objects for the life of the - // track. Remove an entry from objects when returning the end of the payload, - // or the object itself for non-normal. - newObjects_.enqueue(res.first->second); -} - -void MoQSession::TrackHandle::fin() { - newObjects_.enqueue(nullptr); -} - -void MoQSession::TrackHandle::onObjectPayload( - uint64_t groupId, - uint64_t id, - std::unique_ptr payload, - bool eom) { - XLOG(DBG1) << __func__ << " g=" << groupId << " o=" << id - << " len=" << (payload ? payload->computeChainDataLength() : 0) - << " eom=" << uint64_t(eom) << " trackHandle=" << this; - auto objIt = objects_.find(std::make_pair(groupId, id)); - if (objIt == objects_.end()) { - // error; - XLOG(ERR) << "unknown object gid=" << groupId << " seq=" << id - << " trackHandle=" << this; - return; +folly::CancellationToken MoQSession::CallbackBase::getCancelToken() const { + return (trackReceiveState_) ? trackReceiveState_->getCancelToken() + : folly::CancellationToken(); +} + +std::shared_ptr MoQSession::FetchCallback::get() const { + return trackReceiveState_ ? trackReceiveState_->fetchCallback_ : nullptr; +} + +void MoQSession::FetchCallback::reset() { + if (trackReceiveState_) { + trackReceiveState_->fetchCallback_.reset(); + } +} + +namespace { +class SubgroupCodecCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + explicit SubgroupCodecCallback(std::shared_ptr session) + : session_(std::move(session)) {} + + void setCallback(MoQSession::SubscribeCallback cb) { + callback_ = std::move(cb); + } + + void onSubgroup( + TrackAlias, + uint64_t group, + uint64_t subgroup, + Priority priority) override { + auto callback = callback_.get(); + if (!callback) { + return; + } + auto res = callback->beginSubgroup(group, subgroup, priority); + if (res.hasValue()) { + subgroupCallback_ = *res; + } else { + error_ = std::move(res.error()); + } + } + + void onFetchHeader(SubscribeID /*subscribeID*/) override { + XLOG(FATAL) << "unreachable"; + } + + void onObjectBegin( + uint64_t /*group*/, + uint64_t /*subgroup*/, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool subgroupComplete) override { + auto callback = callback_.get(); + if (!subgroupCallback_ || !callback) { + return; + } + + folly::Expected res{folly::unit}; + if (objectComplete) { + res = subgroupCallback_->object( + objectID, std::move(initialPayload), subgroupComplete); + if (subgroupComplete) { + subgroupCallback_.reset(); + } + } else { + res = subgroupCallback_->beginObject( + objectID, length, std::move(initialPayload)); + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onObjectPayload(Payload payload, bool objectComplete) override { + auto callback = callback_.get(); + if (!subgroupCallback_ || !!callback) { + return; + } + + auto res = + subgroupCallback_->objectPayload(std::move(payload), objectComplete); + if (!res) { + error_ = std::move(res.error()); + } + } + + void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) override { + auto callback = callback_.get(); + if (!subgroupCallback_ || !callback) { + return; + } + + folly::Expected res{folly::unit}; + switch (status) { + case ObjectStatus::NORMAL: + XLOG(FATAL) << "Unreachable"; + break; + case ObjectStatus::OBJECT_NOT_EXIST: + res = subgroupCallback_->objectNotExists(objectID); + break; + case ObjectStatus::GROUP_NOT_EXIST: + res = callback->groupNotExists(group, subgroup, true); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_GROUP: + res = subgroupCallback_->endOfGroup(objectID); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_TRACK_AND_GROUP: + res = subgroupCallback_->endOfTrackAndGroup(objectID); + subgroupCallback_.reset(); + break; + case ObjectStatus::END_OF_SUBGROUP: + res = subgroupCallback_->endOfSubgroup(); + subgroupCallback_.reset(); + break; + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onEndOfStream() override { + auto callback = callback_.get(); + if (subgroupCallback_ && callback) { + auto res = subgroupCallback_->endOfSubgroup(); + if (!res) { + error_ = std::move(res.error()); + } + subgroupCallback_.reset(); + } + } + + void onConnectionError(ErrorCode error) override { + XLOG(ERR) << "Parse error=" << folly::to_underlying(error); + session_->close(SessionCloseErrorCode::PROTOCOL_VIOLATION); + } + + // Called by read loop on read error (eg: RESET_STREAM) + bool reset(ResetStreamErrorCode error) { + if (subgroupCallback_) { + auto callback = callback_.get(); + if (callback) { + // ignoring error from reset? + subgroupCallback_->reset(error); + } + subgroupCallback_.reset(); + return true; + } + return false; + } + + folly::Optional error() const { + return error_; + } + + private: + std::shared_ptr session_; + MoQSession::SubscribeCallback callback_; + std::shared_ptr subgroupCallback_; + folly::Optional error_; +}; + +class FetchCodecCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + explicit FetchCodecCallback(std::shared_ptr session) + : session_(std::move(session)) {} + + void setCallback(MoQSession::FetchCallback cb) { + callback_ = std::move(cb); + } + + void onSubgroup(TrackAlias, uint64_t, uint64_t, Priority) override { + XLOG(FATAL) << "unreachable"; + } + + void onFetchHeader(SubscribeID) override {} + + void onObjectBegin( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + uint64_t length, + Payload initialPayload, + bool objectComplete, + bool fetchComplete) override { + auto callback = callback_.get(); + if (!callback) { + return; + } + folly::Expected res{folly::unit}; + if (objectComplete) { + res = callback->object( + group, subgroup, objectID, std::move(initialPayload), fetchComplete); + } else { + res = callback->beginObject( + group, subgroup, objectID, length, std::move(initialPayload)); + } + if (!res) { + error_ = std::move(res.error()); + } + } + + void onObjectPayload(Payload payload, bool objectComplete) override { + auto callback = callback_.get(); + if (!callback) { + return; + } + auto res = callback->objectPayload( + std::move(payload), + /*finSubgroup=*/false); + if (!res) { + error_ = std::move(res.error()); + } else { + XCHECK_EQ(objectComplete, res.value() == ObjectPublishStatus::DONE); + } } - if (payload) { - XLOG(DBG1) << "payload enqueued trackHandle=" << this; - objIt->second->payloadQueue.enqueue(std::move(payload)); + + void onObjectStatus( + uint64_t group, + uint64_t subgroup, + uint64_t objectID, + ObjectStatus status) override { + auto callback = callback_.get(); + if (!callback) { + return; + } + folly::Expected res{folly::unit}; + switch (status) { + case ObjectStatus::NORMAL: + break; + case ObjectStatus::OBJECT_NOT_EXIST: + res = callback->objectNotExists(group, subgroup, objectID); + break; + case ObjectStatus::GROUP_NOT_EXIST: + res = callback->groupNotExists(group, subgroup, false); + break; + case ObjectStatus::END_OF_GROUP: + res = callback->endOfGroup(group, subgroup, objectID, false); + break; + case ObjectStatus::END_OF_TRACK_AND_GROUP: + res = callback->endOfTrackAndGroup(group, subgroup, objectID); + callback_.reset(); + break; + case ObjectStatus::END_OF_SUBGROUP: + break; + } + if (!res) { + error_ = std::move(res.error()); + } } - if (eom) { - XLOG(DBG1) << "eom enqueued trackHandle=" << this; - objIt->second->payloadQueue.enqueue(nullptr); + + void onEndOfStream() override { + auto callback = callback_.get(); + if (callback) { + callback->endOfFetch(); + callback_.reset(); + } + } + + void onConnectionError(ErrorCode error) override { + XLOG(ERR) << "Parse error=" << folly::to_underlying(error); + session_->close(SessionCloseErrorCode::PROTOCOL_VIOLATION); + } + + bool reset(ResetStreamErrorCode error) { + if (callback_) { + auto callback = callback_.get(); + if (callback) { + callback->reset(error); + callback_.reset(); + } + return true; + } + return false; + } + + folly::Optional error() const { + return error_; + } + + private: + std::shared_ptr session_; + MoQSession::FetchCallback callback_; + folly::Optional error_; +}; + +class DispatchCallback : public MoQObjectStreamCodec::ObjectCallback { + public: + DispatchCallback( + std::shared_ptr session, + MoQObjectStreamCodec& codec, + folly::CancellationToken& token) + : session_(session), + subgroupCallback_(session), + fetchCallback_(session), + codec_(codec), + token_(token) {} + + void onSubgroup( + TrackAlias alias, + uint64_t group, + uint64_t subgroup, + Priority priority) override { + auto callback = session_->getSubscribeCallback(alias); + if (!callback) { + error_ = MoQPublishError( + MoQPublishError::CANCELLED, "Subgroup for unknown track"); + return; + } + token_ = folly::CancellationToken::merge(token_, callback.getCancelToken()); + codec_.setCallback(&subgroupCallback_); + subgroupCallback_.setCallback(std::move(callback)); + subgroupCallback_.onSubgroup(alias, group, subgroup, priority); + } + + void onFetchHeader(SubscribeID subscribeID) override { + auto callback = session_->getFetchCallback(subscribeID); + + if (!callback) { + error_ = MoQPublishError( + MoQPublishError::CANCELLED, "Fetch response for unknown track"); + return; + } + token_ = folly::CancellationToken::merge(token_, callback.getCancelToken()); + codec_.setCallback(&fetchCallback_); + fetchCallback_.setCallback(std::move(callback)); + fetchCallback_.onFetchHeader(subscribeID); + } + + void + onObjectBegin(uint64_t, uint64_t, uint64_t, uint64_t, Payload, bool, bool) + override { + XCHECK(error_); + } + + void onObjectPayload(Payload, bool) override { + XCHECK(error_); + } + + void onObjectStatus(uint64_t, uint64_t, uint64_t, ObjectStatus) override { + XCHECK(error_); + } + + void onEndOfStream() override { + XCHECK(error_); + } + + void onConnectionError(ErrorCode error) override { + subgroupCallback_.onConnectionError(error); + } + + bool reset(ResetStreamErrorCode error) { + return (subgroupCallback_.reset(error) || fetchCallback_.reset(error)); + } + + folly::Optional error() const { + if (error_) { + return error_; + } + auto err = subgroupCallback_.error(); + if (err) { + return err; + } else { + return fetchCallback_.error(); + } + } + + private: + std::shared_ptr session_; + SubgroupCodecCallback subgroupCallback_; + FetchCodecCallback fetchCallback_; + MoQObjectStreamCodec& codec_; + folly::CancellationToken& token_; + folly::Optional error_; +}; +} // namespace + +folly::coro::Task MoQSession::unidirectionalReadLoop( + std::shared_ptr session, + proxygen::WebTransport::StreamReadHandle* readHandle) { + XLOG(DBG1) << __func__ << " sess=" << this; + auto g = folly::makeGuard([func = __func__, this] { + XLOG(DBG1) << "exit " << func << " sess=" << this; + }); + co_await folly::coro::co_safe_point; + auto token = co_await folly::coro::co_current_cancellation_token; + MoQObjectStreamCodec codec(nullptr); + DispatchCallback dcb(session, codec, /*by ref*/ token); + codec.setCallback(&dcb); + auto id = readHandle->getID(); + codec.setStreamId(id); + + bool fin = false; + while (!fin && !token.isCancellationRequested()) { + auto streamData = + co_await folly::coro::co_awaitTry(folly::coro::co_withCancellation( + token, + folly::coro::toTaskInterruptOnCancel( + readHandle->readStreamData().via(evb_)))); + if (streamData.hasException()) { + XLOG(ERR) << streamData.exception().what() << " id=" << id + << " sess=" << this; + ResetStreamErrorCode errorCode{ResetStreamErrorCode::INTERNAL_ERROR}; + auto wtEx = + streamData.tryGetExceptionObject(); + if (wtEx) { + errorCode = ResetStreamErrorCode(wtEx->error); + } else { + XLOG(ERR) << streamData.exception().what(); + } + if (!dcb.reset(errorCode)) { + XLOG(ERR) << __func__ << " terminating for unknown " + << "stream id=" << id << " sess=" << this; + } + break; + } else { + if (streamData->data || streamData->fin) { + fin = streamData->fin; + folly::Optional err; + try { + codec.onIngress(std::move(streamData->data), streamData->fin); + err = dcb.error(); + } catch (const std::exception& ex) { + err = MoQPublishError(MoQPublishError::CANCELLED, ex.what()); + } + XLOG_IF(DBG3, fin) << "End of stream id=" << id << " sess=" << this; + if (err) { + XLOG(ERR) << "Error parsing stream, err=" << err->what(); + if (!fin) { + readHandle->stopSending(/*error=*/0); + break; + } + } + } // else empty read + } } } @@ -1163,7 +1515,7 @@ void MoQSession::onSubscribeUpdate(SubscribeUpdate subscribeUpdate) { void MoQSession::onUnsubscribe(Unsubscribe unsubscribe) { XLOG(DBG1) << __func__ << " sess=" << this; // How does this impact pending subscribes? - // and open TrackHandles + // and open TrackReceiveStates controlMessages_.enqueue(std::move(unsubscribe)); } @@ -1176,8 +1528,12 @@ void MoQSession::onSubscribeOk(SubscribeOk subOk) { << " sess=" << this; return; } - subTracks_[trackAliasIt->second]->subscribeOK( - subTracks_[trackAliasIt->second], subOk.groupOrder, subOk.latest); + auto trackReceiveStateIt = subTracks_.find(trackAliasIt->second); + if (trackReceiveStateIt != subTracks_.end()) { + trackReceiveStateIt->second->subscribeOK(std::move(subOk)); + } else { + XLOG(ERR) << "Missing subTracks_ entry for alias=" << trackAliasIt->second; + } } void MoQSession::onSubscribeError(SubscribeError subErr) { @@ -1189,10 +1545,16 @@ void MoQSession::onSubscribeError(SubscribeError subErr) { << " sess=" << this; return; } - subTracks_[trackAliasIt->second]->subscribeError(std::move(subErr)); - subTracks_.erase(trackAliasIt->second); - subIdToTrackAlias_.erase(trackAliasIt); - checkForCloseOnDrain(); + + auto trackReceiveStateIt = subTracks_.find(trackAliasIt->second); + if (trackReceiveStateIt != subTracks_.end()) { + trackReceiveStateIt->second->subscribeError(std::move(subErr)); + subTracks_.erase(trackReceiveStateIt); + subIdToTrackAlias_.erase(trackAliasIt); + checkForCloseOnDrain(); + } else { + XLOG(ERR) << "Missing subTracks_ entry for alias=" << trackAliasIt->second; + } } void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { @@ -1211,17 +1573,16 @@ void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) { // TODO: there could still be objects in flight. Removing from maps now // will prevent their delivery. I think the only way to handle this is with // timeouts. - auto trackHandleIt = subTracks_.find(trackAliasIt->second); - if (trackHandleIt != subTracks_.end()) { - auto trackHandle = trackHandleIt->second; - subTracks_.erase(trackHandleIt); - trackHandle->fin(); + auto trackReceiveStateIt = subTracks_.find(trackAliasIt->second); + if (trackReceiveStateIt != subTracks_.end()) { + auto trackReceiveState = trackReceiveStateIt->second; + subTracks_.erase(trackReceiveStateIt); + trackReceiveState->subscribeDone(std::move(subscribeDone)); } else { - XLOG(DFATAL) << "trackAliasIt but no trackHandleIt for id=" + XLOG(DFATAL) << "trackAliasIt but no trackReceiveStateIt for id=" << subscribeDone.subscribeID << " sess=" << this; } subIdToTrackAlias_.erase(trackAliasIt); - controlMessages_.enqueue(std::move(subscribeDone)); checkForCloseOnDrain(); } @@ -1294,10 +1655,10 @@ void MoQSession::onFetchOk(FetchOk fetchOk) { << " sess=" << this; return; } - auto trackHandle = fetchIt->second; - trackHandle->fetchOK(trackHandle); - if (trackHandle->allDataReceived()) { - fetches_.erase(trackHandle->subscribeID()); + auto trackReceiveState = fetchIt->second; + trackReceiveState->fetchOK(); + if (trackReceiveState->allDataReceived()) { + fetches_.erase(fetchIt); } } @@ -1412,8 +1773,11 @@ void MoQSession::onGoaway(Goaway goaway) { void MoQSession::onConnectionError(ErrorCode error) { XLOG(DBG1) << __func__ << " sess=" << this; - XLOG(ERR) << "err=" << folly::to_underlying(error); - // TODO + XLOG(ERR) << "MoQCodec control stream parse error err=" + << folly::to_underlying(error); + // TODO: This error is coming from MoQCodec -- do we need a better + // error code? + close(SessionCloseErrorCode::PROTOCOL_VIOLATION); } folly::coro::Task> @@ -1505,37 +1869,9 @@ void MoQSession::subscribeAnnouncesError( controlWriteEvent_.signal(); } -folly::coro::AsyncGenerator< - std::shared_ptr> -MoQSession::TrackHandle::objects() { - XLOG(DBG1) << __func__ << " trackHandle=" << this; - auto g = - folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); - auto cancelToken = co_await folly::coro::co_current_cancellation_token; - auto mergeToken = folly::CancellationToken::merge(cancelToken, cancelToken_); - folly::EventBaseThreadTimekeeper tk(*evb_); - while (!cancelToken.isCancellationRequested()) { - auto optionalObj = newObjects_.try_dequeue(); - std::shared_ptr obj; - if (optionalObj) { - obj = *optionalObj; - } else { - obj = co_await folly::coro::co_withCancellation( - mergeToken, - folly::coro::timeout(newObjects_.dequeue(), objectTimeout_, &tk)); - } - if (!obj) { - XLOG(DBG3) << "Out of objects for trackHandle=" << this - << " id=" << subscribeID_; - break; - } - co_yield obj; - } -} - -folly::coro::Task< - folly::Expected, SubscribeError>> -MoQSession::subscribe(SubscribeRequest sub) { +folly::coro::Task MoQSession::subscribe( + SubscribeRequest sub, + std::shared_ptr callback) { XLOG(DBG1) << __func__ << " sess=" << this; auto fullTrackName = sub.fullTrackName; if (nextSubscribeID_ >= peerMaxSubscribeID_) { @@ -1557,17 +1893,20 @@ MoQSession::subscribe(SubscribeRequest sub) { controlWriteEvent_.signal(); auto res = subIdToTrackAlias_.emplace(subID, trackAlias); XCHECK(res.second) << "Duplicate subscribe ID"; - auto subTrack = subTracks_.emplace( - std::piecewise_construct, - std::forward_as_tuple(trackAlias), - std::forward_as_tuple(std::make_shared( - fullTrackName, subID, evb_, cancellationSource_.getToken()))); - - auto trackHandle = subTrack.first->second; - auto res2 = co_await trackHandle->ready(); - XLOG(DBG1) << "Subscribe ready trackHandle=" << trackHandle + auto trackReceiveState = std::make_shared( + fullTrackName, subID, callback, nullptr); + auto subTrack = subTracks_.try_emplace(trackAlias, trackReceiveState); + XCHECK(subTrack.second) << "Track alias already in use alias=" << trackAlias + << " sess=" << this; + + auto res2 = co_await trackReceiveState->ready(); + XLOG(DBG1) << "Subscribe ready trackReceiveState=" << trackReceiveState << " subscribeID=" << subID; - co_return res2; + if (res2.hasValue()) { + co_return std::move(res2.value()); + } else { + co_return folly::makeUnexpected(res2.error()); + } } std::shared_ptr MoQSession::subscribeOk(SubscribeOk subOk) { @@ -1636,6 +1975,7 @@ void MoQSession::unsubscribe(Unsubscribe unsubscribe) { << " sess=" << this; // if there are open streams for this subscription, we should STOP_SENDING // them? + trackIt->second->removeCallback(); auto res = writeUnsubscribe(controlWriteBuf_, std::move(unsubscribe)); if (!res) { XLOG(ERR) << "writeUnsubscribe failed sess=" << this; @@ -1684,7 +2024,7 @@ void MoQSession::retireSubscribeId(bool signalWriteLoop) { } } -void MoQSession::sendMaxSubscribeID(bool signal) { +void MoQSession::sendMaxSubscribeID(bool signalWriteLoop) { XLOG(DBG1) << "Issuing new maxSubscribeID=" << maxSubscribeID_ << " sess=" << this; auto res = @@ -1693,7 +2033,7 @@ void MoQSession::sendMaxSubscribeID(bool signal) { XLOG(ERR) << "writeMaxSubscribeId failed sess=" << this; return; } - if (signal) { + if (signalWriteLoop) { controlWriteEvent_.signal(); } } @@ -1740,9 +2080,9 @@ void MoQSession::subscribeUpdate(SubscribeUpdate subUpdate) { controlWriteEvent_.signal(); } -folly::coro::Task< - folly::Expected, FetchError>> -MoQSession::fetch(Fetch fetch) { +folly::coro::Task> MoQSession::fetch( + Fetch fetch, + std::shared_ptr fetchCallback) { XLOG(DBG1) << __func__ << " sess=" << this; auto g = folly::makeGuard([func = __func__] { XLOG(DBG1) << "exit " << func; }); @@ -1762,16 +2102,15 @@ MoQSession::fetch(Fetch fetch) { FetchError({subID, 500, "local write failed"})); } controlWriteEvent_.signal(); - auto subTrack = fetches_.emplace( - std::piecewise_construct, - std::forward_as_tuple(subID), - std::forward_as_tuple(std::make_shared( - fullTrackName, subID, evb_, cancellationSource_.getToken()))); - - auto trackHandle = subTrack.first->second; - trackHandle->setNewObjectTimeout(std::chrono::seconds(2)); - auto res = co_await trackHandle->fetchReady(); - XLOG(DBG1) << __func__ << " fetchReady trackHandle=" << trackHandle; + auto trackReceiveState = std::make_shared( + fullTrackName, subID, nullptr, fetchCallback); + auto fetchTrack = fetches_.try_emplace(subID, trackReceiveState); + XCHECK(fetchTrack.second) + << "SubscribeID already in use id=" << subID << " sess=" << this; + trackReceiveState->setNewObjectTimeout(std::chrono::seconds(2)); + auto res = co_await trackReceiveState->fetchReady(); + XLOG(DBG1) << __func__ + << " fetchReady trackReceiveState=" << trackReceiveState; co_return res; } @@ -1835,6 +2174,7 @@ void MoQSession::fetchCancel(FetchCancel fetchCan) { << " sess=" << this; return; } + trackIt->second->removeCallback(); auto res = writeFetchCancel(controlWriteBuf_, std::move(fetchCan)); if (!res) { XLOG(ERR) << "writeFetchCancel failed sess=" << this; @@ -1853,7 +2193,7 @@ void MoQSession::onNewUniStream(proxygen::WebTransport::StreamReadHandle* rh) { // maybe not STREAM_HEADER_SUBGROUP, but at least not control co_withCancellation( cancellationSource_.getToken(), - readLoop(StreamType::STREAM_HEADER_SUBGROUP, rh)) + unidirectionalReadLoop(shared_from_this(), rh)) .scheduleOn(evb_) .start(); } @@ -1868,8 +2208,7 @@ void MoQSession::onNewBidiStream(proxygen::WebTransport::BidiStreamHandle bh) { } else { bh.writeHandle->setPriority(0, 0, false); co_withCancellation( - cancellationSource_.getToken(), - readLoop(StreamType::CONTROL, bh.readHandle)) + cancellationSource_.getToken(), controlReadLoop(bh.readHandle)) .scheduleOn(evb_) .start(); auto mergeToken = folly::CancellationToken::merge( @@ -1908,12 +2247,9 @@ void MoQSession::onDatagram(std::unique_ptr datagram) { readBuf.trimStart(dgLength - remainingLength); auto alias = std::get_if(&res->trackIdentifier); XCHECK(alias); - auto track = getTrack(*alias); - if (track) { - auto groupID = res->group; - auto objID = res->id; - track->onObjectHeader(std::move(*res)); - track->onObjectPayload(groupID, objID, readBuf.move(), true); + auto callback = getSubscribeCallback(*alias).get(); + if (callback) { + callback->datagram(std::move(*res), readBuf.move()); } } diff --git a/moxygen/MoQSession.h b/moxygen/MoQSession.h index a2c15bb..ab328f3 100644 --- a/moxygen/MoQSession.h +++ b/moxygen/MoQSession.h @@ -24,8 +24,11 @@ namespace moxygen { class MoQSession : public MoQControlCodec::ControlCallback, - public MoQObjectStreamCodec::ObjectCallback, - public proxygen::WebTransportHandler { + public proxygen::WebTransportHandler, + public std::enable_shared_from_this { + private: + class TrackReceiveState; + public: class ServerSetupCallback { public: @@ -85,7 +88,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, SubscribeRequest, SubscribeUpdate, Unsubscribe, - SubscribeDone, Fetch, TrackStatusRequest, TrackStatus, @@ -136,10 +138,6 @@ class MoQSession : public MoQControlCodec::ControlCallback, XLOG(INFO) << "SubscribeUpdate subID=" << subscribeUpdate.subscribeID; } - virtual void operator()(SubscribeDone subscribeDone) const { - XLOG(INFO) << "SubscribeDone subID=" << subscribeDone.subscribeID; - } - virtual void operator()(Unsubscribe unsubscribe) const { XLOG(INFO) << "Unsubscribe subID=" << unsubscribe.subscribeID; } @@ -187,202 +185,22 @@ class MoQSession : public MoQControlCodec::ControlCallback, return subOrder == GroupOrder::Default ? pubOrder : subOrder; } - class TrackHandle { - public: - TrackHandle( - FullTrackName fullTrackName, - SubscribeID subscribeID, - folly::EventBase* evb, - folly::CancellationToken token) - : fullTrackName_(std::move(fullTrackName)), - subscribeID_(subscribeID), - evb_(evb), - cancelToken_(std::move(token)) { - auto contract = folly::coro::makePromiseContract< - folly::Expected, SubscribeError>>(); - promise_ = std::move(contract.first); - future_ = std::move(contract.second); - auto contract2 = folly::coro::makePromiseContract< - folly::Expected, FetchError>>(); - fetchPromise_ = std::move(contract2.first); - fetchFuture_ = std::move(contract2.second); - } - - void setTrackName(FullTrackName trackName) { - fullTrackName_ = std::move(trackName); - } - - [[nodiscard]] const FullTrackName& fullTrackName() const { - return fullTrackName_; - } - - SubscribeID subscribeID() const { - return subscribeID_; - } - - void setNewObjectTimeout(std::chrono::milliseconds objectTimeout) { - objectTimeout_ = objectTimeout; - } - - [[nodiscard]] folly::CancellationToken getCancelToken() const { - return cancelToken_; - } - - void mergeReadCancelToken(folly::CancellationToken readToken) { - cancelToken_ = folly::CancellationToken::merge(cancelToken_, readToken); - } - - void fin(); - - folly::coro::Task< - folly::Expected, SubscribeError>> - ready() { - co_return co_await std::move(future_); - } - void subscribeOK( - std::shared_ptr self, - GroupOrder order, - folly::Optional latest) { - XCHECK_EQ(self.get(), this); - groupOrder_ = order; - latest_ = std::move(latest); - promise_.setValue(std::move(self)); - } - void subscribeError(SubscribeError subErr) { - if (!promise_.isFulfilled()) { - subErr.subscribeID = subscribeID_; - promise_.setValue(folly::makeUnexpected(std::move(subErr))); - } - } - - folly::coro::Task, FetchError>> - fetchReady() { - co_return co_await std::move(fetchFuture_); - } - void fetchOK(std::shared_ptr self) { - XCHECK_EQ(self.get(), this); - XLOG(DBG1) << __func__ << " trackHandle=" << this; - fetchPromise_.setValue(std::move(self)); - } - void fetchError(FetchError fetchErr) { - if (!promise_.isFulfilled()) { - fetchErr.subscribeID = subscribeID_; - fetchPromise_.setValue(folly::makeUnexpected(std::move(fetchErr))); - } - } - - struct ObjectSource { - ObjectHeader header; - FullTrackName fullTrackName; - folly::CancellationToken cancelToken; - - folly::coro::UnboundedQueue, true, true> - payloadQueue; - - folly::coro::Task> payload() { - if (header.status != ObjectStatus::NORMAL) { - co_return nullptr; - } - folly::IOBufQueue payloadBuf{folly::IOBufQueue::cacheChainLength()}; - auto curCancelToken = - co_await folly::coro::co_current_cancellation_token; - auto mergeToken = - folly::CancellationToken::merge(curCancelToken, cancelToken); - while (!curCancelToken.isCancellationRequested()) { - std::unique_ptr buf; - auto optionalBuf = payloadQueue.try_dequeue(); - if (optionalBuf) { - buf = std::move(*optionalBuf); - } else { - buf = co_await folly::coro::co_withCancellation( - cancelToken, payloadQueue.dequeue()); - } - if (!buf) { - break; - } - payloadBuf.append(std::move(buf)); - } - co_return payloadBuf.move(); - } - }; - - void onObjectHeader(ObjectHeader objHeader); - void onObjectPayload( - uint64_t groupId, - uint64_t id, - std::unique_ptr payload, - bool eom); - - folly::coro::AsyncGenerator> objects(); - - GroupOrder groupOrder() const { - return groupOrder_; - } - - folly::Optional latest() { - return latest_; - } - - void setAllDataReceived() { - allDataReceived_ = true; - } - - bool allDataReceived() const { - return allDataReceived_; - } - - bool fetchOkReceived() const { - return fetchPromise_.isFulfilled(); - } - - private: - FullTrackName fullTrackName_; - SubscribeID subscribeID_; - folly::EventBase* evb_; - using SubscribeResult = - folly::Expected, SubscribeError>; - folly::coro::Promise promise_; - folly::coro::Future future_; - using FetchResult = - folly::Expected, FetchError>; - folly::coro::Promise fetchPromise_; - folly::coro::Future fetchFuture_; - folly:: - F14FastMap, std::shared_ptr> - objects_; - folly::coro::UnboundedQueue, true, true> - newObjects_; - GroupOrder groupOrder_; - folly::Optional latest_; - folly::CancellationToken cancelToken_; - std::chrono::milliseconds objectTimeout_{std::chrono::hours(24)}; - bool allDataReceived_{false}; - }; - - folly::coro::Task< - folly::Expected, SubscribeError>> - subscribe(SubscribeRequest sub); + using SubscribeResult = folly::Expected; + folly::coro::Task subscribe( + SubscribeRequest sub, + std::shared_ptr callback); std::shared_ptr subscribeOk(SubscribeOk subOk); void subscribeError(SubscribeError subErr); void unsubscribe(Unsubscribe unsubscribe); void subscribeUpdate(SubscribeUpdate subUpdate); - folly::coro::Task, FetchError>> - fetch(Fetch fetch); + folly::coro::Task> fetch( + Fetch fetch, + std::shared_ptr fetchCallback); std::shared_ptr fetchOk(FetchOk fetchOk); void fetchError(FetchError fetchError); void fetchCancel(FetchCancel fetchCancel); - class WebTransportException : public std::runtime_error { - public: - explicit WebTransportException( - proxygen::WebTransport::ErrorCode error, - const std::string& msg) - : std::runtime_error(msg), errorCode(error) {} - - proxygen::WebTransport::ErrorCode errorCode; - }; - class PublisherImpl { public: PublisherImpl( @@ -440,28 +258,60 @@ class MoQSession : public MoQControlCodec::ControlCallback, close(); } + // The following wrapper classes allow implementation details in the anonymous + // namespace can use parts of TrackReceiveState without making the entire + // class public. + class CallbackBase { + public: + CallbackBase() = default; + explicit CallbackBase(std::shared_ptr trackReceiveState) + : trackReceiveState_(std::move(trackReceiveState)) {} + operator bool() const { + return bool(trackReceiveState_); + } + folly::CancellationToken getCancelToken() const; + + protected: + std::shared_ptr trackReceiveState_; + }; + + class SubscribeCallback : public CallbackBase { + public: + SubscribeCallback() = default; + explicit SubscribeCallback( + std::shared_ptr trackReceiveState) + : CallbackBase(std::move(trackReceiveState)) {} + std::shared_ptr get() const; + void reset(); + }; + + class FetchCallback : public CallbackBase { + public: + FetchCallback() = default; + explicit FetchCallback(std::shared_ptr trackReceiveState) + : CallbackBase(std::move(trackReceiveState)) {} + std::shared_ptr get() const; + void reset(); + }; + + SubscribeCallback getSubscribeCallback(TrackAlias trackAlias); + FetchCallback getFetchCallback(SubscribeID subscribeID); + private: void cleanup(); folly::coro::Task controlWriteLoop( proxygen::WebTransport::StreamWriteHandle* writeHandle); - folly::coro::Task readLoop( - StreamType streamType, + folly::coro::Task controlReadLoop( + proxygen::WebTransport::StreamReadHandle* readHandle); + folly::coro::Task unidirectionalReadLoop( + std::shared_ptr session, proxygen::WebTransport::StreamReadHandle* readHandle); - std::shared_ptr getTrack(TrackIdentifier trackidentifier); void subscribeDone(SubscribeDone subDone); void onClientSetup(ClientSetup clientSetup) override; void onServerSetup(ServerSetup setup) override; - void onObjectHeader(ObjectHeader objectHeader) override; - void onObjectPayload( - TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom) override; - void onFetchHeader(uint64_t) override {} void onSubscribe(SubscribeRequest subscribeRequest) override; void onSubscribeUpdate(SubscribeUpdate subscribeUpdate) override; void onSubscribeOk(SubscribeOk subscribeOk) override; @@ -514,12 +364,135 @@ class MoQSession : public MoQControlCodec::ControlCallback, folly::coro::UnboundedQueue controlMessages_; // Subscriber State + class TrackReceiveState { + public: + TrackReceiveState( + FullTrackName fullTrackName, + SubscribeID subscribeID, + std::shared_ptr callback, + std::shared_ptr fetchCallback) + : callback_(std::move(callback)), + fetchCallback_(std::move(fetchCallback)), + fullTrackName_(std::move(fullTrackName)), + subscribeID_(subscribeID) { + auto contract = folly::coro::makePromiseContract(); + promise_ = std::move(contract.first); + future_ = std::move(contract.second); + auto contract2 = folly::coro::makePromiseContract< + folly::Expected>(); + fetchPromise_ = std::move(contract2.first); + fetchFuture_ = std::move(contract2.second); + } + + void setTrackName(FullTrackName trackName) { + fullTrackName_ = std::move(trackName); + } + + [[nodiscard]] const FullTrackName& fullTrackName() const { + return fullTrackName_; + } + + [[nodiscard]] SubscribeID subscribeID() const { + return subscribeID_; + } + + void setNewObjectTimeout(std::chrono::milliseconds objectTimeout) { + objectTimeout_ = objectTimeout; + } + + folly::CancellationToken getCancelToken() { + return cancelSource_.getToken(); + } + + folly::coro::Task ready() { + co_return co_await std::move(future_); + } + + void removeCallback() { + callback_.reset(); + fetchCallback_.reset(); + cancelSource_.requestCancellation(); + } + void subscribeOK(SubscribeOk subscribeOK) { + groupOrder_ = subscribeOK.groupOrder; + promise_.setValue(std::move(subscribeOK)); + } + void subscribeError(SubscribeError subErr) { + XLOG(DBG1) << __func__ << " trackReceiveState=" << this; + if (!promise_.isFulfilled()) { + subErr.subscribeID = subscribeID_; + promise_.setValue(folly::makeUnexpected(std::move(subErr))); + } else { + subscribeDone( + {subscribeID_, + SubscribeDoneStatusCode::INTERNAL_ERROR, + "closed locally", + folly::none}); + } + } + + void subscribeDone(SubscribeDone subDone) { + XLOG(DBG1) << __func__ << " trackReceiveState=" << this; + if (callback_) { + callback_->subscribeDone(std::move(subDone)); + } // else, unsubscribe raced with subscribeDone and callback was removed + } + + folly::coro::Task> fetchReady() { + co_return co_await std::move(fetchFuture_); + } + void fetchOK() { + XLOG(DBG1) << __func__ << " trackReceiveState=" << this; + fetchPromise_.setValue(subscribeID_); + } + void fetchError(FetchError fetchErr) { + if (!promise_.isFulfilled()) { + fetchErr.subscribeID = subscribeID_; + fetchPromise_.setValue(folly::makeUnexpected(std::move(fetchErr))); + } // there's likely a missing case here from shutdown + } + + void setAllDataReceived() { + allDataReceived_ = true; + } + + bool allDataReceived() const { + return allDataReceived_; + } + + bool fetchOkReceived() const { + return fetchPromise_.isFulfilled(); + } + + // Accessed By SubscribeCallback/FetchCallback + std::shared_ptr callback_; + std::shared_ptr fetchCallback_; + + private: + FullTrackName fullTrackName_; + SubscribeID subscribeID_; + folly::coro::Promise promise_; + folly::coro::Future future_; + using FetchResult = folly::Expected; + folly::coro::Promise fetchPromise_; + folly::coro::Future fetchFuture_; + GroupOrder groupOrder_; + std::chrono::milliseconds objectTimeout_{std::chrono::hours(24)}; + folly::CancellationSource cancelSource_; + bool allDataReceived_{false}; + }; + // Track Alias -> Track Handle - folly::F14FastMap, TrackAlias::hash> + folly::F14FastMap< + TrackAlias, + std::shared_ptr, + TrackAlias::hash> subTracks_; - folly:: - F14FastMap, SubscribeID::hash> - fetches_; + folly::F14FastMap< + SubscribeID, + std::shared_ptr, + SubscribeID::hash> + fetches_; folly::F14FastMap subIdToTrackAlias_; diff --git a/moxygen/relay/MoQRelay.cpp b/moxygen/relay/MoQRelay.cpp index e0a9bf8..088617c 100644 --- a/moxygen/relay/MoQRelay.cpp +++ b/moxygen/relay/MoQRelay.cpp @@ -180,25 +180,22 @@ folly::coro::Task MoQRelay::onSubscribe( // TODO: we only subscribe with the downstream locations. subReq.priority = 1; subReq.groupOrder = GroupOrder::Default; - auto subRes = co_await upstreamSession->subscribe(subReq); + forwarder = + std::make_shared(subReq.fullTrackName, folly::none); + // TODO: there's a race condition that the forwarder gets upstream objects + // before we add the downstream subscriber to it, below + auto subRes = co_await upstreamSession->subscribe(subReq, forwarder); if (subRes.hasError()) { session->subscribeError({subReq.subscribeID, 502, "subscribe failed"}); co_return; } - forwarder = std::make_shared( - subReq.fullTrackName, subRes.value()->latest()); - forwarder->setGroupOrder(subRes.value()->groupOrder()); - RelaySubscription rsub( - {forwarder, - upstreamSession, - (*subRes)->subscribeID(), - folly::CancellationSource()}); - auto token = rsub.cancellationSource.getToken(); + auto latest = subRes->latest; + if (latest) { + forwarder->updateLatest(latest->group, latest->object); + } + forwarder->setGroupOrder(subRes->groupOrder); + RelaySubscription rsub({forwarder, upstreamSession, subRes->subscribeID}); subscriptions_[subReq.fullTrackName] = std::move(rsub); - folly::coro::co_withCancellation( - token, forwardTrack(subRes.value(), forwarder)) - .scheduleOn(upstreamSession->getEventBase()) - .start(); } else { forwarder = subscriptionIt->second.forwarder; } @@ -224,52 +221,6 @@ folly::coro::Task MoQRelay::onSubscribe( } } -folly::coro::Task MoQRelay::forwardTrack( - std::shared_ptr track, - std::shared_ptr forwarder) { - while (auto obj = co_await track->objects().next()) { - XLOG(DBG1) << __func__ << " new object t=" << obj.value()->fullTrackName - << " g=" << obj.value()->header.group - << " o=" << obj.value()->header.id; - folly::IOBufQueue payloadBuf{folly::IOBufQueue::cacheChainLength()}; - bool eom = false; - // TODO: this is wrong - we're publishing each object in it's own subgroup - // stream now - auto res = forwarder->beginSubgroup( - obj.value()->header.group, - obj.value()->header.subgroup, - obj.value()->header.priority); - if (!res) { - XLOG(ERR) << "Failed to begin forwarding subgroup"; - // TODO: error - } - auto subgroupPub = std::move(res.value()); - subgroupPub->beginObject( - obj.value()->header.id, *obj.value()->header.length, nullptr); - while (!eom) { - auto payload = co_await obj.value()->payloadQueue.dequeue(); - if (payload) { - payloadBuf.append(std::move(payload)); - XLOG(DBG1) << __func__ - << " object bytes, buflen now=" << payloadBuf.chainLength(); - } else { - XLOG(DBG1) << __func__ - << " object eom, buflen now=" << payloadBuf.chainLength(); - eom = true; - } - auto payloadLength = payloadBuf.chainLength(); - if (eom || payloadLength > 1280) { - subgroupPub->objectPayload(payloadBuf.move(), eom); - } else { - XLOG(DBG1) << __func__ - << " Not publishing yet payloadLength=" << payloadLength - << " eom=" << uint64_t(eom); - } - } - subgroupPub.reset(); - } -} - void MoQRelay::onUnsubscribe( Unsubscribe unsub, std::shared_ptr session) { @@ -286,7 +237,6 @@ void MoQRelay::onUnsubscribe( subscription.forwarder->latest()}); if (subscription.forwarder->empty()) { XLOG(INFO) << "Removed last subscriber for " << subscriptionIt->first; - subscription.cancellationSource.requestCancellation(); subscription.upstream->unsubscribe({subscription.subscribeID}); subscriptionIt = subscriptions_.erase(subscriptionIt); } else { @@ -343,7 +293,6 @@ void MoQRelay::removeSession(const std::shared_ptr& session) { SubscribeDoneStatusCode::SUBSCRIPTION_ENDED, "upstream disconnect", subscription.forwarder->latest()}); - subscription.cancellationSource.requestCancellation(); } else { subscription.forwarder->removeSession(session); } diff --git a/moxygen/relay/MoQRelay.h b/moxygen/relay/MoQRelay.h index c5a9a06..601c867 100644 --- a/moxygen/relay/MoQRelay.h +++ b/moxygen/relay/MoQRelay.h @@ -54,11 +54,7 @@ class MoQRelay { std::shared_ptr forwarder; std::shared_ptr upstream; SubscribeID subscribeID; - folly::CancellationSource cancellationSource; }; - folly::coro::Task forwardTrack( - std::shared_ptr track, - std::shared_ptr forwarder); TrackNamespace allowedNamespacePrefix_; folly::F14FastMap diff --git a/moxygen/samples/chat/MoQChatClient.cpp b/moxygen/samples/chat/MoQChatClient.cpp index 227bec0..49d2924 100644 --- a/moxygen/samples/chat/MoQChatClient.cpp +++ b/moxygen/samples/chat/MoQChatClient.cpp @@ -5,6 +5,7 @@ */ #include "moxygen/samples/chat/MoQChatClient.h" +#include "moxygen/ObjectReceiver.h" #include #include @@ -112,12 +113,6 @@ folly::coro::Task MoQChatClient::controlReadLoop() { latest}); } - void operator()(SubscribeDone subDone) const override { - XLOG(INFO) << "SubscribeDone is=" << subDone.subscribeID; - client_.subscribeDone(std::move(subDone)); - // TODO: should be handled in session - } - void operator()(Unsubscribe unsubscribe) const override { XLOG(INFO) << "Unsubscribe id=" << unsubscribe.subscribeID; if (client_.chatSubscribeID_ && @@ -224,6 +219,43 @@ folly::coro::Task MoQChatClient::subscribeToUser( &userTracks.emplace_back(UserTrack({deviceId, timestamp, 0})); } // now subscribe and update timestamp. + class ChatObjectHandler : public ObjectReceiverCallback { + public: + explicit ChatObjectHandler(MoQChatClient& client, std::string username) + : client_(client), username_(username) {} + ~ChatObjectHandler() override = default; + FlowControlState onObject(const ObjectHeader&, Payload payload) override { + if (payload) { + std::cout << username_ << ": "; + payload->coalesce(); + std::cout << payload->moveToFbString() << std::endl; + } + return FlowControlState::UNBLOCKED; + } + void onObjectStatus(const ObjectHeader&) override {} + void onEndOfStream() override {} + void onError(ResetStreamErrorCode error) override { + std::cout << "Stream Error=" << folly::to_underlying(error) << std::endl; + } + + void onSubscribeDone(SubscribeDone subDone) override { + XLOG(INFO) << "SubscribeDone: " << subDone.reasonPhrase; + if (subDone.statusCode != SubscribeDoneStatusCode::UNSUBSCRIBED && + client_.moqClient_.moqSession_) { + client_.moqClient_.moqSession_->unsubscribe({subDone.subscribeID}); + } + client_.subscribeDone(std::move(subDone)); + baton.post(); + } + + folly::coro::Baton baton; + + private: + MoQChatClient& client_; + std::string username_; + }; + ChatObjectHandler handler(*this, username); + auto track = co_await co_awaitTry(moqClient_.moqSession_->subscribe( {0, 0, @@ -233,7 +265,8 @@ folly::coro::Task MoQChatClient::subscribeToUser( LocationType::LatestGroup, folly::none, folly::none, - {}})); + {}}, + std::make_shared(ObjectReceiver::SUBSCRIBE, &handler))); if (track.hasException()) { // subscribe failed XLOG(ERR) << track.exception(); @@ -246,17 +279,9 @@ folly::coro::Task MoQChatClient::subscribeToUser( co_return; } - userTrackPtr->subscribeId = track->value()->subscribeID(); + userTrackPtr->subscribeId = track->value().subscribeID; userTrackPtr->timestamp = timestamp; - while (auto obj = co_await track->value()->objects().next()) { - // how to cancel this loop - auto payload = co_await obj.value()->payload(); - if (payload) { - std::cout << username << ": "; - payload->coalesce(); - std::cout << payload->moveToFbString() << std::endl; - } - } + co_await handler.baton; } void MoQChatClient::subscribeDone(SubscribeDone subDone) { diff --git a/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp b/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp index 91bcf35..79acba0 100644 --- a/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp +++ b/moxygen/samples/flv_streamer_client/MoQFlvStreamerClient.cpp @@ -256,11 +256,6 @@ class MoQFlvStreamerClient { return; } - void operator()(SubscribeDone /* subscribeDone */) const override { - // Not expecxted to receive this - XLOG(INFO) << "SubscribeDone"; - } - void operator()(Unsubscribe unSubs) const override { XLOG(INFO) << "Unsubscribe"; // Delete subscribe diff --git a/moxygen/samples/text-client/MoQTextClient.cpp b/moxygen/samples/text-client/MoQTextClient.cpp index 57042f7..f903e3e 100644 --- a/moxygen/samples/text-client/MoQTextClient.cpp +++ b/moxygen/samples/text-client/MoQTextClient.cpp @@ -7,6 +7,7 @@ #include #include "moxygen/MoQClient.h" #include "moxygen/MoQLocation.h" +#include "moxygen/ObjectReceiver.h" #include #include @@ -66,6 +67,31 @@ SubParams flags2params() { return result; } +class TextHandler : public ObjectReceiverCallback { + public: + ~TextHandler() override = default; + FlowControlState onObject(const ObjectHeader&, Payload payload) override { + if (payload) { + std::cout << payload->moveToFbString() << std::endl; + } + return FlowControlState::UNBLOCKED; + } + void onObjectStatus(const ObjectHeader& objHeader) override { + std::cout << "ObjectStatus=" << uint32_t(objHeader.status) << std::endl; + } + void onEndOfStream() override {} + void onError(ResetStreamErrorCode error) override { + std::cout << "Stream Error=" << folly::to_underlying(error) << std::endl; + } + + void onSubscribeDone(SubscribeDone) override { + std::cout << __func__ << std::endl; + baton.post(); + } + + folly::coro::Baton baton; +}; + class MoQTextClient { public: MoQTextClient(folly::EventBase* evb, proxygen::URL url, FullTrackName ftn) @@ -92,11 +118,14 @@ class MoQTextClient { sub.locType = LocationType::LatestObject; sub.start = folly::none; sub.end = folly::none; - auto track = co_await moqClient_.moqSession_->subscribe(sub); + subTextHandler_ = std::make_shared( + ObjectReceiver::SUBSCRIBE, &textHandler_); + auto track = + co_await moqClient_.moqSession_->subscribe(sub, subTextHandler_); if (track.hasValue()) { - subscribeID_ = track.value()->subscribeID(); + subscribeID_ = track->subscribeID; XLOG(DBG1) << "subscribeID=" << subscribeID_; - auto latest = track.value()->latest(); + auto latest = track->latest; if (latest) { XLOG(INFO) << "Latest={" << latest->group << ", " << latest->object << "}"; @@ -122,6 +151,8 @@ class MoQTextClient { XLOG(DBG1) << "FETCH start={" << range.start.group << "," << range.start.object << "} end={" << fetchEnd.group << "," << fetchEnd.object << "}"; + fetchTextHandler_ = std::make_shared( + ObjectReceiver::FETCH, &textHandler_); auto fetchTrack = co_await moqClient_.moqSession_->fetch( {SubscribeID(0), sub.fullTrackName, @@ -129,16 +160,13 @@ class MoQTextClient { sub.groupOrder, range.start, fetchEnd, - {}}); + {}}, + fetchTextHandler_); if (fetchTrack.hasError()) { XLOG(ERR) << "Fetch failed err=" << fetchTrack.error().errorCode << " reason=" << fetchTrack.error().reasonPhrase; } else { - XLOG(DBG1) << "subscribeID=" - << fetchTrack.value()->subscribeID(); - readTrack(std::move(fetchTrack.value())) - .scheduleOn(exec) - .start(); + XLOG(DBG1) << "subscribeID=" << fetchTrack.value(); } } } // else we started from current or no content - nothing to FETCH @@ -154,7 +182,6 @@ class MoQTextClient { sub.params}); } } - co_await readTrack(std::move(track.value())); } else { XLOG(INFO) << "SubscribeError id=" << track.error().subscribeID << " code=" << track.error().errorCode @@ -165,10 +192,13 @@ class MoQTextClient { XLOG(ERR) << ex.what(); co_return; } + co_await textHandler_.baton; XLOG(INFO) << __func__ << " done"; } void stop() { + textHandler_.baton.post(); + // TODO: maybe need fetchCancel + fetchTextHandler_.baton.post() moqClient_.moqSession_->unsubscribe({subscribeID_}); moqClient_.moqSession_->close(); } @@ -191,10 +221,6 @@ class MoQTextClient { {subscribeReq.subscribeID, 404, "don't care"}); } - void operator()(SubscribeDone) const override { - XLOG(INFO) << "SubscribeDone"; - } - void operator()(Goaway) const override { XLOG(INFO) << "Goaway"; client_.moqClient_.moqSession_->unsubscribe({client_.subscribeID_}); @@ -214,22 +240,12 @@ class MoQTextClient { } } - folly::coro::Task readTrack( - std::shared_ptr track) { - XLOG(INFO) << __func__; - auto g = - folly::makeGuard([func = __func__] { XLOG(INFO) << "exit " << func; }); - // TODO: check track.value()->getCancelToken() - while (auto obj = co_await track->objects().next()) { - auto payload = co_await obj.value()->payload(); - if (payload) { - std::cout << payload->moveToFbString() << std::endl; - } - } - } MoQClient moqClient_; FullTrackName fullTrackName_; SubscribeID subscribeID_{0}; + TextHandler textHandler_; + std::shared_ptr subTextHandler_; + std::shared_ptr fetchTextHandler_; }; } // namespace diff --git a/moxygen/test/MoQCodecTest.cpp b/moxygen/test/MoQCodecTest.cpp index da8f707..3edd145 100644 --- a/moxygen/test/MoQCodecTest.cpp +++ b/moxygen/test/MoQCodecTest.cpp @@ -62,15 +62,28 @@ TEST(MoQCodec, All) { TEST(MoQCodec, AllObject) { auto allMsgs = moxygen::test::writeAllObjectMessages(); - testing::NiceMock callback; + testing::StrictMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)).Times(2); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + EXPECT_CALL( + callback, + onObjectBegin( + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + true, + false)); EXPECT_CALL( callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) - .Times(1); + onObjectStatus( + testing::_, + testing::_, + testing::_, + ObjectStatus::END_OF_TRACK_AND_GROUP)); codec.onIngress(std::move(allMsgs), true); } @@ -129,11 +142,19 @@ TEST(MoQCodec, UnderflowObjects) { folly::IOBufQueue readBuf{folly::IOBufQueue::cacheChainLength()}; readBuf.append(std::move(allMsgs)); - EXPECT_CALL(callback, onObjectHeader(testing::_)).Times(2); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); EXPECT_CALL( callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) + onObjectBegin( + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + testing::_, + testing::_)); + EXPECT_CALL(callback, onObjectPayload(testing::_, testing::_)) .Times(strlen("hello world")); while (!readBuf.empty()) { codec.onIngress(readBuf.split(1), false); @@ -141,6 +162,29 @@ TEST(MoQCodec, UnderflowObjects) { codec.onIngress(nullptr, true); } +TEST(MoQCodec, ObjectStreamPayloadFin) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + writeSingleObjectStream( + writeBuf, + {TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + 11}, + folly::IOBuf::copyBuffer("hello world")); + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); + EXPECT_CALL( + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, true, true)); + + codec.onIngress(writeBuf.move(), true); +} + TEST(MoQCodec, ObjectStreamPayload) { folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; writeSingleObjectStream( @@ -157,14 +201,12 @@ TEST(MoQCodec, ObjectStreamPayload) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); EXPECT_CALL( - callback, - onObjectPayload( - testing::_, testing::_, testing::_, testing::_, testing::_)) - .Times(1); + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, true, false)); codec.onIngress(writeBuf.move(), false); + EXPECT_CALL(callback, onEndOfStream()); codec.onIngress(std::unique_ptr(), true); } @@ -184,8 +226,10 @@ TEST(MoQCodec, EmptyObjectPayload) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); - + EXPECT_CALL(callback, onSubgroup(TrackAlias(1), 2, 3, 5)); + EXPECT_CALL( + callback, onObjectStatus(2, 3, 4, ObjectStatus::OBJECT_NOT_EXIST)); + EXPECT_CALL(callback, onEndOfStream()); // extra coverage of underflow in header codec.onIngress(writeBuf.split(3), false); codec.onIngress(writeBuf.move(), false); @@ -221,11 +265,130 @@ TEST(MoQCodec, TruncatedObject) { testing::NiceMock callback; MoQObjectStreamCodec codec(&callback); - EXPECT_CALL(callback, onObjectHeader(testing::_)); + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + EXPECT_CALL(callback, onConnectionError(testing::_)); codec.onIngress(writeBuf.move(), true); } +TEST(MoQCodec, TruncatedObjectPayload) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + auto res = writeStreamHeader( + writeBuf, + ObjectHeader({ + TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + folly::none, + })); + res = writeObject( + writeBuf, + ObjectHeader( + {TrackAlias(1), + 2, + 3, + 4, + 5, + ForwardPreference::Subgroup, + ObjectStatus::NORMAL, + 11}), + nullptr); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL( + callback, onSubgroup(testing::_, testing::_, testing::_, testing::_)); + + EXPECT_CALL( + callback, onObjectBegin(2, 3, 4, testing::_, testing::_, false, false)); + codec.onIngress(writeBuf.move(), false); + EXPECT_CALL(callback, onConnectionError(testing::_)); + writeBuf.append(std::string("hello")); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, StreamTypeUnderflow) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + uint8_t big = 0xff; + writeBuf.append(&big, 1); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_UNDERFLOW)); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, UnknownStreamType) { + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + uint8_t bad = 0x12; + writeBuf.append(&bad, 1); + testing::NiceMock callback; + MoQObjectStreamCodec codec(&callback); + + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_ERROR)); + codec.onIngress(writeBuf.move(), true); +} + +TEST(MoQCodec, Fetch) { + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + ObjectHeader obj{ + SubscribeID(1), + 2, + 3, + 4, + 5, + ForwardPreference::Fetch, + ObjectStatus::NORMAL, + folly::none, + }; + auto res = writeStreamHeader(writeBuf, ObjectHeader(obj)); + obj.length = 5; + res = writeObject(writeBuf, obj, folly::IOBuf::copyBuffer("hello")); + obj.id++; + obj.status = ObjectStatus::END_OF_TRACK_AND_GROUP; + obj.length = 0; + res = writeObject(writeBuf, obj, nullptr); + obj.id++; + obj.status = ObjectStatus::GROUP_NOT_EXIST; + obj.length = 0; + res = writeObject(writeBuf, obj, nullptr); + + EXPECT_CALL(callback, onFetchHeader(testing::_)); + EXPECT_CALL(callback, onObjectBegin(2, 3, 4, 5, testing::_, true, false)); + EXPECT_CALL( + callback, onObjectStatus(2, 3, 5, ObjectStatus::END_OF_TRACK_AND_GROUP)); + // object after terminal status + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_ERROR)); + codec.onIngress(writeBuf.move(), false); +} + +TEST(MoQCodec, FetchHeaderUnderflow) { + testing::StrictMock callback; + MoQObjectStreamCodec codec(&callback); + folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; + ObjectHeader obj{ + SubscribeID(0xffffffffffffff), + 2, + 3, + 4, + 5, + ForwardPreference::Fetch, + ObjectStatus::NORMAL, + folly::none, + }; + writeStreamHeader(writeBuf, ObjectHeader(obj)); + // only deliver first byte of fetch header + EXPECT_CALL(callback, onConnectionError(ErrorCode::PARSE_UNDERFLOW)); + codec.onIngress(writeBuf.splitAtMost(2), true); +} + TEST(MoQCodec, InvalidFrame) { folly::IOBufQueue writeBuf{folly::IOBufQueue::cacheChainLength()}; writeBuf.append(std::string(" ")); diff --git a/moxygen/test/MoQSessionTest.cpp b/moxygen/test/MoQSessionTest.cpp index fceb755..ebc6285 100644 --- a/moxygen/test/MoQSessionTest.cpp +++ b/moxygen/test/MoQSessionTest.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include using namespace moxygen; @@ -24,7 +25,6 @@ class MockControlVisitorBase { virtual ~MockControlVisitorBase() = default; virtual void onSubscribe(SubscribeRequest subscribeRequest) const = 0; virtual void onSubscribeUpdate(SubscribeUpdate subscribeUpdate) const = 0; - virtual void onSubscribeDone(SubscribeDone subscribeDone) const = 0; virtual void onUnsubscribe(Unsubscribe unsubscribe) const = 0; virtual void onFetch(Fetch fetch) const = 0; virtual void onAnnounce(Announce announce) const = 0; @@ -80,11 +80,6 @@ class MockControlVisitor : public MoQSession::ControlVisitor, onSubscribeUpdate(subscribeUpdate); } - MOCK_METHOD(void, onSubscribeDone, (SubscribeDone), (const)); - void operator()(SubscribeDone subscribeDone) const override { - onSubscribeDone(subscribeDone); - } - MOCK_METHOD(void, onUnsubscribe, (Unsubscribe), (const)); void operator()(Unsubscribe unsubscribe) const override { onUnsubscribe(unsubscribe); @@ -205,28 +200,34 @@ TEST_F(MoQSessionTest, Setup) { clientSession_->close(); } +MATCHER_P(HasChainDataLengthOf, n, "") { + return arg->computeChainDataLength() == n; +} + TEST_F(MoQSessionTest, Fetch) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + folly::coro::Baton baton; + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), true)) + .WillOnce(testing::Invoke([&] { + baton.post(); + return folly::unit; + })); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto obj = co_await handle.value()->objects().next(); - EXPECT_NE(obj.value(), nullptr); - EXPECT_EQ( - *std::get_if(&obj.value()->header.trackIdentifier), - SubscribeID(0)); - auto payload = co_await obj.value()->payload(); - EXPECT_EQ(payload->computeChainDataLength(), 100); - obj = co_await handle.value()->objects().next(); - EXPECT_FALSE(obj.has_value()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); + co_await baton; session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -258,15 +259,18 @@ TEST_F(MoQSessionTest, FetchCleanupFromStreamFin) { std::shared_ptr serverSession, std::shared_ptr& fetchPub) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // publish here now we know FETCH_OK has been received at client XCHECK(fetchPub); fetchPub->object( @@ -275,16 +279,14 @@ TEST_F(MoQSessionTest, FetchCleanupFromStreamFin) { /*objectID=*/0, moxygen::test::makeBuf(100), /*finFetch=*/true); - - auto obj = co_await handle.value()->objects().next(); - EXPECT_NE(obj.value(), nullptr); - EXPECT_EQ( - *std::get_if(&obj.value()->header.trackIdentifier), - SubscribeID(0)); - auto payload = co_await obj.value()->payload(); - EXPECT_EQ(payload->computeChainDataLength(), 100); - obj = co_await handle.value()->objects().next(); - EXPECT_FALSE(obj.has_value()); + folly::coro::Baton baton; + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), true)) + .WillOnce(testing::Invoke([&] { + baton.post(); + return folly::unit; + })); + co_await baton; session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -307,17 +309,20 @@ TEST_F(MoQSessionTest, FetchError) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 1}, AbsoluteLocation{0, 0}, - {}}); - EXPECT_TRUE(handle.hasError()); + {}}, + fetchCallback); + EXPECT_TRUE(res.hasError()); EXPECT_EQ( - handle.error().errorCode, + res.error().errorCode, folly::to_underlying(FetchErrorCode::INVALID_RANGE)); session->close(); }; @@ -332,29 +337,38 @@ TEST_F(MoQSessionTest, FetchCancel) { std::shared_ptr serverSession, std::shared_ptr& fetchPub) mutable -> folly::coro::Task { - auto handle = co_await clientSession->fetch( - {SubscribeID(0), + auto fetchCallback = + std::make_shared>(); + SubscribeID subscribeID(0); + EXPECT_CALL( + *fetchCallback, object(0, 0, 0, HasChainDataLengthOf(100), false)) + .WillOnce(testing::Return(folly::unit)); + // TODO: fetchCancel removes the callback - should it also deliver a + // reset() call to the callback? + // EXPECT_CALL(*fetchCallback, reset(ResetStreamErrorCode::CANCELLED)); + auto res = co_await clientSession->fetch( + {subscribeID, FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 2}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto subscribeID = handle.value()->subscribeID(); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); clientSession->fetchCancel({subscribeID}); co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor; XCHECK(fetchPub); - auto res = fetchPub->object( + auto res2 = fetchPub->object( /*groupID=*/0, /*subgroupID=*/0, /*objectID=*/1, moxygen::test::makeBuf(100), /*finFetch=*/true); // publish after fetchCancel fails - EXPECT_TRUE(res.hasError()); + EXPECT_TRUE(res2.hasError()); clientSession->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -373,7 +387,7 @@ TEST_F(MoQSessionTest, FetchCancel) { /*subgroupID=*/0, fetch.start.object, moxygen::test::makeBuf(100), - true); + false); // published 1 object })); f(clientSession_, serverSession_, fetchPub).scheduleOn(&eventBase_).start(); @@ -384,16 +398,19 @@ TEST_F(MoQSessionTest, FetchEarlyCancel) { setupMoQSession(); auto f = [](std::shared_ptr clientSession) mutable -> folly::coro::Task { - auto handle = co_await clientSession->fetch( - {SubscribeID(0), + auto fetchCallback = + std::make_shared>(); + SubscribeID subscribeID(0); + auto res = co_await clientSession->fetch( + {subscribeID, FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 2}, - {}}); - EXPECT_TRUE(handle.hasValue()); - auto subscribeID = handle.value()->subscribeID(); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // TODO: this no-ops right now so there's nothing to verify clientSession->fetchCancel({subscribeID}); clientSession->close(); @@ -418,19 +435,32 @@ TEST_F(MoQSessionTest, FetchBadLength) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { - auto handle = co_await session->fetch( + auto fetchCallback = + std::make_shared>(); + auto res = co_await session->fetch( {SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, 0, GroupOrder::OldestFirst, AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, - {}}); - EXPECT_TRUE(handle.hasValue()); + {}}, + fetchCallback); + EXPECT_FALSE(res.hasError()); // FETCH_OK comes but the FETCH stream is reset and we timeout waiting // for a new object. + auto contract = folly::coro::makePromiseContract(); + ON_CALL( + *fetchCallback, + beginObject(testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillByDefault([&] { + contract.first.setValue(); + return folly::Expected(folly::unit); + }); EXPECT_THROW( - co_await handle.value()->objects().next(), folly::FutureTimeout); + co_await folly::coro::timeout( + std::move(contract.second), std::chrono::milliseconds(100)), + folly::FutureTimeout); session->close(); }; EXPECT_CALL(serverControl, onFetch(testing::_)) @@ -461,6 +491,12 @@ TEST_F(MoQSessionTest, FetchOverLimit) { setupMoQSession(); auto f = [](std::shared_ptr session) mutable -> folly::coro::Task { + auto fetchCallback1 = + std::make_shared>(); + auto fetchCallback2 = + std::make_shared>(); + auto fetchCallback3 = + std::make_shared>(); Fetch fetch{ SubscribeID(0), FullTrackName{TrackNamespace{{"foo"}}, "bar"}, @@ -469,10 +505,10 @@ TEST_F(MoQSessionTest, FetchOverLimit) { AbsoluteLocation{0, 0}, AbsoluteLocation{0, 1}, {}}; - auto handle = co_await session->fetch(fetch); - handle = co_await session->fetch(fetch); - handle = co_await session->fetch(fetch); - EXPECT_TRUE(handle.hasError()); + auto res = co_await session->fetch(fetch, fetchCallback1); + res = co_await session->fetch(fetch, fetchCallback2); + res = co_await session->fetch(fetch, fetchCallback3); + EXPECT_TRUE(res.hasError()); }; EXPECT_CALL(serverControl, onFetch(testing::_)) .WillOnce(testing::Invoke([this](Fetch fetch) { @@ -615,7 +651,13 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { folly::none, folly::none, {}}; - auto res = co_await clientSession->subscribe(sub); + auto trackPublisher1 = + std::make_shared>(); + auto trackPublisher2 = + std::make_shared>(); + auto trackPublisher3 = + std::make_shared>(); + auto res = co_await clientSession->subscribe(sub, trackPublisher1); co_await folly::coro::co_reschedule_on_current_executor; // This is true because initial is 2 in this test case and we grant credit // every 50%. @@ -623,15 +665,21 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { EXPECT_EQ(serverSession->maxSubscribeID(), expectedSubId); // subscribe again but this time we get a DONE - res = co_await clientSession->subscribe(sub); + EXPECT_CALL(*trackPublisher2, subscribeDone(testing::_)) + .WillOnce(testing::Return(folly::unit)); + res = co_await clientSession->subscribe(sub, trackPublisher2); co_await folly::coro::co_reschedule_on_current_executor; expectedSubId++; EXPECT_EQ(serverSession->maxSubscribeID(), expectedSubId); - // subscribe three more times, last one should fail - res = co_await clientSession->subscribe(sub); - res = co_await clientSession->subscribe(sub); - res = co_await clientSession->subscribe(sub); + // subscribe three more times, last one should fail, the first two will get + // subscribeDone via the session closure + EXPECT_CALL(*trackPublisher3, subscribeDone(testing::_)) + .WillOnce(testing::Return(folly::unit)) + .WillOnce(testing::Return(folly::unit)); + res = co_await clientSession->subscribe(sub, trackPublisher3); + res = co_await clientSession->subscribe(sub, trackPublisher3); + res = co_await clientSession->subscribe(sub, trackPublisher3); EXPECT_TRUE(res.hasError()); }(clientSession_, serverSession_) .scheduleOn(&eventBase_) @@ -671,6 +719,5 @@ TEST_F(MoQSessionTest, MaxSubscribeID) { {}}); })); - EXPECT_CALL(clientControl, onSubscribeDone(testing::_)); eventBase_.loop(); } diff --git a/moxygen/test/Mocks.h b/moxygen/test/Mocks.h index fc0b464..9918284 100644 --- a/moxygen/test/Mocks.h +++ b/moxygen/test/Mocks.h @@ -14,16 +14,6 @@ class MockMoQCodecCallback : public MoQControlCodec::ControlCallback, MOCK_METHOD(void, onFrame, (FrameType /*frameType*/)); MOCK_METHOD(void, onClientSetup, (ClientSetup clientSetup)); MOCK_METHOD(void, onServerSetup, (ServerSetup serverSetup)); - MOCK_METHOD(void, onObjectHeader, (ObjectHeader objectHeader)); - MOCK_METHOD( - void, - onObjectPayload, - (TrackIdentifier trackIdentifier, - uint64_t groupID, - uint64_t id, - std::unique_ptr payload, - bool eom)); - MOCK_METHOD(void, onFetchHeader, (uint64_t subscribeID)); MOCK_METHOD(void, onSubscribe, (SubscribeRequest subscribeRequest)); MOCK_METHOD(void, onSubscribeUpdate, (SubscribeUpdate subscribeUpdate)); MOCK_METHOD(void, onSubscribeOk, (SubscribeOk subscribeOk)); @@ -57,6 +47,19 @@ class MockMoQCodecCallback : public MoQControlCodec::ControlCallback, MOCK_METHOD(void, onTrackStatus, (TrackStatus trackStatus)); MOCK_METHOD(void, onGoaway, (Goaway goaway)); MOCK_METHOD(void, onConnectionError, (ErrorCode error)); + + MOCK_METHOD(void, onFetchHeader, (SubscribeID)); + MOCK_METHOD(void, onSubgroup, (TrackAlias, uint64_t, uint64_t, uint8_t)); + MOCK_METHOD( + void, + onObjectBegin, + (uint64_t, uint64_t, uint64_t, uint64_t, Payload, bool, bool)); + MOCK_METHOD( + void, + onObjectStatus, + (uint64_t, uint64_t, uint64_t, ObjectStatus)); + MOCK_METHOD(void, onObjectPayload, (Payload, bool)); + MOCK_METHOD(void, onEndOfStream, ()); }; class MockTrackConsumer : public TrackConsumer {