Skip to content

Commit

Permalink
Add UniTask.WhenEach
Browse files Browse the repository at this point in the history
  • Loading branch information
neuecc committed Sep 30, 2024
1 parent 87e164e commit 0826b7e
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 21 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ UniTask provides three pattern of extension methods.
> Note: AssetBundleRequest has `asset` and `allAssets`, default await returns `asset`. If you want to get `allAssets`, you can use `AwaitForAllAssets()` method.
The type of `UniTask` can use utilities like `UniTask.WhenAll`, `UniTask.WhenAny`. They are like `Task.WhenAll`/`Task.WhenAny` but the return type is more useful. They return value tuples so you can deconstruct each result and pass multiple types.
The type of `UniTask` can use utilities like `UniTask.WhenAll`, `UniTask.WhenAny`, `UniTask.WhenEach`. They are like `Task.WhenAll`/`Task.WhenAny` but the return type is more useful. They return value tuples so you can deconstruct each result and pass multiple types.

```csharp
public async UniTaskVoid LoadManyAsync()
Expand Down Expand Up @@ -716,6 +716,19 @@ await UniTaskAsyncEnumerable.EveryUpdate().ForEachAsync(_ =>
}, token);
```

`UniTask.WhenEach` that is similar to .NET 9's `Task.WhenEach` can consume new way for await multiple tasks.

```csharp
await foreach (var result in UniTask.WhenEach(task1, task2, task3))
{
// The result is of type WhenEachResult<T>.
// It contains either `T Result` or `Exception Exception`.
// You can check `IsCompletedSuccessfully` or `IsFaulted` to determine whether to access `.Result` or `.Exception`.
// If you want to throw an exception when `IsFaulted` and retrieve the result when successful, use `GetResult()`.
Debug.Log(result.GetResult());
}
```

UniTaskAsyncEnumerable implements asynchronous LINQ, similar to LINQ in `IEnumerable<T>` or Rx in `IObservable<T>`. All standard LINQ query operators can be applied to asynchronous streams. For example, the following code shows how to apply a Where filter to a button-click asynchronous stream that runs once every two clicks.

