Skip to content

Commit

Permalink
Fixed #247
Browse files Browse the repository at this point in the history
  • Loading branch information
sakno committed Jul 14, 2024
1 parent 767a82a commit 3d2471f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 32 deletions.
69 changes: 69 additions & 0 deletions src/DotNext.Tests/Collections/Generic/CollectionTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Runtime.InteropServices;

namespace DotNext.Collections.Generic;

Expand Down Expand Up @@ -79,6 +81,20 @@ public static void ForEachTest()
array2.ForEach(counter.Accept);
Equal(5, counter.value);
}

[Fact]
public static async Task ForEachTestAsync()
{
IList<int> list = new List<int> { 1, 10, 20 };
var counter = new Counter<int>();
await list.ForEachAsync(counter.AcceptAsync);
Equal(3, counter.value);
counter.value = 0;

var array2 = new int[] { 1, 2, 10, 11, 15 };
await array2.ForEachAsync(counter.AcceptAsync);
Equal(5, counter.value);
}

[Fact]
public static void ElementAtIndex()
Expand Down Expand Up @@ -275,4 +291,57 @@ public static void CopyString()
using var copy = "abcd".Copy();
Equal("abcd", copy.Memory.ToString());
}

[Fact]
public static void FirstOrNone()
{
Equal(5, new[] { 5, 6 }.FirstOrNone());
Equal(5, new List<int> { 5, 6 }.FirstOrNone());
Equal(5, new LinkedList<int>([5, 6]).FirstOrNone());
Equal('5', "56".FirstOrNone());
Equal(5, ImmutableArray.Create([5, 6]).FirstOrNone());
Equal(5, GetValues().FirstOrNone());

Equal(Optional<int>.None, Array.Empty<int>().FirstOrNone());
Equal(Optional<int>.None, new List<int>().FirstOrNone());
Equal(Optional<int>.None, new LinkedList<int>().FirstOrNone());
Equal(Optional<char>.None, string.Empty.FirstOrNone());
Equal(Optional<int>.None, ImmutableArray<int>.Empty.FirstOrNone());
Equal(Optional<int>.None, EmptyEnumerable<int>().FirstOrNone());

static IEnumerable<int> GetValues()
{
yield return 5;
yield return 6;
}
}

[Fact]
public static void LastOrNone()
{
Equal(6, new[] { 5, 6 }.LastOrNone());
Equal(6, new List<int> { 5, 6 }.LastOrNone());
Equal(6, new LinkedList<int>([5, 6]).LastOrNone());
Equal('6', "56".LastOrNone());
Equal(6, ImmutableArray.Create([5, 6]).LastOrNone());
Equal(6, GetValues().LastOrNone());

Equal(Optional<int>.None, Array.Empty<int>().LastOrNone());
Equal(Optional<int>.None, new List<int>().LastOrNone());
Equal(Optional<int>.None, new LinkedList<int>().LastOrNone());
Equal(Optional<char>.None, string.Empty.LastOrNone());
Equal(Optional<int>.None, ImmutableArray<int>.Empty.LastOrNone());
Equal(Optional<int>.None, EmptyEnumerable<int>().LastOrNone());

static IEnumerable<int> GetValues()
{
yield return 5;
yield return 6;
}
}

