Home > database >  Polly doesn't handle an exception in a task because of Task.WhenAny
Polly doesn't handle an exception in a task because of Task.WhenAny

Time:04-28

When our connection drops, ReceiveAsync is throwing WebSocketException (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely).

The issue is that it's not handled by Polly for some reason. I believe it doesn't handle it, because it's in a separate Task, although I'm doing Task.WhenAny.

The expected behavior is to trigger the reconnect if WebSocketException is thrown.

public sealed class ChannelWebSocketClient : IDisposable
{
    private readonly Uri _uri;
    private readonly ILogger<ChannelWebSocketClient> _logger;
    private readonly Channel<string> _output;
    private CancellationTokenSource? _cancellationTokenSource;

    public ChannelWebSocketClient(Uri uri, ILoggerFactory loggerFactory)
    {
        _uri = uri ?? throw new ArgumentNullException(nameof(uri));
        _logger = loggerFactory.CreateLogger<ChannelWebSocketClient>();

        _output = Channel.CreateUnbounded<string>(new UnboundedChannelOptions
        {
            SingleReader = true,
            SingleWriter = false
        });
    }

    public void Dispose()
    {
        _output.Writer.TryComplete();
    }

    public Task StartAsync()
    {
        return Policy.Handle<Exception>(ex => ex is not (TaskCanceledException or OperationCanceledException))
            .WaitAndRetryForeverAsync(
                (_, _) => TimeSpan.FromSeconds(5),
                (ex, retryCount, calculatedWaitDuration, _) => { _logger.LogError(ex, "Unable to connect to the web socket server. Retry count: {RetryCount} | Retry in {Seconds} seconds", retryCount, calculatedWaitDuration.TotalSeconds); })
            .ExecuteAsync(ConnectAsync);
    }

    public void Stop()
    {
        _cancellationTokenSource?.Cancel();
    }

    private async Task ConnectAsync()
    {
        _logger.LogDebug("Connecting");

        using var ws = new ClientWebSocket();

        // WebSocketException, TaskCanceledException
        await ws.ConnectAsync(_uri, CancellationToken.None).ConfigureAwait(false);

        _logger.LogDebug("Connected to {Host}", _uri.AbsoluteUri);

        _cancellationTokenSource = new CancellationTokenSource();

        var receiving = ReceiveLoopAsync(ws, _cancellationTokenSource.Token);
        var sending = SendLoopAsync(ws, _cancellationTokenSource.Token);

        var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false);

        if (trigger == receiving)
        {
            _cancellationTokenSource?.Cancel();

            await sending.ConfigureAwait(false);
        }

        _logger.LogDebug("END");
    }

    public async Task SendAsync(string message)
    {
        await _output.Writer.WriteAsync(message, CancellationToken.None).ConfigureAwait(false);
    }

    private async Task SendLoopAsync(WebSocket webSocket, CancellationToken cancellationToken)
    {
        _logger.LogDebug("SendLoopAsync BEGIN");

        try
        {
            while (await _output.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
            {
                while (_output.Reader.TryRead(out var message))
                {
                    // WebSocketException, TaskCanceledException, ObjectDisposedException
                    await webSocket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(message)),
                        WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false);
                }
            }
        }
        catch (OperationCanceledException)
        {
        }
        finally
        {
            _logger.LogDebug("SendLoopAsync END");
        }
    }

    private async Task ReceiveLoopAsync(WebSocket webSocket, CancellationToken cancellationToken)
    {
        _logger.LogDebug("ReceiveLoopAsync BEGIN");

        try
        {
            while (true)
            {
                ValueWebSocketReceiveResult receiveResult;

                using var buffer = MemoryPool<byte>.Shared.Rent(4096);
                await using var ms = new MemoryStream(buffer.Memory.Length);
                do
                {
                    // WebSocketException, TaskCanceledException, ObjectDisposedException
                    receiveResult = await webSocket.ReceiveAsync(buffer.Memory, cancellationToken).ConfigureAwait(false);

                    if (receiveResult.MessageType == WebSocketMessageType.Close)
                    {
                        break;
                    }

                    await ms.WriteAsync(buffer.Memory[..receiveResult.Count], cancellationToken).ConfigureAwait(false);
                } while (!receiveResult.EndOfMessage);

                ms.Seek(0, SeekOrigin.Begin);

                if (receiveResult.MessageType == WebSocketMessageType.Text)
                {
                    using var reader = new StreamReader(ms, Encoding.UTF8);
                    var message = await reader.ReadToEndAsync().ConfigureAwait(false);

                    _logger.LogInformation("Message received: {Message}", message);
                }
                else if (receiveResult.MessageType == WebSocketMessageType.Close)
                {
                    break;
                }
            }
        }
        catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
        {
            _logger.LogError(ex, "");
            throw;
        }
        finally
        {
            _logger.LogDebug("ReceiveLoopAsync END");
        }
    }
}

CodePudding user response:

The Task.WhenAll works differently than Task.WhenAny.

  • Former throws exception is any of the tasks failed with an exception
  • Latter does not throw exception even if all of the tasks fail

So either you use call two twice the .GetAwaiter().GetResult() since WhenAny returns a Task<Task>

Task.WhenAny(receiving, sending).ConfigureAwait(false)
   .GetAwaiter().GetResult()
   .GetAwaiter().GetResult();

Or you can re-throw the exception

var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false);
if (trigger.Exception != null)
{
    throw trigger.Exception;
}

None of these solutions is perfect, but they will trigger your policy.


UPDATE #1

As Monsieur Merso pointed out you can call twice await

await await Task.WhenAny(receiving, sending).ConfigureAwait(false);

This is much better than the above two approaches.


UPDATE #2

If you want to

  • trigger the policy if faster task failed
  • or want to know which one has finished sooner with success

then you can "avoid" the double await

var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); 

await trigger; //Throws exception if the faster Task has failed

if (trigger == receiving) //Determines which Task finished sooner
{

}
  • Related