```csharp
Expand Down Expand Up @@ -1026,6 +1039,7 @@ Use UniTask type.
| `Task.Run` | `UniTask.RunOnThreadPool` |
| `Task.WhenAll` | `UniTask.WhenAll` |
| `Task.WhenAny` | `UniTask.WhenAny` |
| `Task.WhenEach` | `UniTask.WhenEach` |
| `Task.CompletedTask` | `UniTask.CompletedTask` |
| `Task.FromException` | `UniTask.FromException` |
| `Task.FromResult` | `UniTask.FromResult` |
Expand Down
69 changes: 69 additions & 0 deletions src/UniTask.NetCoreTests/WhenEachTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Cysharp.Threading.Tasks;
using FluentAssertions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;

namespace NetCoreTests
{
public class WhenEachTest
{
[Fact]
public async Task Each()
{
var a = Delay(1, 3000);
var b = Delay(2, 1000);
var c = Delay(3, 2000);

var l = new List<int>();
await foreach (var item in UniTask.WhenEach(a, b, c))
{
l.Add(item.Result);
}

l.Should().Equal(2, 3, 1);
}

[Fact]
public async Task Error()
{
var a = Delay2(1, 3000);
var b = Delay2(2, 1000);
var c = Delay2(3, 2000);

var l = new List<WhenEachResult<int>>();
await foreach (var item in UniTask.WhenEach(a, b, c))
{
l.Add(item);
}

l[0].IsCompletedSuccessfully.Should().BeTrue();
l[0].IsFaulted.Should().BeFalse();
l[0].Result.Should().Be(2);

l[1].IsCompletedSuccessfully.Should().BeFalse();
l[1].IsFaulted.Should().BeTrue();
l[1].Exception.Message.Should().Be("ERROR");

l[2].IsCompletedSuccessfully.Should().BeTrue();
l[2].IsFaulted.Should().BeFalse();
l[2].Result.Should().Be(1);
}

async UniTask<int> Delay(int id, int sleep)
{
await Task.Delay(sleep);
return id;
}

async UniTask<int> Delay2(int id, int sleep)
{
await Task.Delay(sleep);
if (id == 3) throw new Exception("ERROR");
return id;
}
}
}
183 changes: 183 additions & 0 deletions src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WhenEach.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
using Cysharp.Threading.Tasks.Internal;
using System;
using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Threading;

namespace Cysharp.Threading.Tasks
{
public partial struct UniTask
{
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(IEnumerable<UniTask<T>> tasks)
{
return new WhenEachEnumerable<T>(tasks);
}

public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(params UniTask<T>[] tasks)
{
return new WhenEachEnumerable<T>(tasks);
}
}

public readonly struct WhenEachResult<T>
{
public T Result { get; }
public Exception Exception { get; }

//[MemberNotNullWhen(false, nameof(Exception))]
public bool IsCompletedSuccessfully => Exception == null;

//[MemberNotNullWhen(true, nameof(Exception))]
public bool IsFaulted => Exception != null;

public WhenEachResult(T result)
{
this.Result = result;
this.Exception = null;
}

public WhenEachResult(Exception exception)
{
if (exception == null) throw new ArgumentNullException(nameof(exception));
this.Result = default!;
this.Exception = exception;
}

public void TryThrow()
{
if (IsFaulted)
{
ExceptionDispatchInfo.Capture(Exception).Throw();
}
}

public T GetResult()
{
if (IsFaulted)
{
ExceptionDispatchInfo.Capture(Exception).Throw();
}
return Result;
}

public override string ToString()
{
if (IsCompletedSuccessfully)
{
return Result?.ToString() ?? "";
}
else
{
return $"Exception{{{Exception.Message}}}";
}
}
}

internal enum WhenEachState : byte
{
NotRunning,
Running,
Completed
}

internal sealed class WhenEachEnumerable<T> : IUniTaskAsyncEnumerable<WhenEachResult<T>>
{
IEnumerable<UniTask<T>> source;

public WhenEachEnumerable(IEnumerable<UniTask<T>> source)
{
this.source = source;
}

public IUniTaskAsyncEnumerator<WhenEachResult<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(source, cancellationToken);
}

sealed class Enumerator : IUniTaskAsyncEnumerator<WhenEachResult<T>>
{
readonly IEnumerable<UniTask<T>> source;
CancellationToken cancellationToken;

Channel<WhenEachResult<T>> channel;
IUniTaskAsyncEnumerator<WhenEachResult<T>> channelEnumerator;
int completeCount;
WhenEachState state;

public Enumerator(IEnumerable<UniTask<T>> source, CancellationToken cancellationToken)
{
this.source = source;
this.cancellationToken = cancellationToken;
}

public WhenEachResult<T> Current => channelEnumerator.Current;

public UniTask<bool> MoveNextAsync()
{
cancellationToken.ThrowIfCancellationRequested();

if (state == WhenEachState.NotRunning)
{
state = WhenEachState.Running;
channel = Channel.CreateSingleConsumerUnbounded<WhenEachResult<T>>();
channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken);

if (source is UniTask<T>[] array)
{
ConsumeAll(this, array, array.Length);
}
else
{
using (var rentArray = ArrayPoolUtil.Materialize(source))
{
ConsumeAll(this, rentArray.Array, rentArray.Length);
}
}
}

return channelEnumerator.MoveNextAsync();
}

static void ConsumeAll(Enumerator self, UniTask<T>[] array, int length)
{
for (int i = 0; i < length; i++)
{
RunWhenEachTask(self, array[i], length).Forget();
}

static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask<T> task, int length)
{
try
{
var result = await task;
self.channel.Writer.TryWrite(new WhenEachResult<T>(result));
}
catch (Exception ex)
{
self.channel.Writer.TryWrite(new WhenEachResult<T>(ex));
}

if (Interlocked.Increment(ref self.completeCount) == length)
{
self.state = WhenEachState.Completed;
self.channel.Writer.TryComplete();
}
}
}

public async UniTask DisposeAsync()
{
if (channelEnumerator != null)
{
await channelEnumerator.DisposeAsync();
}

if (state != WhenEachState.Completed)
{
state = WhenEachState.Completed;
channel.Writer.TryComplete(new OperationCanceledException());
}
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions src/UniTask/Assets/Scenes/SandboxMain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,28 @@ public void Dispose()
connection.Dispose();
}
}
public class WhenEachTest
{
public async UniTask Each()
{
var a = Delay(1, 3000);
var b = Delay(2, 1000);
var c = Delay(3, 2000);

var l = new List<int>();
await foreach (var item in UniTask.WhenEach(a, b, c))
{
Debug.Log(item.Result);
}
}

async UniTask<int> Delay(int id, int sleep)
{
await UniTask.Delay(sleep);
return id;
}

}

public class SandboxMain : MonoBehaviour
{
Expand Down Expand Up @@ -147,6 +168,18 @@ async UniTask<int> FooAsync()

Debug.Log("Again");


// var foo = InstantiateAsync<SandboxMain>(this).ToUniTask();





// var tako = await foo;




return 10;
}

Expand Down Expand Up @@ -557,6 +590,7 @@ private static async UniTask TestAsync(CancellationToken ct)

async UniTaskVoid Start()
{
await new WhenEachTest().Each();


// UniTask.Delay(TimeSpan.FromSeconds(1)).TimeoutWithoutException
Expand Down
Loading

0 comments on commit 0826b7e

Please sign in to comment.