static IEnumerable<T> EmptyEnumerable<T>()
{
yield break;
}
}
33 changes: 16 additions & 17 deletions src/DotNext/Collections/Generic/AsyncEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace DotNext.Collections.Generic;
public static partial class AsyncEnumerable
{
/// <summary>
/// Applies specified action to each collection element asynchronously.
/// Applies specified action to each element of the collection asynchronously.
/// </summary>
/// <typeparam name="T">Type of elements in the collection.</typeparam>
/// <param name="collection">A collection to enumerate. Cannot be <see langword="null"/>.</param>
Expand All @@ -23,7 +23,7 @@ public static async ValueTask ForEachAsync<T>(this IAsyncEnumerable<T> collectio
}

/// <summary>
/// Applies specified action to each collection element asynchronously.
/// Applies the specified action to each element of the collection asynchronously.
/// </summary>
/// <typeparam name="T">Type of elements in the collection.</typeparam>
/// <param name="collection">A collection to enumerate. Cannot be <see langword="null"/>.</param>
Expand All @@ -38,8 +38,8 @@ public static async ValueTask ForEachAsync<T>(this IAsyncEnumerable<T> collectio
}

/// <summary>
/// Obtains first value type in the sequence; or <see langword="null"/>
/// if sequence is empty.
/// Obtains the first value of a sequence; or <see langword="null"/>
/// if the sequence is empty.
/// </summary>
/// <typeparam name="T">Type of elements in the sequence.</typeparam>
/// <param name="seq">A sequence to check. Cannot be <see langword="null"/>.</param>
Expand All @@ -55,8 +55,8 @@ public static async ValueTask ForEachAsync<T>(this IAsyncEnumerable<T> collectio
}

/// <summary>
/// Obtains the last value type in the sequence; or <see langword="null"/>
/// if sequence is empty.
/// Obtains the last value of a sequence; or <see langword="null"/>
/// if the sequence is empty.
/// </summary>
/// <typeparam name="T">Type of elements in the sequence.</typeparam>
/// <param name="seq">A sequence to check. Cannot be <see langword="null"/>.</param>
Expand All @@ -74,8 +74,8 @@ public static async ValueTask ForEachAsync<T>(this IAsyncEnumerable<T> collectio
}

/// <summary>
/// Obtains first element in the sequence; or <see cref="Optional{T}.None"/>
/// if sequence is empty.
/// Obtains the first element of a sequence; or <see cref="Optional{T}.None"/>
/// if the sequence is empty.
/// </summary>
/// <typeparam name="T">Type of elements in the sequence.</typeparam>
/// <param name="seq">A sequence to check. Cannot be <see langword="null"/>.</param>
Expand All @@ -90,8 +90,8 @@ public static async ValueTask<Optional<T>> FirstOrNoneAsync<T>(this IAsyncEnumer
}

/// <summary>
/// Obtains the last element in the sequence; or <see cref="Optional{T}.None"/>
/// if sequence is empty.
/// Obtains the last element of a sequence; or <see cref="Optional{T}.None"/>
/// if the sequence is empty.
/// </summary>
/// <typeparam name="T">Type of elements in the sequence.</typeparam>
/// <param name="seq">A sequence to check. Cannot be <see langword="null"/>.</param>
Expand All @@ -118,6 +118,8 @@ public static async ValueTask<Optional<T>> LastOrNoneAsync<T>(this IAsyncEnumera
/// <exception cref="OperationCanceledException">The operation has been canceled.</exception>
public static async ValueTask<Optional<T>> FirstOrNoneAsync<T>(this IAsyncEnumerable<T> seq, Predicate<T> filter, CancellationToken token = default)
{
ArgumentNullException.ThrowIfNull(filter);

await foreach (var item in seq.WithCancellation(token).ConfigureAwait(false))
{
if (filter(item))
Expand All @@ -137,11 +139,10 @@ public static async ValueTask<Optional<T>> FirstOrNoneAsync<T>(this IAsyncEnumer
/// <exception cref="OperationCanceledException">The operation has been canceled.</exception>
public static async ValueTask<bool> SkipAsync<T>(this IAsyncEnumerator<T> enumerator, int count)
{
while (count > 0)
for (; count > 0; count--)
{
if (!await enumerator.MoveNextAsync().ConfigureAwait(false))
return false;
count--;
}

return true;
Expand All @@ -161,11 +162,9 @@ public static async ValueTask<Optional<T>> ElementAtAsync<T>(this IAsyncEnumerab
var enumerator = collection.GetAsyncEnumerator(token);
await using (enumerator.ConfigureAwait(false))
{
await enumerator.SkipAsync(index).ConfigureAwait(false);

return await enumerator.MoveNextAsync().ConfigureAwait(false) ?
enumerator.Current :
Optional<T>.None;
return await enumerator.SkipAsync(index).ConfigureAwait(false) && await enumerator.MoveNextAsync().ConfigureAwait(false)
? enumerator.Current
: Optional<T>.None;
}
}

Expand Down
73 changes: 68 additions & 5 deletions src/DotNext/Collections/Generic/Collection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,67 @@ public static async ValueTask ForEachAsync<T>(this IEnumerable<T> collection, Fu
await action.Invoke(item, token).ConfigureAwait(false);
}

/// <summary>
/// Obtains the first element of a sequence; or <see cref="Optional{T}.None"/>
/// if the sequence is empty.
/// </summary>
/// <param name="collection">The collection to return the first element of.</param>
/// <typeparam name="T">The type of the element of a collection.</typeparam>
/// <returns>The first element; or <see cref="Optional{T}.None"/></returns>
public static Optional<T> FirstOrNone<T>(this IEnumerable<T> collection)
{
return collection switch
{
null => throw new ArgumentNullException(nameof(collection)),
List<T> list => Span.FirstOrNone<T>(CollectionsMarshal.AsSpan(list)),
T[] array => Span.FirstOrNone<T>(array),
string str => Unsafe.BitCast<Optional<char>, Optional<T>>(Span.FirstOrNone<char>(str)),
LinkedList<T> list => list.First is { } first ? first.Value : Optional<T>.None,
IList<T> list => list.Count > 0 ? list[0] : Optional<T>.None,
IReadOnlyList<T> readOnlyList => readOnlyList.Count > 0 ? readOnlyList[0] : Optional<T>.None,
_ => FirstOrNoneSlow(collection),
};

static Optional<T> FirstOrNoneSlow(IEnumerable<T> collection)
{
using var enumerator = collection.GetEnumerator();
return enumerator.MoveNext() ? enumerator.Current : Optional<T>.None;
}
}

/// <summary>
/// Obtains the last element of a sequence; or <see cref="Optional{T}.None"/>
/// if the sequence is empty.
/// </summary>
/// <param name="collection">The collection to return the first element of.</param>
/// <typeparam name="T">The type of the element of a collection.</typeparam>
/// <returns>The first element; or <see cref="Optional{T}.None"/></returns>
public static Optional<T> LastOrNone<T>(this IEnumerable<T> collection)
{
return collection switch
{
null => throw new ArgumentNullException(nameof(collection)),
List<T> list => Span.LastOrNone<T>(CollectionsMarshal.AsSpan(list)),
T[] array => Span.LastOrNone<T>(array),
string str => Unsafe.BitCast<Optional<char>, Optional<T>>(Span.LastOrNone<char>(str)),
LinkedList<T> list => list.Last is { } last ? last.Value : Optional<T>.None,
IList<T> list => list.Count > 0 ? list[^1] : Optional<T>.None,
IReadOnlyList<T> readOnlyList => readOnlyList.Count > 0 ? readOnlyList[^1] : Optional<T>.None,
_ => LastOrNoneSlow(collection),
};

static Optional<T> LastOrNoneSlow(IEnumerable<T> collection)
{
var result = Optional.None<T>();
foreach (var item in collection)
{
result = item;
}

return result;
}
}

/// <summary>
/// Obtains element at the specified index in the sequence.
/// </summary>
Expand All @@ -181,6 +242,7 @@ public static bool ElementAt<T>(this IEnumerable<T> collection, int index, [Mayb
{
return collection switch
{
null => throw new ArgumentNullException(nameof(collection)),
List<T> list => Span.ElementAt(CollectionsMarshal.AsSpan(list), index, out element),
T[] array => Span.ElementAt(array, index, out element),
LinkedList<T> list => NodeValueAt(list, index, out element),
Expand Down Expand Up @@ -211,14 +273,15 @@ static bool NodeValueAt(LinkedList<T> list, int matchIndex, [MaybeNullWhen(false
static bool ElementAtSlow(IEnumerable<T> collection, int index, [MaybeNullWhen(false)] out T element)
{
using var enumerator = collection.GetEnumerator();
enumerator.Skip(index);
if (enumerator.MoveNext())

// enumerator.Skip(index + 1) may overflow, replace it with two calls
if (enumerator.Skip(index) && enumerator.MoveNext())
{
element = enumerator.Current;
return true;
}

element = default!;
element = default;
return false;
}

Expand All @@ -230,7 +293,7 @@ static bool ListElementAt(IList<T> list, int index, [MaybeNullWhen(false)] out T
return true;
}

element = default!;
element = default;
return false;
}

Expand All @@ -242,7 +305,7 @@ static bool ReadOnlyListElementAt(IReadOnlyList<T> list, int index, [MaybeNullWh
return true;
}

element = default!;
element = default;
return false;
}
}
Expand Down
8 changes: 2 additions & 6 deletions src/DotNext/Collections/Generic/Enumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ public static partial class Enumerator
/// <returns><see langword="true"/>, if current element is available; otherwise, <see langword="false"/>.</returns>
public static bool Skip<T>(this IEnumerator<T> enumerator, int count)
{
while (count > 0)
for (; count > 0; count--)
{
if (!enumerator.MoveNext())
return false;

count--;
}

return true;
Expand All @@ -41,12 +39,10 @@ public static bool Skip<T>(this IEnumerator<T> enumerator, int count)
public static bool Skip<TEnumerator, T>(this ref TEnumerator enumerator, int count)
where TEnumerator : struct, IEnumerator<T>
{
while (count > 0)
for (; count > 0; count--)
{
if (!enumerator.MoveNext())
return false;

count--;
}

return true;
Expand Down
14 changes: 10 additions & 4 deletions src/DotNext/Span.cs
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,10 @@ public static void CopyTo<T>(this Span<T> source, Span<T> destination, out int w
/// <param name="span">The source span.</param>
/// <param name="filter">A function to test each element for a condition.</param>
/// <returns>The first element in the span that matches to the specified filter; or <see cref="Optional{T}.None"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="filter"/> is <see langword="null"/>.</exception>
public static Optional<T> FirstOrNone<T>(this ReadOnlySpan<T> span, Predicate<T> filter)
public static Optional<T> FirstOrNone<T>(this ReadOnlySpan<T> span, Predicate<T>? filter = null)
{
ArgumentNullException.ThrowIfNull(filter);

filter ??= Predicate.Constant<T>(true);
for (var i = 0; i < span.Length; i++)
{
var item = span[i];
Expand All @@ -434,6 +433,13 @@ public static Optional<T> FirstOrNone<T>(this ReadOnlySpan<T> span, Predicate<T>
return Optional<T>.None;
}

internal static Optional<T> LastOrNone<T>(ReadOnlySpan<T> span)
{
ref var elementRef = ref MemoryMarshal.GetReference(span);
var length = span.Length;
return length > 0 ? Unsafe.Add(ref elementRef, length - 1) : Optional<T>.None;
}

internal static bool ElementAt<T>(ReadOnlySpan<T> span, int index, [MaybeNullWhen(false)] out T element)
{
if ((uint)index < (uint)span.Length)
Expand Down

0 comments on commit 3d2471f

Please sign in to comment.