-
-
Notifications
You must be signed in to change notification settings - Fork 856
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
310 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
using Cysharp.Threading.Tasks; | ||
using FluentAssertions; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
using Xunit; | ||
|
||
namespace NetCoreTests | ||
{ | ||
public class WhenEachTest | ||
{ | ||
[Fact] | ||
public async Task Each() | ||
{ | ||
var a = Delay(1, 3000); | ||
var b = Delay(2, 1000); | ||
var c = Delay(3, 2000); | ||
|
||
var l = new List<int>(); | ||
await foreach (var item in UniTask.WhenEach(a, b, c)) | ||
{ | ||
l.Add(item.Result); | ||
} | ||
|
||
l.Should().Equal(2, 3, 1); | ||
} | ||
|
||
[Fact] | ||
public async Task Error() | ||
{ | ||
var a = Delay2(1, 3000); | ||
var b = Delay2(2, 1000); | ||
var c = Delay2(3, 2000); | ||
|
||
var l = new List<WhenEachResult<int>>(); | ||
await foreach (var item in UniTask.WhenEach(a, b, c)) | ||
{ | ||
l.Add(item); | ||
} | ||
|
||
l[0].IsCompletedSuccessfully.Should().BeTrue(); | ||
l[0].IsFaulted.Should().BeFalse(); | ||
l[0].Result.Should().Be(2); | ||
|
||
l[1].IsCompletedSuccessfully.Should().BeFalse(); | ||
l[1].IsFaulted.Should().BeTrue(); | ||
l[1].Exception.Message.Should().Be("ERROR"); | ||
|
||
l[2].IsCompletedSuccessfully.Should().BeTrue(); | ||
l[2].IsFaulted.Should().BeFalse(); | ||
l[2].Result.Should().Be(1); | ||
} | ||
|
||
async UniTask<int> Delay(int id, int sleep) | ||
{ | ||
await Task.Delay(sleep); | ||
return id; | ||
} | ||
|
||
async UniTask<int> Delay2(int id, int sleep) | ||
{ | ||
await Task.Delay(sleep); | ||
if (id == 3) throw new Exception("ERROR"); | ||
return id; | ||
} | ||
} | ||
} |
183 changes: 183 additions & 0 deletions
183
src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WhenEach.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
using Cysharp.Threading.Tasks.Internal; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Runtime.ExceptionServices; | ||
using System.Threading; | ||
|
||
namespace Cysharp.Threading.Tasks | ||
{ | ||
public partial struct UniTask | ||
{ | ||
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(IEnumerable<UniTask<T>> tasks) | ||
{ | ||
return new WhenEachEnumerable<T>(tasks); | ||
} | ||
|
||
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(params UniTask<T>[] tasks) | ||
{ | ||
return new WhenEachEnumerable<T>(tasks); | ||
} | ||
} | ||
|
||
public readonly struct WhenEachResult<T> | ||
{ | ||
public T Result { get; } | ||
public Exception Exception { get; } | ||
|
||
//[MemberNotNullWhen(false, nameof(Exception))] | ||
public bool IsCompletedSuccessfully => Exception == null; | ||
|
||
//[MemberNotNullWhen(true, nameof(Exception))] | ||
public bool IsFaulted => Exception != null; | ||
|
||
public WhenEachResult(T result) | ||
{ | ||
this.Result = result; | ||
this.Exception = null; | ||
} | ||
|
||
public WhenEachResult(Exception exception) | ||
{ | ||
if (exception == null) throw new ArgumentNullException(nameof(exception)); | ||
this.Result = default!; | ||
this.Exception = exception; | ||
} | ||
|
||
public void TryThrow() | ||
{ | ||
if (IsFaulted) | ||
{ | ||
ExceptionDispatchInfo.Capture(Exception).Throw(); | ||
} | ||
} | ||
|
||
public T GetResult() | ||
{ | ||
if (IsFaulted) | ||
{ | ||
ExceptionDispatchInfo.Capture(Exception).Throw(); | ||
} | ||
return Result; | ||
} | ||
|
||
public override string ToString() | ||
{ | ||
if (IsCompletedSuccessfully) | ||
{ | ||
return Result?.ToString() ?? ""; | ||
} | ||
else | ||
{ | ||
return $"Exception{{{Exception.Message}}}"; | ||
} | ||
} | ||
} | ||
|
||
internal enum WhenEachState : byte | ||
{ | ||
NotRunning, | ||
Running, | ||
Completed | ||
} | ||
|
||
internal sealed class WhenEachEnumerable<T> : IUniTaskAsyncEnumerable<WhenEachResult<T>> | ||
{ | ||
IEnumerable<UniTask<T>> source; | ||
|
||
public WhenEachEnumerable(IEnumerable<UniTask<T>> source) | ||
{ | ||
this.source = source; | ||
} | ||
|
||
public IUniTaskAsyncEnumerator<WhenEachResult<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default) | ||
{ | ||
return new Enumerator(source, cancellationToken); | ||
} | ||
|
||
sealed class Enumerator : IUniTaskAsyncEnumerator<WhenEachResult<T>> | ||
{ | ||
readonly IEnumerable<UniTask<T>> source; | ||
CancellationToken cancellationToken; | ||
|
||
Channel<WhenEachResult<T>> channel; | ||
IUniTaskAsyncEnumerator<WhenEachResult<T>> channelEnumerator; | ||
int completeCount; | ||
WhenEachState state; | ||
|
||
public Enumerator(IEnumerable<UniTask<T>> source, CancellationToken cancellationToken) | ||
{ | ||
this.source = source; | ||
this.cancellationToken = cancellationToken; | ||
} | ||
|
||
public WhenEachResult<T> Current => channelEnumerator.Current; | ||
|
||
public UniTask<bool> MoveNextAsync() | ||
{ | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
if (state == WhenEachState.NotRunning) | ||
{ | ||
state = WhenEachState.Running; | ||
channel = Channel.CreateSingleConsumerUnbounded<WhenEachResult<T>>(); | ||
channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken); | ||
|
||
if (source is UniTask<T>[] array) | ||
{ | ||
ConsumeAll(this, array, array.Length); | ||
} | ||
else | ||
{ | ||
using (var rentArray = ArrayPoolUtil.Materialize(source)) | ||
{ | ||
ConsumeAll(this, rentArray.Array, rentArray.Length); | ||
} | ||
} | ||
} | ||
|
||
return channelEnumerator.MoveNextAsync(); | ||
} | ||
|
||
static void ConsumeAll(Enumerator self, UniTask<T>[] array, int length) | ||
{ | ||
for (int i = 0; i < length; i++) | ||
{ | ||
RunWhenEachTask(self, array[i], length).Forget(); | ||
} | ||
|
||
static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask<T> task, int length) | ||
{ | ||
try | ||
{ | ||
var result = await task; | ||
self.channel.Writer.TryWrite(new WhenEachResult<T>(result)); | ||
} | ||
catch (Exception ex) | ||
{ | ||
self.channel.Writer.TryWrite(new WhenEachResult<T>(ex)); | ||
} | ||
|
||
if (Interlocked.Increment(ref self.completeCount) == length) | ||
{ | ||
self.state = WhenEachState.Completed; | ||
self.channel.Writer.TryComplete(); | ||
} | ||
} | ||
} | ||
|
||
public async UniTask DisposeAsync() | ||
{ | ||
if (channelEnumerator != null) | ||
{ | ||
await channelEnumerator.DisposeAsync(); | ||
} | ||
|
||
if (state != WhenEachState.Completed) | ||
{ | ||
state = WhenEachState.Completed; | ||
channel.Writer.TryComplete(new OperationCanceledException()); | ||
} | ||
} | ||
} | ||
} | ||
} |
2 changes: 2 additions & 0 deletions
2
src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WhenEach.cs.meta
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.