Home > OS >  async call in async void EventHandler leads to a deadlock
async call in async void EventHandler leads to a deadlock

Time:04-05

Is there a way to call SendAsync in OnConnect without leading to a deadlock? I'm not using .Wait or .Result and it still leads to a deadlock.

Edit:

The actual problem is that SendAsync is being called twice (once at OnConnect and once at Main). If I put a await Task.Delay(10000) before the second call in Main, it actually works good. How can I fix it? If there is no task delay, it basically hangs on await tcs.Task.ConfigureAwait(false), because it's being called twice and async void OnConnect is kinda "fire and forget", meaning that it's not waiting for the first SendAsync to complete, before it goes for the second call.

// Program.cs
var client = new Client(key, secret);

await client.StartAsync().ConfigureAwait(false);

await Task.Delay(3000); // This line fixes it, but it's kinda fake fix

await client.SendAsync(request).ConfigureAwait(false);
await client.SendAsync(request2).ConfigureAwait(false);

Console.ReadLine();

// Client.cs
public class Client
{
    private static long _nextId;
    private readonly WebSocketClient _webSocket;
    private readonly ConcurrentDictionary<long, TaskCompletionSource<string>> _outstandingRequests = new();

    ...

    public event EventHandler<ConnectEventArgs>? Connected;
    public event EventHandler<MessageReceivedEventArgs>? MessageReceived;

    public ValueTask StartAsync()
    {
        _client.Connected  = OnConnect;
        _client.MessageReceived  = OnMessageReceived;

        return _webSocket.StartAsync();  // there is a long-running `Task.Run` inside it, which keeps the web socket connection and its pipelines open.
    }

    private async void OnConnect(object? sender, ConnectEventArgs e)
    {
        await AuthAsync(...); // the problematic line
    }

    private void OnMessageReceived(object? sender, MessageReceivedEventArgs e)
    {
        ... deserialization stuff

        if (_requests.TryRemove(response.Id, out var tcs))
        {
            tcs.TrySetResult(message);
        }
    }

    public ValueTask<TResponse?> SendAsync<TResponse>(JsonRpcRequest request)
    {
        var tcs = new TaskCompletionSource<string>(TaskCreationOptions.RunContinuationsAsynchronously);
        _requests.TryAdd(request.Id, tcs);
        return SendRequestAndWaitForResponseAsync();

        async ValueTask<TResponse?> SendRequestAndWaitForResponseAsync()
        {
            var message = JsonSerializer.Serialize(request);
            await _client.SendAsync(message).ConfigureAwait(false);
            var response = await tcs.Task.ConfigureAwait(false); // it hangs here (deadlock)

            return JsonSerializer.Deserialize<TResponse>(response);
        }
    }

    public ValueTask<JsonRpcResponse?> AuthAsync(JsonRpcRequest request)
    {
        return SendAsync<JsonRpcResponse>(request);
    }

    private static long NextId()
    {
        return Interlocked.Increment(ref _nextId);
    }
}
public sealed class WebSocketClient
{
    private readonly AsyncManualResetEvent _sendSemaphore = new(false); // Nito.AsyncEx
    private readonly WebSocketPipe _webSocket; // SignalR Web Socket Pipe

    ...

    public event EventHandler<ConnectEventArgs>? Connected;
    public event EventHandler<DisconnectEventArgs>? Disconnected;
    public event EventHandler<MessageReceivedEventArgs>? MessageReceived;

    public ValueTask StartAsync()
    {
        _ = Task.Run(async () =>
        {
            try
            {
                await CreatePolicy()
                    .ExecuteAsync(async () =>
                    {
                        await _webSocket.StartAsync(new Uri(_url), CancellationToken.None).ConfigureAwait(false);

                        Connected?.Invoke(this, new ConnectEventArgs());

                        _sendSemaphore.Set();

                        await ReceiveLoopAsync().ConfigureAwait(false);
                    })
                    .ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                // Failed after all retries
                Disconnected?.Invoke(this, new DisconnectEventArgs(ex));
            }
        });

        return ValueTask.CompletedTask;
    }

