diff --git a/include/unifex/at_coroutine_exit.hpp b/include/unifex/at_coroutine_exit.hpp index 47bacf42c..102e94dcf 100644 --- a/include/unifex/at_coroutine_exit.hpp +++ b/include/unifex/at_coroutine_exit.hpp @@ -318,6 +318,8 @@ namespace _at_coroutine_exit { } at_coroutine_exit{}; } // namespace _at_coroutine_exit +// TODO: verify that `at_coroutine_exit()` can't be used to break scheduler +// affinity by running an async task that reschedules using _at_coroutine_exit::at_coroutine_exit; } // namespace unifex diff --git a/include/unifex/task.hpp b/include/unifex/task.hpp index cc2268ca4..13cf08f72 100644 --- a/include/unifex/task.hpp +++ b/include/unifex/task.hpp @@ -23,17 +23,21 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include +#include #include #include #include +#include #if UNIFEX_NO_COROUTINES # error "Coroutine support is required to use this header" @@ -47,6 +51,9 @@ namespace unifex { namespace _task { using namespace _util; +/** + * An RAII owner of a coroutine_handle<>. + */ struct coro_holder { explicit coro_holder(coro::coroutine_handle<> h) noexcept : coro_(std::move(h)) {} @@ -69,9 +76,13 @@ struct coro_holder { coro::coroutine_handle<> coro_; }; +/** + * An RAII owner of a coroutine_handle<> that steals a low bit for an extra + * flag. + */ struct tagged_coro_holder { explicit tagged_coro_holder(coro::coroutine_handle<> h) noexcept - : coro_((std::uintptr_t)h.address()) { + : coro_(reinterpret_cast(h.address())) { UNIFEX_ASSERT(coro_); } @@ -96,40 +107,58 @@ struct tagged_coro_holder { template struct _task { + /** + * The "public facing" task<> type. + */ struct [[nodiscard]] type; }; -struct _promise_base { - struct _final_suspend_awaiter_base { - bool await_ready() noexcept { return false; } - void await_resume() noexcept {} - - friend constexpr auto tag_invoke( - tag_t, const _final_suspend_awaiter_base&) noexcept { - return blocking_kind::always_inline; - } - }; +template +struct _sa_task final { + /** + * A "scheduler-affine" task that's used as an implementation detail to mark a + * given task<> as running with the scheduler-affinity invariant maintained by + * the caller. + */ + struct [[nodiscard]] type; +}; - void transform_schedule_sender_impl_(any_scheduler newSched); +template +struct _sr_thunk_task final { + /** + * A special coroutine type that gets interposed between a task<> and its + * caller to guarantee that stop requests are delivered to the task<> on the + * task<>'s scheduler. + * + * "sr thunk" refers to "stop request thunk". + */ + struct [[nodiscard]] type; +}; - coro::suspend_always initial_suspend() noexcept { return {}; } +/** + * A base class for both task<> and sr_thunk_task<>'s promises' final-suspend + * awaitable. + */ +struct _final_suspend_awaiter_base { + bool await_ready() noexcept { return false; } - coro::coroutine_handle<> unhandled_done() noexcept { - return continuation_.done(); - } + void await_resume() noexcept {} -#ifdef UNIFEX_ENABLE_CONTINUATION_VISITATIONS - template - friend void - tag_invoke(tag_t, const _promise_base& p, Func&& func) { - visit_continuations(p.continuation_, (Func &&) func); + // TODO: we need to address always-inline awaitables + friend constexpr auto tag_invoke( + tag_t, const _final_suspend_awaiter_base&) noexcept { + return blocking_kind::always_inline; } -#endif +}; - friend inplace_stop_token - tag_invoke(tag_t, const _promise_base& p) noexcept { - return p.stoken_; - } +/** + * Common behaviour and data for task<> and sr_thunk_task<>'s promise types. + */ +struct _promise_base { + /** + * Our coroutine types are lazy so initial_suspend() returns suspend_always. + */ + coro::suspend_always initial_suspend() noexcept { return {}; } friend any_scheduler tag_invoke(tag_t, const _promise_base& p) noexcept { @@ -140,54 +169,125 @@ struct _promise_base { const tag_t&, _promise_base& p, continuation_handle<> action) noexcept { - return std::exchange(p.continuation_, (continuation_handle<> &&) action); + return std::exchange(p.continuation_, std::move(action)); } +#ifdef UNIFEX_ENABLE_CONTINUATION_VISITATIONS + template + friend void + tag_invoke(tag_t, const _promise_base& p, Func&& func) { + visit_continuations(p.continuation_, static_cast(func)); + } +#endif + inline static constexpr inline_scheduler _default_scheduler{}; + // the coroutine awaiting our completion continuation_handle<> continuation_; - inplace_stop_token stoken_; + // the scheduler we run on any_scheduler sched_{_default_scheduler}; + // a stop token from our receiver, possibly adapted through an adapter + inplace_stop_token stoken_; +}; + +/** + * The parts of a task's promise that don't depend on T. + */ +struct _task_promise_base : _promise_base { + // the implementation of the magic of co_await schedule(s); this is to be + // ripped out and replaced with something more explicit + void transform_schedule_sender_impl_(any_scheduler newSched); + + coro::coroutine_handle<> unhandled_done() noexcept { + return continuation_.done(); + } + + void register_stop_callback() noexcept {} + + friend inplace_stop_token + tag_invoke(tag_t, const _task_promise_base& p) noexcept { + return p.stoken_; + } + + // has this task<> been rescheduled onto a new scheduler? bool rescheduled_ = false; }; template -struct _return_value_or_void { +struct _result_and_unhandled_exception final { + /** + * Storage for a task or sr_thunk_task's result, plus handling for + * unhandled exceptions. + * + * This is used to share the implementation of result handling between + * type-specific promise types. + */ struct type { - template(typename Value = T)( - requires convertible_to AND constructible_from< - T, - Value>) void return_value(Value&& - value) noexcept(std:: - is_nothrow_constructible_v< - T, - Value>) { + void unhandled_exception() noexcept { expected_.reset_value(); - unifex::activate_union_member(expected_.value_, (Value &&) value); - expected_.state_ = _state::value; + unifex::activate_union_member( + expected_.exception_, std::current_exception()); + expected_.state_ = _state::exception; } + + decltype(auto) result() { + if (expected_.state_ == _state::exception) { + std::rethrow_exception(std::move(expected_.exception_).get()); + } + return std::move(expected_.value_).get(); + } + _expected expected_; }; }; +template +struct _return_value_or_void { + /** + * Provides a type-specific return_value() method to meet a promise type's + * requirements. + */ + struct type : _result_and_unhandled_exception::type { + template(typename Value = T) // + (requires convertible_to AND constructible_from) // + void return_value(Value&& value) noexcept( + std::is_nothrow_constructible_v) { + this->expected_.reset_value(); + unifex::activate_union_member( + this->expected_.value_, static_cast(value)); + this->expected_.state_ = _state::value; + } + }; +}; + template <> struct _return_value_or_void { - struct type { + /** + * Provides a return_void() method to meet a promise type's requirements. + */ + struct type : _result_and_unhandled_exception::type { void return_void() noexcept { expected_.reset_value(); unifex::activate_union_member(expected_.value_); expected_.state_ = _state::value; } - _expected expected_; }; }; +/** + * A marker type that task<> inherits from. I'd like to deprecate and remove + * this but, Hyrum's law. + */ struct _task_base {}; template -struct _promise { - struct type - : _promise_base +struct _promise final { + /** + * The promise_type for task<>; inherits _task_promise_base for common + * funcitonality, and _return_value_or_void for result handling. + */ + struct type final + : _task_promise_base , _return_value_or_void::type { using result_type = T; @@ -197,7 +297,7 @@ struct _promise { } auto final_suspend() noexcept { - struct awaiter : _final_suspend_awaiter_base { + struct awaiter final : _final_suspend_awaiter_base { #if (defined(_MSC_VER) && !defined(__clang__)) || defined(__EMSCRIPTEN__) // MSVC doesn't seem to like symmetric transfer in this final awaiter // and the Emscripten (WebAssembly) compiler doesn't support tail-calls @@ -213,29 +313,28 @@ struct _promise { return awaiter{}; } - void unhandled_exception() noexcept { - this->expected_.reset_value(); - unifex::activate_union_member( - this->expected_.exception_, std::current_exception()); - this->expected_.state_ = _state::exception; - } - template decltype(auto) await_transform(Value&& value) { - if constexpr (derived_from, _task_base>) { - // We are co_await-ing a unifex::task, which completes inline because of - // task scheduler affinity. We don't need an additional transition. - return unifex::await_transform(*this, (Value &&) value); + if constexpr (is_sender_for_v, schedule>) { + // TODO: rip this out and replace it with something more explicit + + // If we are co_await'ing a sender that is the result of calling + // schedule, do something special + return transform_schedule_sender_(static_cast(value)); + } else if constexpr (unifex::sender) { + return unifex::await_transform( + *this, + with_scheduler_affinity(static_cast(value), this->sched_)); } else if constexpr ( tag_invocable, type&, Value> || detail::_awaitable) { // Either await_transform has been customized or Value is an awaitable. // Either way, we can dispatch to the await_transform CPO, then insert a // transition back to the correct execution context if necessary. - return transform_awaitable_( - unifex::await_transform(*this, (Value &&) value)); - } else if constexpr (unifex::sender) { - return transform_sender_((Value &&) value); + return with_scheduler_affinity( + *this, + unifex::await_transform(*this, static_cast(value)), + this->sched_); } else { // Otherwise, we don't know how to await this type. Just return it and // let the compiler issue a diagnostic. @@ -243,41 +342,6 @@ struct _promise { } } - template - decltype(auto) transform_awaitable_(Awaitable&& awaitable) { - using blocking_t = decltype(blocking(awaitable)); - - if constexpr ( - !same_as && - (blocking_kind::always_inline == blocking_t{})) { - return Awaitable{(Awaitable &&) awaitable}; - } else { - return unifex::await_transform( - *this, - finally( - as_sender((Awaitable &&) awaitable), - unstoppable(schedule(this->sched_)))); - } - } - - template - decltype(auto) transform_sender_(Sender&& sndr) { - if constexpr (sender_traits< - remove_cvref_t>::is_always_scheduler_affine) { - return unifex::await_transform(*this, (Sender &&) sndr); - } else if constexpr (is_sender_for_v, schedule>) { - // If we are co_await'ing a sender that is the result of calling - // schedule, do something special - return transform_schedule_sender_((Sender &&) sndr); - } else { - // Otherwise, append a transition to the correct execution context and - // wrap the result in an awaiter: - return unifex::await_transform( - *this, - finally((Sender &&) sndr, unstoppable(schedule(this->sched_)))); - } - } - // co_await schedule(sched) is magical. It does the following: // - transitions execution context // - updates the coroutine's current scheduler @@ -295,18 +359,165 @@ struct _promise { // Return the inner sender, appropriately wrapped in an awaitable: return unifex::await_transform(*this, std::move(snd).base()); } + }; +}; - decltype(auto) result() { - if (this->expected_.state_ == _state::exception) { - std::rethrow_exception(std::move(this->expected_.exception_).get()); +struct _sr_thunk_promise_base : _promise_base { + coro::coroutine_handle<> unhandled_done() noexcept { + callback_.destruct(); + + if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + return continuation_.done(); + } else { + return coro::noop_coroutine(); + } + } + + friend inplace_stop_token + tag_invoke(tag_t, const _sr_thunk_promise_base& p) noexcept { + return p.stopSource_.get_token(); + } + + mutable inplace_stop_source stopSource_; + + struct deferred_stop_request final { + _sr_thunk_promise_base* self; + + auto operator()() noexcept -> decltype(unstoppable(on( + self->sched_, + just(&self->stopSource_) | then(&inplace_stop_source::request_stop)))) { + return unstoppable(on( + self->sched_, + just(&self->stopSource_) | then(&inplace_stop_source::request_stop))); + } + }; + + using sender_t = + decltype(unifex::defer(UNIFEX_DECLVAL(deferred_stop_request))); + + struct receiver_t { + _sr_thunk_promise_base* self; + + void set_value(bool) noexcept { + if (self->refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + self->continuation_.handle().resume(); } - return std::move(this->expected_.value_).get(); + } + void set_error(std::exception_ptr) noexcept { std::terminate(); } + void set_done() noexcept { std::terminate(); } + }; + + using op_t = connect_result_t; + + op_t stopOperation_{ + connect(unifex::defer(deferred_stop_request{this}), receiver_t{this})}; + + struct stop_callback { + _sr_thunk_promise_base* self; + + void operator()() noexcept { + if (self->refCount_.fetch_add(1, std::memory_order_relaxed) == 0) { + return; + } + + unifex::start(self->stopOperation_); + } + }; + + using stop_callback_t = + typename inplace_stop_token::callback_type; + + manual_lifetime callback_; + + std::atomic refCount_{1}; + + void register_stop_callback() noexcept { + callback_.construct(stoken_, stop_callback{this}); + } +}; + +template +struct _sr_thunk_promise final { + /** + * The promise_type for an sr_thunk_task. + * + * This type has two main responsibilities: + * - register a stop callback on our receiver's stop token that, when + * invoked, executes an async operation that forwards the request to the + * nested stop source on the correct scheduler; and + * - ensure that, if the async stop request is ever started, we wait for + * *both* the async stop request *and* the nested operation to complete + * before continuing our continuation. + * + * The async stop request delivery is handled in _sr_thunk_promise_base (our + * base class), and our final-awaiter handles coordinating who continues our + * continuation. + */ + struct type final + : _sr_thunk_promise_base + , _return_value_or_void::type { + using result_type = T; + + typename _sr_thunk_task::type get_return_object() noexcept { + return typename _sr_thunk_task::type{ + coro::coroutine_handle::from_promise(*this)}; + } + + auto final_suspend() noexcept { + struct awaiter final : _final_suspend_awaiter_base { +#if (defined(_MSC_VER) && !defined(__clang__)) || defined(__EMSCRIPTEN__) + // MSVC doesn't seem to like symmetric transfer in this final awaiter + // and the Emscripten (WebAssembly) compiler doesn't support tail-calls + void await_suspend(coro::coroutine_handle h) noexcept { + auto& p = h.promise(); + + p.callback_.destruct(); + + // if we're last to complete, continue our continuation; otherwise do + // nothing and wait for the async stop request to do it + if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + return h.promise().continuation_.handle().resume(); + } + // nothing + } +#else + coro::coroutine_handle<> + await_suspend(coro::coroutine_handle h) noexcept { + auto& p = h.promise(); + + p.callback_.destruct(); + + // if we're last to complete, continue our continuation; otherwise do + // nothing and wait for the async stop request to do it + if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + return h.promise().continuation_.handle(); + } else { + return coro::noop_coroutine(); + } + } +#endif + }; + + return awaiter{}; + } + + template + decltype(auto) await_transform(Value&& value) { + return unifex::await_transform(*this, static_cast(value)); } }; }; template -struct _awaiter { +struct _awaiter final { + /** + * An awaitable type that knows how to await a task<>, sa_task<>, or + * sr_thunk_task<> from a coroutine whose promise_type is OtherPromise. + * + * We inherit tagged_coro_holder to be able to distinguish whether or not the + * awaited task has been started, and thus whether we need to clean it up in + * our destructor. + */ struct type : tagged_coro_holder { using result_type = typename ThisPromise::result_type; @@ -349,6 +560,9 @@ struct _awaiter { } else { promise.stoken_ = get_stop_token(h.promise()); } + + promise.register_stop_callback(); + return thisCoro; } @@ -391,6 +605,40 @@ struct _awaiter { }; }; +/** + * The coroutine type that ensures stop requests are delivered to nested task<>s + * on the right scheduler. + */ +template +struct _sr_thunk_task::type final : coro_holder { + using promise_type = typename _sr_thunk_promise::type; + friend promise_type; + +private: + template + using awaiter = typename _awaiter::type; + + explicit type(coro::coroutine_handle h) noexcept + : coro_holder(h) {} + + template + friend awaiter + tag_invoke(tag_t, Promise&, type&& t) noexcept { + return awaiter{std::exchange(t.coro_, {})}; + } +}; + +/** + * Await the given sa_task<> in a context that will deliver stop requests from + * the receiver on the expected scheduler. + */ +template +typename _sr_thunk_task::type +inject_stop_request_thunk(typename _sa_task::type awaitable) { + // I wonder if we could do better than hopping through this extra coroutine + co_return co_await std::move(awaitable); +} + template struct _task::type : _task_base @@ -431,12 +679,65 @@ struct _task::type } private: - template - using awaiter = typename _awaiter::type; - explicit type(coro::coroutine_handle h) noexcept : coro_holder(h) {} + template + friend auto tag_invoke(tag_t, Promise& p, type&& t) { + // we don't know whether our consumer will enforce the scheduler-affinity + // invariants so we need to ensure that stop requests are delivered on the + // right scheduler + return unifex::await_transform( + p, inject_stop_request_thunk(std::move(t))); + } + + template + friend auto tag_invoke(tag_t, type&& t, Receiver&& r) { + using stoken_t = stop_token_type_t; + + if constexpr (is_stop_never_possible_v) { + // NOTE: we *don't* need to worry about stop requests if the receiver's + // stop token can't make such requests! + using sa_task = typename _sa_task::type; + + return connect(sa_task{std::move(t)}, static_cast(r)); + } else { + // connect_awaitable will get the awaitable to connect by invoking + // await_transform so we can guarantee stop requests are delivered + // on the right scheduler by relying on await_transform to do that + return connect_awaitable(std::move(t), static_cast(r)); + } + } + + template + friend typename _sa_task::type tag_invoke( + tag_t, type&& task, Scheduler&&) noexcept { + return {std::move(task)}; + } +}; + +/** + * A "sheduler-affine" task<>; an sa_task<> is the same as a task<> except that + * it expects its consumer to maintain the scheduler-affinity invariant and so + * it can avoid the overhead required to establish the invariant itself. + * + * The main difference is that await_transform doesn't indirect through + * inject_stop_request_thunk. + */ +template +struct _sa_task::type final : public _task::type { + using base = typename _task::type; + + type(base&& t) noexcept : base(std::move(t)) {} + + template + using awaiter = + typename _awaiter::type; + + // given that we're awaited in a scheduler-affine context, we are ourselves + // scheduler-affine + static constexpr bool is_always_scheduler_affine = true; + template friend awaiter tag_invoke(tag_t, Promise&, type&& t) noexcept { @@ -444,8 +745,10 @@ struct _task::type } template - friend auto tag_invoke(tag_t, type&& t, Receiver&& r) { - return unifex::connect_awaitable((type &&) t, (Receiver &&) r); + friend auto + tag_invoke(tag_t, type&& t, Receiver&& r) noexcept( + noexcept(connect_awaitable(std::move(t), static_cast(r)))) { + return connect_awaitable(std::move(t), static_cast(r)); } }; diff --git a/include/unifex/with_scheduler_affinity.hpp b/include/unifex/with_scheduler_affinity.hpp new file mode 100644 index 000000000..63de21ce6 --- /dev/null +++ b/include/unifex/with_scheduler_affinity.hpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace unifex { +namespace _wsa { + +template +struct _wsa_sender_wrapper final { + class type; +}; + +template +static auto +_make_sender(Sender&& sender, Scheduler&& scheduler) noexcept(noexcept(finally( + static_cast(sender), + unstoppable(schedule(static_cast(scheduler)))))) { + return finally( + static_cast(sender), + unstoppable(schedule(static_cast(scheduler)))); +} + +template +using wsa_sender_wrapper = + typename _wsa_sender_wrapper::type; + +template +class _wsa_sender_wrapper::type final { + using sender_t = + decltype(_make_sender(UNIFEX_DECLVAL(Sender), UNIFEX_DECLVAL(Scheduler))); + + sender_t sender_; + +public: + template < + template + typename Variant, + template + typename Tuple> + using value_types = sender_value_types_t; + + template