diff --git a/.bazelrc b/.bazelrc index 8b9d0d7d..bfda4099 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1 +1 @@ -build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder +build --cxxopt=-std=c++17 --cxxopt=-fcoroutines-ts --host_cxxopt=-std=c++17 --host_cxxopt=-fcoroutines-ts --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder diff --git a/support-lib/cpp/Future.hpp b/support-lib/cpp/Future.hpp index f939aecb..71db25ab 100644 --- a/support-lib/cpp/Future.hpp +++ b/support-lib/cpp/Future.hpp @@ -370,7 +370,6 @@ class Future { return true; } - template struct PromiseTypeBase { Promise _promise; std::optional> _result{}; @@ -379,7 +378,9 @@ class Future { constexpr bool await_ready() const noexcept { return false; } - bool await_suspend(detail::CoroutineHandle finished) const noexcept { + template + bool await_suspend(detail::CoroutineHandle

finished) const noexcept { + static_assert(std::is_convertible_v); auto& promise_type = finished.promise(); if (*promise_type._result) { if constexpr (std::is_void_v) { @@ -406,7 +407,7 @@ class Future { } }; - struct PromiseType: PromiseTypeBase{ + struct PromiseType: PromiseTypeBase { template >> void return_value(V&& value) { this->_result.emplace(std::forward(value)); @@ -424,7 +425,7 @@ class Future { #if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT) template<> -struct Future::PromiseType : PromiseTypeBase { +struct Future::PromiseType : PromiseTypeBase { void return_void() { _result.emplace(); } diff --git a/support-lib/cpp/SharedFuture.hpp b/support-lib/cpp/SharedFuture.hpp new file mode 100644 index 00000000..938118ae --- /dev/null +++ b/support-lib/cpp/SharedFuture.hpp @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Snap, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Future.hpp" + +#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT) + +#include +#include +#include +#include +#include + +namespace djinni { + +// SharedFuture is a wrapper around djinni::Future to allow multiple consumers (i.e. like std::shared_future) +// The API is designed to be similar to djinni::Future. +template +class SharedFuture { +public: + // Create SharedFuture from Future. Runtime error if the future is already consumed. + explicit SharedFuture(Future&& future); + + // Transform into Future. + Future toFuture() const { + if (await_ready()) { + co_return await_resume(); // return stored value directly + } + co_return co_await SharedFuture(*this); // retain copy during coroutine suspension + } + + void wait() const { + waitIgnoringExceptions().wait(); + } + + decltype(auto) get() const { + wait(); + return await_resume(); + } + + template + using ResultT = std::remove_cv_t&>>>; + + // Transform the result of this future into a new future. The behavior is same as Future::then except that + // it doesn't consume the future, and can be called multiple times. + template + Future> then(Func transform) const { + auto cpy = SharedFuture(*this); // retain copy during coroutine suspension + co_await cpy.waitIgnoringExceptions(); + co_return transform(cpy); + } + + // Same as above but returns SharedFuture. + template + SharedFuture> thenShared(Func transform) const { + return SharedFuture>(then(std::move(transform))); + } + + // -- coroutine support implementation only; not intended externally -- + + bool await_ready() const { + std::scoped_lock lock(_sharedStates->mutex); + return _sharedStates->storedValue.has_value(); + } + + decltype(auto) await_resume() const { + if (!*_sharedStates->storedValue) { + std::rethrow_exception(_sharedStates->storedValue->error()); + } + if constexpr (!std::is_void_v) { + return const_cast(_sharedStates->storedValue->value()); + } + } + + bool await_suspend(detail::CoroutineHandle<> h) const; + + struct Promise : public Future::promise_type { + SharedFuture get_return_object() noexcept { + return SharedFuture(Future::promise_type::get_return_object()); + } + }; + using promise_type = Promise; + +private: + Future waitIgnoringExceptions() const { + try { + co_await *this; + } catch (...) { + // Ignore exceptions. + } + } + + struct SharedStates { + std::recursive_mutex mutex; + std::optional> storedValue = std::nullopt; + std::vector> coroutineHandles; + }; + // Use a shared_ptr to allow copying SharedFuture. + std::shared_ptr _sharedStates = std::make_shared(); +}; + +// CTAD deduction guide to construct from Future directly. +template +SharedFuture(Future&&) -> SharedFuture; + +// ------------------ Implementation ------------------ + +template +SharedFuture::SharedFuture(Future&& future) { + // `future` will invoke all continuations when it is ready. + future.then([sharedStates = _sharedStates](auto futureResult) { + std::vector toCall = [&] { + std::scoped_lock lock(sharedStates->mutex); + try { + if constexpr (std::is_void_v) { + futureResult.get(); + sharedStates->storedValue.emplace(); + } else { + sharedStates->storedValue = futureResult.get(); + } + } catch (...) { + sharedStates->storedValue = make_unexpected(std::current_exception()); + } + return std::move(sharedStates->coroutineHandles); + }(); + for (auto& handle : toCall) { + handle(); + } + }); +} + +template +bool SharedFuture::await_suspend(detail::CoroutineHandle<> h) const { + { + std::unique_lock lock(_sharedStates->mutex); + if (!_sharedStates->storedValue) { + _sharedStates->coroutineHandles.push_back(std::move(h)); + return true; + } + } + h(); + return true; +} + +} // namespace djinni + +#endif diff --git a/test-suite/BUILD b/test-suite/BUILD index fa6cb8dd..d13f2f83 100644 --- a/test-suite/BUILD +++ b/test-suite/BUILD @@ -55,6 +55,7 @@ objc_library( copts = [ "-ObjC++", "-std=c++17", + "-fcoroutines-ts" ], srcs = glob([ "generated-src/objc/**/*.mm", diff --git a/test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm b/test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm new file mode 100644 index 00000000..c0b5ae2f --- /dev/null +++ b/test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm @@ -0,0 +1,91 @@ +#import +#import + +#include "../../../support-lib/cpp/SharedFuture.hpp" + +@interface DBSharedFutureTest : XCTestCase +@end + +@implementation DBSharedFutureTest + +#ifdef DJINNI_FUTURE_HAS_COROUTINE_SUPPORT + +- (void)setUp +{ + [super setUp]; +} + +- (void)tearDown +{ + [super tearDown]; +} + +- (void)testCreateFuture +{ + djinni::SharedFuture resolvedInt(djinni::Promise::resolve(42)); + XCTAssertEqual(resolvedInt.get(), 42); + + djinni::Promise strPromise; + djinni::SharedFuture futureString(strPromise.getFuture()); + + strPromise.setValue(@"foo"); + XCTAssertEqualObjects(futureString.get(), @"foo"); +} + +- (void)testThen +{ + djinni::Promise intPromise; + djinni::SharedFuture futureInt(intPromise.getFuture()); + + auto transformedInt = futureInt.thenShared([](const auto& resolved) { return 2 * resolved.get(); }); + + intPromise.setValue(42); + XCTAssertEqual(transformedInt.get(), 84); + + // Also verify multiple consumers and chaining. + auto transformedString = futureInt.thenShared([](const auto& resolved) { return std::to_string(resolved.get()); }); + auto futurePlusOneTimesTwo = futureInt.then([](auto resolved) { return resolved.get() + 1; }).then([](auto resolved) { + return 2 * resolved.get(); + }); + auto futureStringLen = transformedString.then([](auto resolved) { return resolved.get().length(); }); + + XCTAssertEqual(transformedString.get(), std::string("42")); + XCTAssertEqual(futurePlusOneTimesTwo.get(), (42 + 1) * 2); + XCTAssertEqual(futureStringLen.get(), 2); + + XCTAssertEqual(futureInt.get(), 42); + + auto voidFuture = transformedString.thenShared([](auto) {}); + voidFuture.wait(); + + auto intFuture2 = voidFuture.thenShared([](auto) { return 43; }); + XCTAssertEqual(intFuture2.get(), 43); +} + +- (void)testException +{ + // Also verify exception handling. + djinni::Promise intPromise; + djinni::SharedFuture futureInt(intPromise.getFuture()); + + intPromise.setException(std::runtime_error("mocked")); + + XCTAssertThrows(futureInt.get()); + + auto thenResult = futureInt.then([](auto resolved) { return resolved.get(); }); + XCTAssertThrows(thenResult.get()); + + auto withExceptionHandling = futureInt.thenShared([](const auto& resolved) { + try { + return resolved.get(); + } catch (...) { + return -1; + } + }); + withExceptionHandling.wait(); + XCTAssertEqual(withExceptionHandling.get(), -1); +} + +#endif + +@end