    public async ValueTask SendAsync(string message)
    {
        await _sendSemaphore.WaitAsync().ConfigureAwait(false);

        var encoded = Encoding.UTF8.GetBytes(message);
        await _webSocket.Transport!.Output
            .WriteAsync(new ArraySegment<byte>(encoded, 0, encoded.Length), CancellationToken.None)
            .ConfigureAwait(false);
    }

    private IAsyncPolicy CreatePolicy()
    {
        var retryPolicy = Policy
            .Handle<WebSocketException>()
            .WaitAndRetryForeverAsync(_ => ReconnectInterval,
                (exception, retryCount, calculatedWaitDuration) =>
                {
                    _sendSemaphore.Reset();

                    Reconnecting?.Invoke(this, new ReconnectingEventArgs(exception, retryCount, calculatedWaitDuration));

                    return Task.CompletedTask;
                });

        return retryPolicy;
    }

    private async Task ReceiveLoopAsync()
    {
        while (true)
        {
            var result = await _webSocket.Transport!.Input.ReadAsync(CancellationToken.None).ConfigureAwait(false);
            var buffer = result.Buffer;

            ...
        }
    }
}

CodePudding user response:

As mentioned in the comments, those web socket wrappers are using System.IO.Pipelines, which is incorrect. System.IO.Pipelines is a stream of bytes, so it's appropriate for (non-web) sockets; a web socket is a stream of messages, so something like System.Threading.Channels would be more appropriate.

You could try something like this, which I just typed up and haven't even run:

public sealed class ChannelWebSocket : IDisposable
{
    private readonly WebSocket _webSocket;
    private readonly Channel<Message> _input;
    private readonly Channel<Message> _output;

    public ChannelWebSocket(WebSocket webSocket, Options options)
    {
        _webSocket = webSocket;
        _input = Channel.CreateBounded(new BoundedChannelOptions(options.InputCapacity)
        {
            FullMode = options.InputFullMode,
        }, options.InputMessageDropped);
        _output = Channel.CreateBounded(new BoundedChannelOptions(options.OutputCapacity)
        {
            FullMode = options.OutputFullMode,
        }, options.OutputMessageDropped);
    }

    public ChannelReader<Message> Input => _input.Reader;
    public ChannelWriter<Message> Output => _output.Writer;

    public void Dispose() => _webSocket.Dispose();

    public async void Start()
    {
        var inputTask = InputLoopAsync(default);
        var outputTask = OutputLoopAsync(default);

        var completedTask = await Task.WhenAny(inputTask, outputTask);

        if (completedTask.Exception != null)
        {
            try { await _webSocket.CloseAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default); } catch { /* ignore */ }
            try { _input.Writer.Complete(completedTask.Exception); } catch { /* ignore */ }
            try { _output.Writer.Complete(completedTask.Exception); } catch { /* ignore */ }
        }
    }

    public sealed class Message
    {
        public WebSocketMessageType MessageType { get; set; }
        public OwnedMemorySequence<byte> Payload { get; set; } = null!;
    }

    private async Task InputLoopAsync(CancellationToken cancellationToken)
    {
        while (true)
        {
            var payload = new OwnedMemorySequence<byte>();
            var buffer = MemoryPool<byte>.Shared.Rent();

            ValueWebSocketReceiveResult result;
            do
            {
                result = await _webSocket.ReceiveAsync(buffer.Memory, cancellationToken);
                if (result.MessageType == WebSocketMessageType.Close)
                {
                    _input.Writer.Complete();
                    return;
                }

                payload.Append(buffer.Slice(0, result.Count));
            } while (!result.EndOfMessage);

            await _input.Writer.WriteAsync(new Message
            {
                MessageType = result.MessageType,
                Payload = payload,
            }, cancellationToken);
        }
    }

    private async Task OutputLoopAsync(CancellationToken cancellationToken)
    {
        await foreach (var message in _output.Reader.ReadAllAsync())
        {
            var sequence = message.Payload.ReadOnlySequence;
            if (sequence.IsEmpty)
                continue;

            while (!sequence.IsSingleSegment)
            {
                await _webSocket.SendAsync(sequence.First, message.MessageType, endOfMessage: false, cancellationToken);
                sequence = sequence.Slice(sequence.First.Length);
            }

            await _webSocket.SendAsync(sequence.First, message.MessageType, endOfMessage: true, cancellationToken);
            message.Payload.Dispose();
        }

        await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, statusDescription: null, cancellationToken);
    }

    public sealed class Options
    {
        public int InputCapacity { get; set; } = 16;
        public BoundedChannelFullMode InputFullMode { get; set; } = BoundedChannelFullMode.Wait;
        public Action<Message>? InputMessageDropped { get; set; }

        public int OutputCapacity { get; set; } = 16;
        public BoundedChannelFullMode OutputFullMode { get; set; } = BoundedChannelFullMode.Wait;
        public Action<Message>? OutputMessageDropped { get; set; }
    }
}

