Skip to content

Commit

Permalink
Implement ICollection<T> on ToLookup (dotnet#98895)
Browse files Browse the repository at this point in the history
LINQ's ToLookup returns an ILookup. The object we return today on some platforms implements `IIListProvider`. Instead, we can just have the returned object always implement `ICollection<T>`. The provided CopyTo then enables the default ToArray/ToList implementations to still be efficient. Implementing ICollection enables other uses to also take advantage of implementations like CopyTo and Count.

As `Lookup<TKey, TElement>` is public, this doesn't implement the interface on that type and instead derives a sealed internal type that provides the interface implementation (which is minimal). If we want to augment the public API, we could do so subsequently and delete the internal derived type.
  • Loading branch information
stephentoub authored Feb 26, 2024
1 parent 8ccb71e commit 05f9a34
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 76 deletions.
28 changes: 8 additions & 20 deletions src/libraries/System.Linq/src/System/Linq/Grouping.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,23 @@ public int GetCount(bool onlyIfCheap) =>

internal sealed partial class GroupedEnumerable<TSource, TKey, TElement> : IIListProvider<IGrouping<TKey, TElement>>
{
public IGrouping<TKey, TElement>[] ToArray()
{
IIListProvider<IGrouping<TKey, TElement>> lookup = Lookup<TKey, TElement>.Create(_source, _keySelector, _elementSelector, _comparer);
return lookup.ToArray();
}
public IGrouping<TKey, TElement>[] ToArray() =>
Lookup<TKey, TElement>.Create(_source, _keySelector, _elementSelector, _comparer).ToArray();

public List<IGrouping<TKey, TElement>> ToList()
{
IIListProvider<IGrouping<TKey, TElement>> lookup = Lookup<TKey, TElement>.Create(_source, _keySelector, _elementSelector, _comparer);
return lookup.ToList();
}
public List<IGrouping<TKey, TElement>> ToList() =>
Lookup<TKey, TElement>.Create(_source, _keySelector, _elementSelector, _comparer).ToList();

public int GetCount(bool onlyIfCheap) =>
onlyIfCheap ? -1 : Lookup<TKey, TElement>.Create(_source, _keySelector, _elementSelector, _comparer).Count;
}

internal sealed partial class GroupedEnumerable<TSource, TKey> : IIListProvider<IGrouping<TKey, TSource>>
{
public IGrouping<TKey, TSource>[] ToArray()
{
IIListProvider<IGrouping<TKey, TSource>> lookup = Lookup<TKey, TSource>.Create(_source, _keySelector, _comparer);
return lookup.ToArray();
}
public IGrouping<TKey, TSource>[] ToArray() =>
Lookup<TKey, TSource>.Create(_source, _keySelector, _comparer).ToArray();

public List<IGrouping<TKey, TSource>> ToList()
{
IIListProvider<IGrouping<TKey, TSource>> lookup = Lookup<TKey, TSource>.Create(_source, _keySelector, _comparer);
return lookup.ToList();
}
public List<IGrouping<TKey, TSource>> ToList() =>
Lookup<TKey, TSource>.Create(_source, _keySelector, _comparer).ToList();

public int GetCount(bool onlyIfCheap) =>
onlyIfCheap ? -1 : Lookup<TKey, TSource>.Create(_source, _keySelector, _comparer).Count;
Expand Down
50 changes: 1 addition & 49 deletions src/libraries/System.Linq/src/System/Linq/Lookup.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,8 @@

namespace System.Linq
{
public partial class Lookup<TKey, TElement> : IIListProvider<IGrouping<TKey, TElement>>
public partial class Lookup<TKey, TElement>
{
IGrouping<TKey, TElement>[] IIListProvider<IGrouping<TKey, TElement>>.ToArray()
{
IGrouping<TKey, TElement>[] array;
if (_count > 0)
{
array = new IGrouping<TKey, TElement>[_count];
Fill(_lastGrouping, array);
}
else
{
array = [];
}
return array;
}

internal TResult[] ToArray<TResult>(Func<TKey, IEnumerable<TElement>, TResult> resultSelector)
{
TResult[] array = new TResult[_count];
Expand All @@ -44,38 +29,5 @@ internal TResult[] ToArray<TResult>(Func<TKey, IEnumerable<TElement>, TResult> r

return array;
}

List<IGrouping<TKey, TElement>> IIListProvider<IGrouping<TKey, TElement>>.ToList()
{
var list = new List<IGrouping<TKey, TElement>>(_count);
if (_count > 0)
{
Fill(_lastGrouping, Enumerable.SetCountAndGetSpan(list, _count));
}

return list;
}

private static void Fill(Grouping<TKey, TElement>? lastGrouping, Span<IGrouping<TKey, TElement>> results)
{
int index = 0;
Grouping<TKey, TElement>? g = lastGrouping;
if (g != null)
{
do
{
g = g._next;
Debug.Assert(g != null);

results[index] = g;
++index;
}
while (g != lastGrouping);
}

Debug.Assert(index == results.Length, "All list elements were not initialized.");
}

int IIListProvider<IGrouping<TKey, TElement>>.GetCount(bool onlyIfCheap) => _count;
}
}
66 changes: 59 additions & 7 deletions src/libraries/System.Linq/src/System/Linq/Lookup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public partial class Lookup<TKey, TElement> : ILookup<TKey, TElement>
{
private readonly IEqualityComparer<TKey> _comparer;
private Grouping<TKey, TElement>[] _groupings;
private Grouping<TKey, TElement>? _lastGrouping;
private protected Grouping<TKey, TElement>? _lastGrouping;
private int _count;

internal static Lookup<TKey, TElement> Create<TSource>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey>? comparer)
Expand All @@ -85,7 +85,7 @@ internal static Lookup<TKey, TElement> Create<TSource>(IEnumerable<TSource> sour
Debug.Assert(keySelector != null);
Debug.Assert(elementSelector != null);

Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
var lookup = new CollectionLookup<TKey, TElement>(comparer);
foreach (TSource item in source)
{
lookup.GetGrouping(keySelector(item), create: true)!.Add(elementSelector(item));
Expand All @@ -99,7 +99,7 @@ internal static Lookup<TKey, TElement> Create(IEnumerable<TElement> source, Func
Debug.Assert(source != null);
Debug.Assert(keySelector != null);

Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
var lookup = new CollectionLookup<TKey, TElement>(comparer);
foreach (TElement item in source)
{
lookup.GetGrouping(keySelector(item), create: true)!.Add(item);
Expand All @@ -110,7 +110,7 @@ internal static Lookup<TKey, TElement> Create(IEnumerable<TElement> source, Func

internal static Lookup<TKey, TElement> CreateForJoin(IEnumerable<TElement> source, Func<TElement, TKey> keySelector, IEqualityComparer<TKey>? comparer)
{
Lookup<TKey, TElement> lookup = new Lookup<TKey, TElement>(comparer);
var lookup = new CollectionLookup<TKey, TElement>(comparer);
foreach (TElement item in source)
{
TKey key = keySelector(item);
Expand All @@ -123,7 +123,7 @@ internal static Lookup<TKey, TElement> CreateForJoin(IEnumerable<TElement> sourc
return lookup;
}

private Lookup(IEqualityComparer<TKey>? comparer)
private protected Lookup(IEqualityComparer<TKey>? comparer)
{
_comparer = comparer ?? EqualityComparer<TKey>.Default;
_groupings = new Grouping<TKey, TElement>[7];
Expand Down Expand Up @@ -259,16 +259,68 @@ private void Resize()
}
}

internal sealed class CollectionLookup<TKey, TElement> : Lookup<TKey, TElement>, ICollection<IGrouping<TKey, TElement>>, IReadOnlyCollection<IGrouping<TKey, TElement>>
{
internal CollectionLookup(IEqualityComparer<TKey>? comparer) : base(comparer) { }

void ICollection<IGrouping<TKey, TElement>>.CopyTo(IGrouping<TKey, TElement>[] array, int arrayIndex)
{
ArgumentNullException.ThrowIfNull(array);
ArgumentOutOfRangeException.ThrowIfNegative(arrayIndex);
ArgumentOutOfRangeException.ThrowIfGreaterThan(arrayIndex, array.Length);
ArgumentOutOfRangeException.ThrowIfLessThan(array.Length - arrayIndex, Count, nameof(arrayIndex));

Grouping<TKey, TElement>? g = _lastGrouping;
if (g != null)
{
do
{
g = g._next;
Debug.Assert(g != null);

array[arrayIndex] = g;
++arrayIndex;
}
while (g != _lastGrouping);
}
}

bool ICollection<IGrouping<TKey, TElement>>.Contains(IGrouping<TKey, TElement> item)
{
ArgumentNullException.ThrowIfNull(item);
return GetGrouping(item.Key, create: false) is { } grouping && grouping == item;
}

bool ICollection<IGrouping<TKey, TElement>>.IsReadOnly => true;
void ICollection<IGrouping<TKey, TElement>>.Add(IGrouping<TKey, TElement> item) => throw new NotSupportedException();
void ICollection<IGrouping<TKey, TElement>>.Clear() => throw new NotSupportedException();
bool ICollection<IGrouping<TKey, TElement>>.Remove(IGrouping<TKey, TElement> item) => throw new NotSupportedException();
}

[DebuggerDisplay("Count = 0")]
[DebuggerTypeProxy(typeof(SystemLinq_LookupDebugView<,>))]
internal sealed class EmptyLookup<TKey, TElement> : ILookup<TKey, TElement>
internal sealed class EmptyLookup<TKey, TElement> : ILookup<TKey, TElement>, ICollection<IGrouping<TKey, TElement>>, IReadOnlyCollection<IGrouping<TKey, TElement>>
{
public static readonly EmptyLookup<TKey, TElement> Instance = new();

public IEnumerable<TElement> this[TKey key] => [];
public int Count => 0;
public bool Contains(TKey key) => false;

public IEnumerator<IGrouping<TKey, TElement>> GetEnumerator() => Enumerable.Empty<IGrouping<TKey, TElement>>().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

public bool Contains(TKey key) => false;
public bool Contains(IGrouping<TKey, TElement> item) => false;
public void CopyTo(IGrouping<TKey, TElement>[] array, int arrayIndex)
{
ArgumentNullException.ThrowIfNull(array);
ArgumentOutOfRangeException.ThrowIfNegative(arrayIndex);
ArgumentOutOfRangeException.ThrowIfGreaterThan(arrayIndex, array.Length);
}

public bool IsReadOnly => true;
public void Add(IGrouping<TKey, TElement> item) => throw new NotSupportedException();
public void Clear() => throw new NotSupportedException();
public bool Remove(IGrouping<TKey, TElement> item) => throw new NotSupportedException();
}
}
32 changes: 32 additions & 0 deletions src/libraries/System.Linq/tests/ToLookupTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,38 @@ public void ApplyResultSelectorForGroup(int enumType)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(10)]
public void LookupImplementsICollection(int count)
{
Assert.IsAssignableFrom<ICollection<IGrouping<string, int>>>(Enumerable.Range(0, count).ToLookup(i => i.ToString()));
Assert.IsAssignableFrom<ICollection<IGrouping<string, int>>>(Enumerable.Range(0, count).ToLookup(i => i.ToString(), StringComparer.OrdinalIgnoreCase));
Assert.IsAssignableFrom<ICollection<IGrouping<string, int>>>(Enumerable.Range(0, count).ToLookup(i => i.ToString(), i => i));
Assert.IsAssignableFrom<ICollection<IGrouping<string, int>>>(Enumerable.Range(0, count).ToLookup(i => i.ToString(), i => i, StringComparer.OrdinalIgnoreCase));

var collection = (ICollection<IGrouping<string, int>>)Enumerable.Range(0, count).ToLookup(i => i.ToString());
Assert.Equal(count, collection.Count);
Assert.Throws<NotSupportedException>(() => collection.Add(null));
Assert.Throws<NotSupportedException>(() => collection.Remove(null));
Assert.Throws<NotSupportedException>(() => collection.Clear());

if (count > 0)
{
IGrouping<string, int> first = collection.First();
IGrouping<string, int> last = collection.Last();
Assert.True(collection.Contains(first));
Assert.True(collection.Contains(last));
}

IGrouping<string, int>[] items = new IGrouping<string, int>[count];
collection.CopyTo(items, 0);
Assert.Equal(collection.Select(i => i), items);
Assert.Equal(items, Enumerable.Range(0, count).ToLookup(i => i.ToString()).ToArray());
Assert.Equal(items, Enumerable.Range(0, count).ToLookup(i => i.ToString()).ToList());
}

public class Membership
{
public int Id { get; set; }
Expand Down

0 comments on commit 05f9a34

Please sign in to comment.