I am trying to implement an asynchronous method that takes an array of ChannelReader<T>
s, and takes a value from any of the channels that has an item available. It is a method with similar functionality with the BlockingCollection<T>.TakeFromAny
method, that has this signature:
public static int TakeFromAny(BlockingCollection<T>[] collections, out T item,
CancellationToken cancellationToken);
This method returns the index in the collections
array from which the item was removed. An async
method cannot have out
parameters, so the API that I am trying to implement is this:
public static Task<(T Item, int Index)> TakeFromAnyAsync<T>(
ChannelReader<T>[] channelReaders,
CancellationToken cancellationToken = default);
The TakeFromAnyAsync<T>
method should read asynchronously an item, and return the consumed item along with the index of the associated channel in the channelReaders
array. In case all the channels are completed (either successfully or with an error), or all become complete during the await
, the method should throw asynchronously a ChannelClosedException
.
My question is: how can I implement the TakeFromAnyAsync<T>
method? The implementation looks quite tricky. It is obvious that under no circumstances the method should consume more than one items from the channels. Also it should not leave behind fire-and-forget tasks, or let disposable resources undisposed. The method will be typically called in a loop, so it should also be reasonably efficient. It should have complexity not worse than O(n), where n
in the number of the channels.
As an insight of where this method can be useful, you could take a look at the select
statement of the Go language. From the tour:
The
select
statement lets a goroutine wait on multiple communication operations.A
select
blocks until one of its cases can run, then it executes that case. It chooses one at random if multiple are ready.
select {
case msg1 := <-c1:
fmt.Println("received", msg1)
case msg2 := <-c2:
fmt.Println("received", msg2)
}
In the above example either a value will be taken from the channel c1
and assigned to the variable msg1
, or a value will be taken from the channel c2
and assigned to the variable msg2
. The Go select
statement is not restricted to reading from channels. It can include multiple heterogeneous cases like writing to bounded channels, waiting for timers etc. Replicating the full functionality of the Go select
statement is beyond the scope of this question.
CodePudding user response:
I came up with something like this:
public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
ChannelReader<T>[] channelReaders,
CancellationToken cancellationToken = default)
{
if (channelReaders == null)
{
throw new ArgumentNullException(nameof(channelReaders));
}
if (channelReaders.Length == 0)
{
throw new ArgumentException("The list cannot be empty.", nameof(channelReaders));
}
if (channelReaders.Length == 1)
{
return (await channelReaders[0].ReadAsync(cancellationToken), 0);
}
// First attempt to read an item synchronosuly
for (int i = 0; i < channelReaders.Length; i)
{
if (channelReaders[i].TryRead(out var item))
{
return (item, i);
}
}
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
var waitToReadTasks = channelReaders
.Select(it => it.WaitToReadAsync(cts.Token).AsTask())
.ToArray();
var pendingTasks = new List<Task<bool>>(waitToReadTasks);
while (pendingTasks.Count > 1)
{
var t = await Task.WhenAny(pendingTasks);
if (t.IsCompletedSuccessfully && t.Result)
{
int index = Array.IndexOf(waitToReadTasks, t);
var reader = channelReaders[index];
// Attempt to read an item synchronosly
if (reader.TryRead(out var item))
{
if (pendingTasks.Count > 1)
{
// Cancel pending "wait to read" on the remaining readers
// then wait for the completion
try
{
cts.Cancel();
await Task.WhenAll((IEnumerable<Task>)pendingTasks);
}
catch { }
}
return (item, index);
}
// Due to the race condition item is no longer available
if (!reader.Completion.IsCompleted)
{
// .. but the channel appears to be still open, so we retry
var waitToReadTask = reader.WaitToReadAsync(cts.Token).AsTask();
waitToReadTasks[index] = waitToReadTask;
pendingTasks.Add(waitToReadTask);
}
}
// Remove all completed tasks that could not yield
pendingTasks.RemoveAll(tt => tt == t ||
tt.IsCompletedSuccessfully && !tt.Result ||
tt.IsFaulted || tt.IsCanceled);
}
int lastIndex = 0;
if (pendingTasks.Count > 0)
{
lastIndex = Array.IndexOf(waitToReadTasks, pendingTasks[0]);
await pendingTasks[0];
}
var lastItem = await channelReaders[lastIndex].ReadAsync(cancellationToken);
return (lastItem, lastIndex);
}
}