Home > Software design >  Elegant way to get a task for async code without running the task immediately
Elegant way to get a task for async code without running the task immediately

Time:01-30

I have the following code that does what I want but I had to resort to using .GetAwaiter().GetResult() in the middle of asynchronous code to get it. I am wondering if there is an elegant way to achieve this without resorting to such hacks.

This is a simplified version of the code I have.

public async Task<string[]> GetValues(int[] keys)
{
    List<int> keysNotYetActivelyRequested = null;
    // don't start the task at this point because the
    // keysNotYetActivelyRequested is not yet populated
    var taskToCreateWithoutStarting = new Task<Dictionary<int, string>>(
        () => GetValuesFromApi(keysNotYetActivelyRequested.ToArray())
        .GetAwaiter().GetResult() /*not the best idea*/);
    (var allTasksToAwait, keysNotYetActivelyRequested) = GetAllTasksToAwait(
        keys, taskToCreateWithoutStarting);

    if (keysNotYetActivelyRequested.Any())
    {
        // keysNotYetActivelyRequested will be empty when all keys
        // are already part of another active request
        taskToCreateWithoutStarting.Start(TaskScheduler.Current);
    }
    var allResults = await Task.WhenAll(allTasksToAwait);
    var theReturn = new string[keys.Length];
    for (int i = 0; i < keys.Length; i  )
    {
        foreach (var result in allResults)
        {
            if (result.TryGetValue(keys[i], out var value))
            {
                theReturn[i] = value;
            }
        }
    }
    if (keysNotYetActivelyRequested.Any())
    {
        taskToCreateWithoutStarting.Dispose();
    }
    return theReturn;
}

// all active requests indexed by the key, used to avoid generating
// multiple requests for the same key
private Dictionary<int, Task<Dictionary<int, string>>> _activeRequests = new();
private (HashSet<Task<Dictionary<int, string>>> allTasksToAwait,
    List<int> keysNotYetActivelyRequested) GetAllTasksToAwait(
    int[] keys, Task<Dictionary<int, string>> taskToCreateWithoutStarting)
{
    var keysNotYetActivelyRequested = new List<int>();
    // a HashSet because each task will have multiple keys hence _activeRequests
    // will have the same task multiple times
    var allTasksToAwait = new HashSet<Task<Dictionary<int, string>>>();

    // add cleanup to the task to remove the requested keys from _activeRequests
    // once it completes
    var taskWithCleanup = taskToCreateWithoutStarting.ContinueWith(_ =>
    {
        lock (_activeRequests)
        {
            foreach (var key in keysNotYetActivelyRequested)
            {
                _activeRequests.Remove(key);
            }
        }
    });
    lock (_activeRequests)
    {
        foreach (var key in keys)
        {
            // use CollectionsMarshal to avoid a lookup for the same key twice
            ref var refToTask = ref CollectionsMarshal.GetValueRefOrAddDefault(
                _activeRequests, key, out var exists);
            if (exists)
            {
                allTasksToAwait.Add(refToTask);
            }
            else
            {
                refToTask = taskToCreateWithoutStarting;
                allTasksToAwait.Add(taskToCreateWithoutStarting);
                keysNotYetActivelyRequested.Add(key);
            }
        }
    }
    return (allTasksToAwait, keysNotYetActivelyRequested);
}

// not the actual code
private async Task<Dictionary<int, string>> GetValuesFromApi(int[] keys)
{
    // request duration dependent on the number of keys
    await Task.Delay(keys.Length);
    return keys.ToDictionary(k => k, k => k.ToString());
}

And a test method:

[Test]
public void TestGetValues()
{
    var random = new Random();
    var allTasks = new Task[10];
    for (int i = 0; i < 10; i  )
    {
        var arrayofRandomInts = Enumerable.Repeat(random, random.Next(1, 100))
            .Select(r => r.Next(1, 100)).ToArray();
        allTasks[i] = GetValues(arrayofRandomInts);
    }
    Assert.DoesNotThrowAsync(() => Task.WhenAll(allTasks));
    Assert.That(_activeRequests.Count, Is.EqualTo(0));
}

CodePudding user response:

Instead of:

Task<Something> coldTask = new(() => GetAsync().GetAwaiter().GetResult());

You can do it like this:

Task<Task<Something>> coldTaskTask = new(() => GetAsync());
Task<Something> proxyTask = coldTaskTask.Unwrap();

The nested task coldTaskTask is the task that you will later Start (or RunSynchronously).

The unwrapped task proxyTask is a proxy that represents both the invocation of the GetAsync method, as well as the completion of the Task<Something> that this method generates.

  • Related