Skip to content

Commit

Permalink
Implement PriorityQueue.Remove (#93994)
Browse files Browse the repository at this point in the history
* Implement PriorityQueue.Remove

* Update src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs

* Update src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs

Co-authored-by: Dan Moseley <[email protected]>

* Update src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs

* Update src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs

* Update src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs

Co-authored-by: Stephen Toub <[email protected]>

* Address feedback.

* Address feedback

* Add a Dijkstra smoke test.

* Alias distance type

---------

Co-authored-by: Dan Moseley <[email protected]>
Co-authored-by: Stephen Toub <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent 44a5abd commit c1f4341
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Collections/ref/System.Collections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public void EnqueueRange(System.Collections.Generic.IEnumerable<(TElement Elemen
public void EnqueueRange(System.Collections.Generic.IEnumerable<TElement> elements, TPriority priority) { }
public int EnsureCapacity(int capacity) { throw null; }
public TElement Peek() { throw null; }
public bool Remove(TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement removedElement, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority, System.Collections.Generic.IEqualityComparer<TElement>? equalityComparer = null) { throw null; }
public void TrimExcess() { }
public bool TryDequeue([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority) { throw null; }
public bool TryPeek([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,59 @@ public void EnqueueRange(IEnumerable<TElement> elements, TPriority priority)
}
}

/// <summary>
/// Removes the first occurrence that equals the specified parameter.
/// </summary>
/// <param name="element">The element to try to remove.</param>
/// <param name="removedElement">The actual element that got removed from the queue.</param>
/// <param name="priority">The priority value associated with the removed element.</param>
/// <param name="equalityComparer">The equality comparer governing element equality.</param>
/// <returns><see langword="true"/> if matching entry was found and removed, <see langword="false"/> otherwise.</returns>
/// <remarks>
/// The method performs a linear-time scan of every element in the heap, removing the first value found to match the <paramref name="element"/> parameter.
/// In case of duplicate entries, what entry does get removed is non-deterministic and does not take priority into account.
///
/// If no <paramref name="equalityComparer"/> is specified, <see cref="EqualityComparer{TElement}.Default"/> will be used instead.
/// </remarks>
public bool Remove(
TElement element,
[MaybeNullWhen(false)] out TElement removedElement,
[MaybeNullWhen(false)] out TPriority priority,
IEqualityComparer<TElement>? equalityComparer = null)
{
int index = FindIndex(element, equalityComparer);
if (index < 0)
{
removedElement = default;
priority = default;
return false;
}

(TElement Element, TPriority Priority)[] nodes = _nodes;
(removedElement, priority) = nodes[index];
int newSize = --_size;

if (index < newSize)
{
// We're removing an element from the middle of the heap.
// Pop the last element in the collection and sift downward from the removed index.
(TElement Element, TPriority Priority) lastNode = nodes[newSize];

if (_comparer == null)
{
MoveDownDefaultComparer(lastNode, index);
}
else
{
MoveDownCustomComparer(lastNode, index);
}
}

nodes[newSize] = default;
_version++;
return true;
}

/// <summary>
/// Removes all items from the <see cref="PriorityQueue{TElement, TPriority}"/>.
/// </summary>
Expand Down Expand Up @@ -809,6 +862,41 @@ private void MoveDownCustomComparer((TElement Element, TPriority Priority) node,
nodes[nodeIndex] = node;
}

/// <summary>
/// Scans the heap for the first index containing an element equal to the specified parameter.
/// </summary>
private int FindIndex(TElement element, IEqualityComparer<TElement>? equalityComparer)
{
equalityComparer ??= EqualityComparer<TElement>.Default;
ReadOnlySpan<(TElement Element, TPriority Priority)> nodes = _nodes.AsSpan(0, _size);

// Currently the JIT doesn't optimize direct EqualityComparer<T>.Default.Equals
// calls for reference types, so we want to cache the comparer instance instead.
// TODO https://github.com/dotnet/runtime/issues/10050: Update if this changes in the future.
if (typeof(TElement).IsValueType && equalityComparer == EqualityComparer<TElement>.Default)
{
for (int i = 0; i < nodes.Length; i++)
{
if (EqualityComparer<TElement>.Default.Equals(element, nodes[i].Element))
{
return i;
}
}
}
else
{
for (int i = 0; i < nodes.Length; i++)
{
if (equalityComparer.Equals(element, nodes[i].Element))
{
return i;
}
}
}

return -1;
}

/// <summary>
/// Initializes the custom comparer to be used internally by the heap.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public void PriorityQueue_EnumerableConstructor_ShouldContainAllElements(int cou

#endregion

#region Enqueue, Dequeue, Peek, EnqueueDequeue, DequeueEnqueue
#region Enqueue, Dequeue, Peek, EnqueueDequeue, DequeueEnqueue, Remove

[Theory]
[MemberData(nameof(ValidCollectionSizes))]
Expand Down Expand Up @@ -246,6 +246,35 @@ public void PriorityQueue_DequeueEnqueue(int count)
AssertExtensions.CollectionEqual(expectedItems, queue.UnorderedItems, EqualityComparer<(TElement, TPriority)>.Default);
}

[Theory]
[MemberData(nameof(ValidCollectionSizes))]
public void PriorityQueue_Remove_AllElements(int count)
{
bool result;
TElement removedElement;
TPriority removedPriority;

PriorityQueue<TElement, TPriority> queue = CreatePriorityQueue(count, count, out List<(TElement element, TPriority priority)> generatedItems);

for (int i = count - 1; i >= 0; i--)
{
(TElement element, TPriority priority) = generatedItems[i];

result = queue.Remove(element, out removedElement, out removedPriority);

Assert.True(result);
Assert.Equal(element, removedElement);
Assert.Equal(priority, removedPriority);
Assert.Equal(i, queue.Count);
}

result = queue.Remove(default, out removedElement, out removedPriority);

Assert.False(result);
Assert.Equal(default, removedElement);
Assert.Equal(default, removedPriority);
}

#endregion

#region Clear
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
using Xunit;
using NodeId = int;
using Distance = int;

namespace System.Collections.Tests
{
public partial class PriorityQueue_NonGeneric_Tests
{
public record struct Graph(Edge[][] nodes);
public record struct Edge(NodeId neighbor, Distance weight);

[Fact]
public static void PriorityQueue_DijkstraSmokeTest()
{
var graph = new Graph([
[new Edge(1, 7), new Edge(2, 9), new Edge(5, 14)],
[new Edge(0, 7), new Edge(2, 10), new Edge(3, 15)],
[new Edge(0, 9), new Edge(1, 10), new Edge(3, 11), new Edge(5, 2)],
[new Edge(1, 15), new Edge(2, 11), new Edge(4, 6)],
[new Edge(3, 6), new Edge(5, 9)],
[new Edge(0, 14), new Edge(2, 2), new Edge(4, 9)],
]);

NodeId startNode = 0;

(NodeId node, Distance distance)[] expectedDistances =
[
(0, 0),
(1, 7),
(2, 9),
(3, 20),
(4, 20),
(5, 11),
];

(NodeId node, Distance distance)[] actualDistances = RunDijkstra(graph, startNode);

Assert.Equal(expectedDistances, actualDistances);
}

public static (NodeId node, Distance distance)[] RunDijkstra(Graph graph, NodeId startNode)
{
Distance[] distances = Enumerable.Repeat(int.MaxValue, graph.nodes.Length).ToArray();
var queue = new PriorityQueue<NodeId, Distance>();

distances[startNode] = 0;
queue.Enqueue(startNode, 0);

do
{
NodeId nodeId = queue.Dequeue();
Distance nodeDistance = distances[nodeId];

foreach (Edge edge in graph.nodes[nodeId])
{
Distance distance = distances[edge.neighbor];
Distance newDistance = nodeDistance + edge.weight;
if (newDistance < distance)
{
distances[edge.neighbor] = newDistance;
// Simulate priority update by attempting to remove the entry
// before re-inserting it with the new distance.
queue.Remove(edge.neighbor, out _, out _);
queue.Enqueue(edge.neighbor, newDistance);
}
}
}
while (queue.Count > 0);

return distances.Select((distance, nodeId) => (nodeId, distance)).ToArray();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace System.Collections.Tests
{
public class PriorityQueue_NonGeneric_Tests : TestBase
public partial class PriorityQueue_NonGeneric_Tests : TestBase
{
protected PriorityQueue<string, int> CreateSmallPriorityQueue(out HashSet<(string, int)> items)
{
Expand Down Expand Up @@ -167,6 +167,55 @@ public void PriorityQueue_Generic_EnqueueRange_Null()
Assert.Equal("not null", queue.Dequeue());
}

[Fact]
public void PriorityQueue_Generic_Remove_MatchingElement()
{
PriorityQueue<string, int> queue = new PriorityQueue<string, int>();
queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]);

Assert.True(queue.Remove("value1", out string removedElement, out int removedPriority));
Assert.Equal("value1", removedElement);
Assert.Equal(1, removedPriority);
Assert.Equal(2, queue.Count);
}

[Fact]
public void PriorityQueue_Generic_Remove_MismatchElement()
{
PriorityQueue<string, int> queue = new PriorityQueue<string, int>();
queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]);

Assert.False(queue.Remove("value4", out string removedElement, out int removedPriority));
Assert.Null(removedElement);
Assert.Equal(0, removedPriority);
Assert.Equal(3, queue.Count);
}

[Fact]
public void PriorityQueue_Generic_Remove_DuplicateElement()
{
PriorityQueue<string, int> queue = new PriorityQueue<string, int>();
queue.EnqueueRange([("value0", 0), ("value1", 1), ("value0", 2)]);

Assert.True(queue.Remove("value0", out string removedElement, out int removedPriority));
Assert.Equal("value0", removedElement);
Assert.True(removedPriority is 0 or 2);
Assert.Equal(2, queue.Count);
}

[Fact]
public void PriorityQueue_Generic_Remove_CustomEqualityComparer()
{
PriorityQueue<string, int> queue = new PriorityQueue<string, int>();
queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]);
EqualityComparer<string> equalityComparer = EqualityComparer<string>.Create((left, right) => left[^1] == right[^1]);

Assert.True(queue.Remove("someOtherValue1", out string removedElement, out int removedPriority, equalityComparer));
Assert.Equal("value1", removedElement);
Assert.Equal(1, removedPriority);
Assert.Equal(2, queue.Count);
}

[Fact]
public void PriorityQueue_Constructor_int_Negative_ThrowsArgumentOutOfRangeException()
{
Expand Down Expand Up @@ -207,6 +256,16 @@ public void PriorityQueue_EmptyCollection_Peek_ShouldReturnFalse()
Assert.Throws<InvalidOperationException>(() => queue.Peek());
}

[Fact]
public void PriorityQueue_EmptyCollection_Remove_ShouldReturnFalse()
{
var queue = new PriorityQueue<string, string>();

Assert.False(queue.Remove(element: "element", out string removedElement, out string removedPriority));
Assert.Null(removedElement);
Assert.Null(removedPriority);
}

#region EnsureCapacity, TrimExcess

[Fact]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>$(NetCoreAppCurrent)</TargetFramework>
<TestRuntime>true</TestRuntime>
Expand Down Expand Up @@ -106,6 +106,7 @@
<Compile Include="Generic\PriorityQueue\PriorityQueue.Generic.Tests.cs" />
<Compile Include="Generic\PriorityQueue\PriorityQueue.PropertyTests.cs" />
<Compile Include="Generic\PriorityQueue\PriorityQueue.Tests.cs" />
<Compile Include="Generic\PriorityQueue\PriorityQueue.Tests.Dijkstra.cs" />
<Compile Include="Generic\Queue\Queue.Generic.cs" />
<Compile Include="Generic\Queue\Queue.Generic.Tests.cs" />
<Compile Include="Generic\Queue\Queue.Tests.cs" />
Expand Down

0 comments on commit c1f4341

Please sign in to comment.