When our connection drops, ReceiveAsync
is throwing WebSocketException
(ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely
).
The issue is that it's not handled by Polly for some reason. I believe it doesn't handle it, because it's in a separate Task
, although I'm doing Task.WhenAny
.
The expected behavior is to trigger the reconnect if WebSocketException
is thrown.
public sealed class ChannelWebSocketClient : IDisposable
{
private readonly Uri _uri;
private readonly ILogger<ChannelWebSocketClient> _logger;
private readonly Channel<string> _output;
private CancellationTokenSource? _cancellationTokenSource;
public ChannelWebSocketClient(Uri uri, ILoggerFactory loggerFactory)
{
_uri = uri ?? throw new ArgumentNullException(nameof(uri));
_logger = loggerFactory.CreateLogger<ChannelWebSocketClient>();
_output = Channel.CreateUnbounded<string>(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = false
});
}
public void Dispose()
{
_output.Writer.TryComplete();
}
public Task StartAsync()
{
return Policy.Handle<Exception>(ex => ex is not (TaskCanceledException or OperationCanceledException))
.WaitAndRetryForeverAsync(
(_, _) => TimeSpan.FromSeconds(5),
(ex, retryCount, calculatedWaitDuration, _) => { _logger.LogError(ex, "Unable to connect to the web socket server. Retry count: {RetryCount} | Retry in {Seconds} seconds", retryCount, calculatedWaitDuration.TotalSeconds); })
.ExecuteAsync(ConnectAsync);
}
public void Stop()
{
_cancellationTokenSource?.Cancel();
}
private async Task ConnectAsync()
{
_logger.LogDebug("Connecting");
using var ws = new ClientWebSocket();
// WebSocketException, TaskCanceledException
await ws.ConnectAsync(_uri, CancellationToken.None).ConfigureAwait(false);
_logger.LogDebug("Connected to {Host}", _uri.AbsoluteUri);
_cancellationTokenSource = new CancellationTokenSource();
var receiving = ReceiveLoopAsync(ws, _cancellationTokenSource.Token);
var sending = SendLoopAsync(ws, _cancellationTokenSource.Token);
var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false);
if (trigger == receiving)
{
_cancellationTokenSource?.Cancel();
await sending.ConfigureAwait(false);
}
_logger.LogDebug("END");
}
public async Task SendAsync(string message)
{
await _output.Writer.WriteAsync(message, CancellationToken.None).ConfigureAwait(false);
}
private async Task SendLoopAsync(WebSocket webSocket, CancellationToken cancellationToken)
{
_logger.LogDebug("SendLoopAsync BEGIN");
try
{
while (await _output.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (_output.Reader.TryRead(out var message))
{
// WebSocketException, TaskCanceledException, ObjectDisposedException
await webSocket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(message)),
WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException)
{
}
finally
{
_logger.LogDebug("SendLoopAsync END");
}
}
private async Task ReceiveLoopAsync(WebSocket webSocket, CancellationToken cancellationToken)
{
_logger.LogDebug("ReceiveLoopAsync BEGIN");
try
{
while (true)
{
ValueWebSocketReceiveResult receiveResult;
using var buffer = MemoryPool<byte>.Shared.Rent(4096);
await using var ms = new MemoryStream(buffer.Memory.Length);
do
{
// WebSocketException, TaskCanceledException, ObjectDisposedException
receiveResult = await webSocket.ReceiveAsync(buffer.Memory, cancellationToken).ConfigureAwait(false);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
break;
}
await ms.WriteAsync(buffer.Memory[..receiveResult.Count], cancellationToken).ConfigureAwait(false);
} while (!receiveResult.EndOfMessage);
ms.Seek(0, SeekOrigin.Begin);
if (receiveResult.MessageType == WebSocketMessageType.Text)
{
using var reader = new StreamReader(ms, Encoding.UTF8);
var message = await reader.ReadToEndAsync().ConfigureAwait(false);
_logger.LogInformation("Message received: {Message}", message);
}
else if (receiveResult.MessageType == WebSocketMessageType.Close)
{
break;
}
}
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.LogError(ex, "");
throw;
}
finally
{
_logger.LogDebug("ReceiveLoopAsync END");
}
}
}
CodePudding user response:
The Task.WhenAll
works differently than Task.WhenAny
.
- Former throws exception is any of the tasks failed with an exception
- Latter does not throw exception even if all of the tasks fail
So either you use call two twice the .GetAwaiter().GetResult()
since WhenAny
returns a Task<Task>
Task.WhenAny(receiving, sending).ConfigureAwait(false)
.GetAwaiter().GetResult()
.GetAwaiter().GetResult();
Or you can re-throw the exception
var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false);
if (trigger.Exception != null)
{
throw trigger.Exception;
}
None of these solutions is perfect, but they will trigger your policy.
UPDATE #1
As Monsieur Merso pointed out you can call twice await
await await Task.WhenAny(receiving, sending).ConfigureAwait(false);
This is much better than the above two approaches.
UPDATE #2
If you want to
- trigger the policy if faster task failed
- or want to know which one has finished sooner with success
then you can "avoid" the double await
var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false);
await trigger; //Throws exception if the faster Task has failed
if (trigger == receiving) //Determines which Task finished sooner
{
}