diff --git a/include/oneapi/dpl/internal/async_impl/async_impl_hetero.h b/include/oneapi/dpl/internal/async_impl/async_impl_hetero.h index 1558919b20f..506287d09ad 100644 --- a/include/oneapi/dpl/internal/async_impl/async_impl_hetero.h +++ b/include/oneapi/dpl/internal/async_impl/async_impl_hetero.h @@ -65,11 +65,10 @@ __pattern_walk2_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__acc_mode2, _ForwardIterator2>(); auto __buf2 = __keep2(__first2, __first2 + __n); - auto __future = oneapi::dpl::__par_backend_hetero::__parallel_for( - _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), - unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, __buf1.all_view(), __buf2.all_view()); - - return __future.__make_future(__first2 + __n); + return oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), + unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, + __n, __buf1.all_view(), __buf2.all_view()) + .__make_future(__first2 + __n); } template , _ExecutionPolicy&& __exec, _For oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _ForwardIterator3>(); auto __buf3 = __keep3(__first3, __first3 + __n); - auto __future = - oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), - unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, - __buf1.all_view(), __buf2.all_view(), __buf3.all_view()); - - return __future.__make_future(__first3 + __n); + return oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), + unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, + __n, __buf1.all_view(), __buf2.all_view(), + __buf3.all_view()) + .__make_future(__first3 + __n); } template , _ExecutionPolicy& auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _Iterator2>(); auto __buf2 = __keep2(__result, __result + __n); - auto __res = oneapi::dpl::__par_backend_hetero::__parallel_transform_scan( - _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __n, __unary_op, - __init, __binary_op, _Inclusive{}); - return __res.__make_future(__result + __n); + return oneapi::dpl::__par_backend_hetero::__parallel_transform_scan( + _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __n, + __unary_op, __init, __binary_op, _Inclusive{}) + .__make_future(__result + __n); } template } }; +template +inline __make_future(sycl::event&& __event, _TData&& __data) +{ + return __future(std::move(__event), std::forward<_TData>(__data)); +} + // Invoke a callable and pass a compile-time integer based on a provided run-time integer. // The compile-time integer that will be provided to the callable is defined as the smallest // value in the integer_sequence not less than the run-time integer. For example: