Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement backpressure logic. #73

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions src/GraphQL.AspNetCore3/WebSockets/AsyncMessagePump.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public AsyncMessagePump(Action<T> callback)
{
if (callback == null)
throw new ArgumentNullException(nameof(callback));
_callback = message => {
_callback = message =>
{
callback(message);
return Task.CompletedTask;
};
Expand All @@ -46,51 +47,75 @@ public AsyncMessagePump(Action<T> callback)
/// <summary>
/// Posts the specified message to the message queue.
/// </summary>
public void Post(T message)
=> Post(new ValueTask<T>(message));
public void Post(T message) => Post(new ValueTask<T>(message));

/// <summary>
/// Posts the result of an asynchronous operation to the message queue.
/// </summary>
public void Post(ValueTask<T> messageTask)
{
bool attach = false;
lock (_queue) {
lock (_queue)
{
_queue.Enqueue(messageTask);
attach = _queue.Count == 1;
}

if (attach) {
if (attach)
{
CompleteAsync();
}
}

/// <summary>
/// Returns the number of messages waiting in the queue.
/// Includes the message currently being processed, if any.
/// </summary>
public int Count
{
get
{
lock (_queue)
{
return _queue.Count;
}
}
}

/// <summary>
/// Processes message in the queue until it is empty.
/// </summary>
private async void CompleteAsync()
{
// grab the message at the start of the queue, but don't remove it from the queue
ValueTask<T> messageTask;
lock (_queue) {
lock (_queue)
{
// should always successfully peek from the queue here
#pragma warning disable CA2012 // Use ValueTasks correctly
messageTask = _queue.Peek();
#pragma warning restore CA2012 // Use ValueTasks correctly
}
while (true) {
while (true)
{
// process the message
try {
try
{
var message = await messageTask.ConfigureAwait(false);
await _callback(message).ConfigureAwait(false);
} catch (Exception ex) {
try {
}
catch (Exception ex)
{
try
{
await HandleErrorAsync(ex);
} catch { }
}
catch { }
}

// once the message has been passed along, dequeue it
lock (_queue) {
lock (_queue)
{
#pragma warning disable CA2012 // Use ValueTasks correctly
_ = _queue.Dequeue();
#pragma warning restore CA2012 // Use ValueTasks correctly
Expand All @@ -105,6 +130,5 @@ private async void CompleteAsync()
/// <summary>
/// Handles exceptions that occur within the asynchronous message delegate or the callback.
/// </summary>
protected virtual Task HandleErrorAsync(Exception exception)
=> Task.CompletedTask;
protected virtual Task HandleErrorAsync(Exception exception) => Task.CompletedTask;
}
7 changes: 7 additions & 0 deletions src/GraphQL.AspNetCore3/WebSockets/GraphQLWebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ public class GraphQLWebSocketOptions
/// Disconnects a subscription from the client there are any GraphQL errors during a subscription.
/// </summary>
public bool DisconnectAfterAnyError { get; set; }

/// <summary>
/// To help prevent backpressure from slower internet speeds, this will prevent the queue from expanding
/// beyond the max length.
/// The default is null (no limit). Value must be greater than 0.
/// </summary>
public int? MaxSendQueueThreshold { get; set; }
}
5 changes: 5 additions & 0 deletions src/GraphQL.AspNetCore3/WebSockets/IWebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public interface IWebSocketConnection : IDisposable
/// </summary>
Task SendMessageAsync(OperationMessage message);

/// <summary>
/// Sends a message. Option to ignoreMaxSendQueueThreshold and force a message.
/// </summary>
Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false);

/// <summary>
/// Closes the WebSocket connection, and
/// prevents further incoming messages from being dispatched through <see cref="IOperationMessageProcessor"/>.
Expand Down
135 changes: 103 additions & 32 deletions src/GraphQL.AspNetCore3/WebSockets/WebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class WebSocketConnection : IWebSocketConnection
private readonly WebSocketWriterStream _stream;
private readonly TaskCompletionSource<bool> _outputClosed = new();
private readonly int _closeTimeoutMs;
private readonly int? _maxSendQueueThreshold;
private volatile bool _closeRequested;
private int _executed;

Expand All @@ -44,18 +45,39 @@ public class WebSocketConnection : IWebSocketConnection
/// <summary>
/// Initializes an instance with the specified parameters.
/// </summary>
public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQLSerializer serializer, GraphQLWebSocketOptions options, CancellationToken cancellationToken)
public WebSocketConnection(
HttpContext httpContext,
WebSocket webSocket,
IGraphQLSerializer serializer,
GraphQLWebSocketOptions options,
CancellationToken cancellationToken
)
{
HttpContext = httpContext ?? throw new ArgumentNullException(nameof(httpContext));
if (options == null)
throw new ArgumentNullException(nameof(options));
if (options.DisconnectionTimeout.HasValue) {
if ((options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan && options.DisconnectionTimeout.Value.TotalMilliseconds < 0) || options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue)
if (options.DisconnectionTimeout.HasValue)
{
if (
(
options.DisconnectionTimeout.Value != Timeout.InfiniteTimeSpan
&& options.DisconnectionTimeout.Value.TotalMilliseconds < 0
)
|| options.DisconnectionTimeout.Value.TotalMilliseconds > int.MaxValue
)
#pragma warning disable CA2208 // Instantiate argument exceptions correctly
throw new ArgumentOutOfRangeException(nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout));
throw new ArgumentOutOfRangeException(
nameof(options) + "." + nameof(GraphQLWebSocketOptions.DisconnectionTimeout)
);
#pragma warning restore CA2208 // Instantiate argument exceptions correctly
}
_closeTimeoutMs = (int)(options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds;
_closeTimeoutMs = (int)
(options.DisconnectionTimeout ?? DefaultDisconnectionTimeout).TotalMilliseconds;
_maxSendQueueThreshold = options.MaxSendQueueThreshold;
if (_maxSendQueueThreshold != null && _maxSendQueueThreshold <= 0)
throw new ArgumentOutOfRangeException(
nameof(options) + "." + nameof(GraphQLWebSocketOptions.MaxSendQueueThreshold)
);
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
_stream = new(webSocket);
_serializer = serializer ?? throw new ArgumentNullException(nameof(serializer));
Expand All @@ -64,8 +86,7 @@ public WebSocketConnection(HttpContext httpContext, WebSocket webSocket, IGraphQ
}

/// <inheritdoc/>
public virtual void Dispose()
=> GC.SuppressFinalize(this);
public virtual void Dispose() => GC.SuppressFinalize(this);

/// <summary>
/// Listens to incoming messages on the WebSocket specified in the constructor,
Expand All @@ -77,8 +98,11 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa
if (operationMessageProcessor == null)
throw new ArgumentNullException(nameof(operationMessageProcessor));
if (Interlocked.Exchange(ref _executed, 1) == 1)
throw new InvalidOperationException($"{nameof(ExecuteAsync)} may only be called once per instance.");
try {
throw new InvalidOperationException(
$"{nameof(ExecuteAsync)} may only be called once per instance."
);
try
{
await operationMessageProcessor.InitializeConnectionAsync();
// set up a buffer in case a message is longer than one block
var receiveStream = new MemoryStream();
Expand All @@ -93,19 +117,23 @@ public virtual async Task ExecuteAsync(IOperationMessageProcessor operationMessa
// prep a reader stream
var bufferStream = new ReusableMemoryReaderStream(buffer);
// read messages until an exception occurs, the cancellation token is signaled, or a 'close' message is received
while (!RequestAborted.IsCancellationRequested) {
while (!RequestAborted.IsCancellationRequested)
{
var result = await _webSocket.ReceiveAsync(bufferMemory, RequestAborted);
if (result.MessageType == WebSocketMessageType.Close) {
if (result.MessageType == WebSocketMessageType.Close)
{
// prevent any more messages from being queued
operationMessageProcessor.Dispose();
// send a close request if none was sent yet
if (!_outputClosed.Task.IsCompleted) {
if (!_outputClosed.Task.IsCompleted)
{
// queue the closure
_ = CloseAsync();
// wait until the close has been sent
await Task.WhenAny(
_outputClosed.Task,
Task.Delay(_closeTimeoutMs, RequestAborted));
Task.Delay(_closeTimeoutMs, RequestAborted)
);
}
// quit
return;
Expand All @@ -114,39 +142,53 @@ await Task.WhenAny(
if (_closeRequested)
continue;
// if this is the last block terminating a message
if (result.EndOfMessage) {
if (result.EndOfMessage)
{
// if only one block of data was sent for this message
if (receiveStream.Length == 0) {
if (receiveStream.Length == 0)
{
// if the message is empty, skip to the next message
if (result.Count == 0)
continue;
// read the message
bufferStream.ResetLength(result.Count);
var message = await _serializer.ReadAsync<OperationMessage>(bufferStream, RequestAborted);
var message = await _serializer.ReadAsync<OperationMessage>(
bufferStream,
RequestAborted
);
// dispatch the message
if (message != null)
await OnDispatchMessageAsync(operationMessageProcessor, message);
} else {
}
else
{
// if there is any data in this block, add it to the buffer
if (result.Count > 0)
receiveStream.Write(buffer, 0, result.Count);
// read the message from the buffer
receiveStream.Position = 0;
var message = await _serializer.ReadAsync<OperationMessage>(receiveStream, RequestAborted);
var message = await _serializer.ReadAsync<OperationMessage>(
receiveStream,
RequestAborted
);
// clear the buffer
receiveStream.SetLength(0);
// dispatch the message
if (message != null)
await OnDispatchMessageAsync(operationMessageProcessor, message);
}
} else {
}
else
{
// if there is any data in this block, add it to the buffer
if (result.Count > 0)
receiveStream.Write(buffer, 0, result.Count);
}
}
} catch (WebSocketException) {
} finally {
}
catch (WebSocketException) { }
finally
{
// prevent any more messages from being sent
_outputClosed.TrySetResult(false);
// prevent any more messages from attempting to send
Expand All @@ -155,21 +197,39 @@ await Task.WhenAny(
}

/// <inheritdoc/>
public Task CloseAsync()
=> CloseAsync(1000, null);
public Task CloseAsync() => CloseAsync(1000, null);

/// <inheritdoc/>
public Task CloseAsync(int eventId, string? description)
{
_closeRequested = true;
_pump.Post(new Message { CloseStatus = (WebSocketCloseStatus)eventId, CloseDescription = description });
_pump.Post(
new Message
{
CloseStatus = (WebSocketCloseStatus)eventId,
CloseDescription = description
}
);
return Task.CompletedTask;
}

/// <inheritdoc/>
public Task SendMessageAsync(OperationMessage message)
{
_pump.Post(new Message { OperationMessage = message });
if (_maxSendQueueThreshold == null || _maxSendQueueThreshold.Value > _pump.Count)
_pump.Post(new Message { OperationMessage = message });
return Task.CompletedTask;
}

/// <inheritdoc/>
public Task SendMessageAsync(OperationMessage message, bool ignoreMaxSendQueueThreshold = false)
{
if (
ignoreMaxSendQueueThreshold
|| _maxSendQueueThreshold == null
|| _maxSendQueueThreshold.Value > _pump.Count
)
_pump.Post(new Message { OperationMessage = message });
return Task.CompletedTask;
}

Expand All @@ -186,9 +246,12 @@ private async Task HandleMessageAsync(Message message)
if (_outputClosed.Task.IsCompleted)
return;
LastMessageSentAt = DateTime.UtcNow;
if (message.OperationMessage != null) {
if (message.OperationMessage != null)
{
await OnSendMessageAsync(message.OperationMessage);
} else {
}
else
{
await OnCloseOutputAsync(message.CloseStatus, message.CloseDescription);
_outputClosed.TrySetResult(true);
}
Expand All @@ -200,8 +263,10 @@ private async Task HandleMessageAsync(Message message)
/// <br/><br/>
/// This method is synchronized and will wait until completion before dispatching another message.
/// </summary>
protected virtual Task OnDispatchMessageAsync(IOperationMessageProcessor operationMessageProcessor, OperationMessage message)
=> operationMessageProcessor.OnMessageReceivedAsync(message);
protected virtual Task OnDispatchMessageAsync(
IOperationMessageProcessor operationMessageProcessor,
OperationMessage message
) => operationMessageProcessor.OnMessageReceivedAsync(message);

/// <summary>
/// Sends the specified message to the underlying <see cref="WebSocket"/>.
Expand All @@ -221,14 +286,20 @@ protected virtual async Task OnSendMessageAsync(OperationMessage message)
/// <br/><br/>
/// This method is synchronized and will wait until completion before sending another message or closing the output stream.
/// </summary>
protected virtual Task OnCloseOutputAsync(WebSocketCloseStatus closeStatus, string? closeDescription)
=> _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted);
protected virtual Task OnCloseOutputAsync(
WebSocketCloseStatus closeStatus,
string? closeDescription
) => _webSocket.CloseOutputAsync(closeStatus, closeDescription, RequestAborted);

/// <summary>
/// A queue entry; see <see cref="HandleMessageAsync(Message)"/>.
/// </summary>
/// <param name="OperationMessage">The message to send, if set; if it is null then this is a closure message.</param>
/// <param name="CloseStatus">The close status.</param>
/// <param name="CloseDescription">The close description.</param>
private record struct Message(OperationMessage? OperationMessage, WebSocketCloseStatus CloseStatus, string? CloseDescription);
private record struct Message(
OperationMessage? OperationMessage,
WebSocketCloseStatus CloseStatus,
string? CloseDescription
);
}
Loading