It uses this type for building a memory sequence:

public sealed class MemorySequence<T>
{
    private MemorySegment? _head;
    private MemorySegment? _tail;

    public MemorySequence<T> Append(ReadOnlyMemory<T> buffer)
    {
        if (_tail == null)
            _head = _tail = new MemorySegment(buffer, runningIndex: 0);
        else
            _tail = _tail.Append(buffer);
        return this;
    }

    public ReadOnlySequence<T> ReadOnlySequence => CreateReadOnlySequence(0, _tail?.Memory.Length ?? 0);

    public ReadOnlySequence<T> CreateReadOnlySequence(int firstBufferStartIndex, int lastBufferEndIndex) =>
        _tail == null ? new ReadOnlySequence<T>(Array.Empty<T>()) :
        new ReadOnlySequence<T>(_head!, firstBufferStartIndex, _tail, lastBufferEndIndex);

    private sealed class MemorySegment : ReadOnlySequenceSegment<T>
    {
        public MemorySegment(ReadOnlyMemory<T> memory, long runningIndex)
        {
            Memory = memory;
            RunningIndex = runningIndex;
        }

        public MemorySegment Append(ReadOnlyMemory<T> nextMemory)
        {
            var next = new MemorySegment(nextMemory, RunningIndex   Memory.Length);
            Next = next;
            return next;
        }
    }
}

and this type for building an owned memory sequence:

public sealed class OwnedMemorySequence<T> : IDisposable
{
    private readonly CollectionDisposable _disposable = new();
    private readonly MemorySequence<T> _sequence = new();

    public OwnedMemorySequence<T> Append(IMemoryOwner<T> memoryOwner)
    {
        _disposable.Add(memoryOwner);
        _sequence.Append(memoryOwner.Memory);
        return this;
    }

    public ReadOnlySequence<T> ReadOnlySequence => _sequence.ReadOnlySequence;

    public ReadOnlySequence<T> CreateReadOnlySequence(int firstBufferStartIndex, int lastBufferEndIndex) =>
        _sequence.CreateReadOnlySequence(firstBufferStartIndex, lastBufferEndIndex);

    public void Dispose() => _disposable.Dispose();
}

which depends on owned memory span extension methods I stole from here:

public static class MemoryOwnerSliceExtensions
{
    public static IMemoryOwner<T> Slice<T>(this IMemoryOwner<T> owner, int start, int length)
    {
        if (start == 0 && length == owner.Memory.Length)
            return owner;
        return new SliceOwner<T>(owner, start, length);
    }

    public static IMemoryOwner<T> Slice<T>(this IMemoryOwner<T> owner, int start)
    {
        if (start == 0)
            return owner;
        return new SliceOwner<T>(owner, start);
    }

    private sealed class SliceOwner<T> : IMemoryOwner<T>
    {
        private readonly IMemoryOwner<T> _owner;
        public Memory<T> Memory { get; }

        public SliceOwner(IMemoryOwner<T> owner, int start, int length)
        {
            _owner = owner;
            Memory = _owner.Memory.Slice(start, length);
        }

        public SliceOwner(IMemoryOwner<T> owner, int start)
        {
            _owner = owner;
            Memory = _owner.Memory[start..];
        }

        public void Dispose() => _owner.Dispose();
    }
}

This code is all completely untested; use at your own risk.

  • Related