Skip to content

Commit

Permalink
Stability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvk committed May 9, 2024
1 parent 5ed35f3 commit 6d50923
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 32 deletions.
9 changes: 5 additions & 4 deletions src/RESTable.AspNetCore/AspNetCoreInputMessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using WebSocket = System.Net.WebSockets.WebSocket;

namespace RESTable.AspNetCore;

Expand Down Expand Up @@ -46,12 +47,12 @@ CancellationToken webSocketCancelledToken

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = new())
{
WebSocketCancelledToken.ThrowIfCancellationRequested();
using var combinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(WebSocketCancelledToken, cancellationToken);
combinedTokenSource.Token.ThrowIfCancellationRequested();
if (EndOfMessage) return 0;
var result = await WebSocket.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);
var result = await WebSocket.ReceiveAsync(buffer, combinedTokenSource.Token).ConfigureAwait(false);
if (result.MessageType is WebSocketMessageType.Close)
{
await WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken).ConfigureAwait(false);
throw new OperationCanceledException();
}
if (result.MessageType != MessageType)
Expand All @@ -72,7 +73,7 @@ public override async ValueTask DisposeAsync()
while (!EndOfMessage)
{
// Read the rest of the message
var _ = await ReadAsync(memory).ConfigureAwait(false);
var _ = await ReadAsync(memory, WebSocketCancelledToken).ConfigureAwait(false);
}
}
finally
Expand Down
1 change: 1 addition & 0 deletions src/RESTable.AspNetCore/AspNetCoreMessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.IO;
using System.Net.WebSockets;
using System.Threading;
using WebSocket = System.Net.WebSockets.WebSocket;

namespace RESTable.AspNetCore;

Expand Down
25 changes: 20 additions & 5 deletions src/RESTable.AspNetCore/AspNetCoreOutputMessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using RESTable.WebSockets;
using WebSocket = System.Net.WebSockets.WebSocket;

namespace RESTable.AspNetCore;

Expand All @@ -10,8 +12,8 @@ internal sealed class AspNetCoreOutputMessageStream : AspNetCoreMessageStream, I
public override bool CanRead => false;
public override bool CanWrite => true;

private WebSocketMessageStreamMode Mode { get; }
private SemaphoreSlim WriteSemaphore { get; }

private bool SemaphoreOpen { get; set; }

public override long Position
Expand All @@ -20,19 +22,32 @@ public override long Position
set => throw new NotSupportedException();
}

public AspNetCoreOutputMessageStream(WebSocket webSocket, WebSocketMessageType messageType, SemaphoreSlim writeSemaphore, CancellationToken webSocketCancelledToken)
public AspNetCoreOutputMessageStream
(
WebSocket webSocket,
WebSocketMessageStreamMode mode,
WebSocketMessageType messageType,
SemaphoreSlim writeSemaphore,
CancellationToken webSocketCancelledToken
)
: base(webSocket, messageType, webSocketCancelledToken)
{
WriteSemaphore = writeSemaphore;
Mode = mode;
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = new())
{
if (IsDisposed)
throw new InvalidOperationException("Cannot write to a closed WebSocket message stream");
if (WebSocket.State is not WebSocketState.Open)
throw new OperationCanceledException();
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(WebSocketCancelledToken, cancellationToken).Token;
{
if (Mode is WebSocketMessageStreamMode.Strict)
throw new OperationCanceledException();
return;
}
using var combindTokenSource = CancellationTokenSource.CreateLinkedTokenSource(WebSocketCancelledToken, cancellationToken);
var combinedToken = combindTokenSource.Token;
combinedToken.ThrowIfCancellationRequested();
if (!SemaphoreOpen)
{
Expand Down Expand Up @@ -82,7 +97,7 @@ await WebSocket.SendAsync
}
}

public override void Write(ReadOnlySpan<byte> buffer) => WriteAsync(buffer.ToArray(), CancellationToken.None).AsTask().Wait();
public override void Write(ReadOnlySpan<byte> buffer) => WriteAsync(buffer.ToArray(), WebSocketCancelledToken).AsTask().Wait();

public override void Write(byte[] buffer, int offset, int count) => Write(buffer.AsSpan(offset, count));

Expand Down
57 changes: 52 additions & 5 deletions src/RESTable.AspNetCore/AspNetCoreWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using RESTable.Requests;
using RESTable.WebSockets;
using static System.Net.WebSockets.WebSocketMessageType;
using WebSocket = RESTable.WebSockets.WebSocket;

Expand Down Expand Up @@ -62,11 +63,12 @@ protected override async Task SendBuffered(ReadOnlyMemory<byte> data, bool asTex
}
}

protected override Stream GetOutgoingMessageStream(bool asText, CancellationToken cancellationToken)
protected override Stream GetOutgoingMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken)
{
return new AspNetCoreOutputMessageStream
(
WebSocket!,
mode,
asText ? Text : Binary,
SendMessageSemaphore,
cancellationToken
Expand Down Expand Up @@ -98,8 +100,15 @@ protected override async Task InitMessageReceiveListener(CancellationToken cance
await WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken).ConfigureAwait(false);
return;
}
var nextMessage = new AspNetCoreInputMessageStream(
WebSocket, messageType, endOfMessage, byteCount, ArrayPool, WebSocketBufferSize, cancellationToken
var nextMessage = new AspNetCoreInputMessageStream
(
webSocket: WebSocket,
messageType: messageType,
endOfMessage: endOfMessage,
initialByteCount: byteCount,
arrayPool: ArrayPool,
bufferSize: WebSocketBufferSize,
webSocketCancelledToken: cancellationToken
);
await using var nextMessageDisposable = nextMessage.ConfigureAwait(false);
if (messageType is Binary)
Expand Down Expand Up @@ -131,7 +140,8 @@ protected override async Task InitMessageReceiveListener(CancellationToken cance
}
finally
{
await WebSocketClosingSource.CancelAsync().ConfigureAwait(false);
try { await WebSocketClosingSource.CancelAsync().ConfigureAwait(false); }
catch (ObjectDisposedException) { }
}
}

Expand All @@ -143,7 +153,44 @@ protected override async Task TryClose(string description, CancellationToken can
{
if (WebSocket.State is WebSocketState.Open or WebSocketState.CloseReceived or WebSocketState.CloseSent)
{
await WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, description, cancellationToken).ConfigureAwait(false);
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var closeTask = WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, description, cancellationTokenSource.Token);
var delayTask = Task.Delay(TimeSpan.FromSeconds(20), cancellationTokenSource.Token);
var completedTask = await Task.WhenAny(closeTask, delayTask).ConfigureAwait(false);
if (completedTask == delayTask)
{
// Cancel the closeTask and abort the WebSocket.
await delayTask.ConfigureAwait(false);
await cancellationTokenSource.CancelAsync().ConfigureAwait(false);
try
{
await closeTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// The close task was cancelled
}
WebSocket.Abort();
}
else
{
// Closed properly
await closeTask.ConfigureAwait(false);
await cancellationTokenSource.CancelAsync().ConfigureAwait(false);
try
{
await delayTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// The delay task was cancelled
}
}
}
else
{
// If the WebSocket is not in a state that allows graceful closure, abort directly.
WebSocket.Abort();
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/RESTable.AspNetCore/HttpRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ public async Task HandleRequest(string rootUri, Method method, HttpContext aspNe
{
await webSocket.LifetimeTask.ConfigureAwait(false);
}
break;
}
break;
}
default:
{
Expand Down
2 changes: 1 addition & 1 deletion src/RESTable/Resources/Templates/CommandTerminal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace RESTable.Resources.Templates;
/// </summary>
public abstract class CommandTerminal : Terminal
{
public CommandTerminal()
protected CommandTerminal()
{
Commands = new Dictionary<string, Command>(StringComparer.OrdinalIgnoreCase);
}
Expand Down
2 changes: 1 addition & 1 deletion src/RESTable/Results/Forbidden.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace RESTable.Results;
/// </summary>
public abstract class Forbidden : Error
{
public Forbidden(ErrorCodes code, string info) : base(code, info)
protected Forbidden(ErrorCodes code, string info) : base(code, info)
{
StatusCode = HttpStatusCode.Forbidden;
StatusDescription = "Forbidden";
Expand Down
3 changes: 2 additions & 1 deletion src/RESTable/Shell.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ public override async Task HandleTextInput(string input, CancellationToken cance
if (!string.IsNullOrWhiteSpace(tail) && double.TryParse(tail, out var timeOutSeconds))
timeoutCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(timeOutSeconds));
else timeoutCancellationTokenSource = new CancellationTokenSource();
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token);
using var _ = timeoutCancellationTokenSource;
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token);
var _cancellationToken = cancellationTokenSource.Token;
var acceptProvider = WebSocket.GetOutputContentTypeProvider();

Expand Down
8 changes: 7 additions & 1 deletion src/RESTable/WebSockets/AwaitingWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ public AwaitingWebSocket(IWebSocketInternal webSocket, Task waitTask)
await WebSocket.Send(data, asText, cancellationToken).ConfigureAwait(false);
}

public async ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new())
public async ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new CancellationToken())
{
await WaitTask.ConfigureAwait(false);
return await WebSocket.GetMessageStream(asText, cancellationToken).ConfigureAwait(false);
}

public async ValueTask<Stream> GetMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken = new())
{
await WaitTask.ConfigureAwait(false);
return await WebSocket.GetMessageStream(asText, mode, cancellationToken).ConfigureAwait(false);
}

public async Task SendException(Exception exception, CancellationToken cancellationToken = new())
{
await WaitTask.ConfigureAwait(false);
Expand Down
6 changes: 6 additions & 0 deletions src/RESTable/WebSockets/IWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ public interface IWebSocket : ITraceable, IProtocolHolder, IAsyncDisposable
/// Returns a stream that, when written to, writes data over the websocket over a single message until the stream is
/// disposed
/// </summary>
ValueTask<Stream> GetMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken = new());

/// <summary>
/// Returns a stream that, when written to, writes data over the websocket over a single message until the stream is
/// disposed. Uses the scrict mode.
/// </summary>
ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new());

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/RESTable/WebSockets/UnknownWebSocketIdException.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace RESTable.WebSockets;

internal class UnknownWebSocketIdException : RESTableException
public class UnknownWebSocketIdException : RESTableException
{
public UnknownWebSocketIdException(string info) : base(ErrorCodes.UnknownWebSocketId, info) { }
}
26 changes: 17 additions & 9 deletions src/RESTable/WebSockets/WebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ protected WebSocket(string webSocketId, RESTableContext context)
Id = webSocketId;
var applicationStopping = context.GetRequiredService<IHostApplicationLifetime>().ApplicationStopping;
WebSocketClosingSource = CancellationTokenSource.CreateLinkedTokenSource(applicationStopping);
WebSocketClosing = WebSocketClosingSource.Token;
Status = WebSocketStatus.Waiting;
Context = context;
JsonProvider = context.GetRequiredService<IJsonProvider>();
Expand Down Expand Up @@ -66,9 +67,14 @@ protected WebSocket(string webSocketId, RESTableContext context)
return SendBufferedInternal(data, asText, cancellationToken);
}

public ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken)
public ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new CancellationToken())
{
return new ValueTask<Stream>(GetOutgoingMessageStream(asText, cancellationToken));
return GetMessageStream(asText, WebSocketMessageStreamMode.Strict, cancellationToken);
}

public ValueTask<Stream> GetMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken)
{
return new ValueTask<Stream>(GetOutgoingMessageStream(asText, mode, cancellationToken));
}

/// <inheritdoc />
Expand Down Expand Up @@ -235,6 +241,9 @@ internal AppProfile GetAppProfile()

protected CancellationTokenSource WebSocketClosingSource { get; }

/// <inheritdoc />
public CancellationToken WebSocketClosing { get; }

/// <summary>
/// The ID of the WebSocket
/// </summary>
Expand Down Expand Up @@ -290,9 +299,6 @@ void IWebSocketInternal.SetStatus(WebSocketStatus status)
/// </summary>
public Task LifetimeTask { get; private set; }

/// <inheritdoc />
public CancellationToken WebSocketClosing => WebSocketClosingSource.Token;

#endregion

#region Delegating members
Expand Down Expand Up @@ -382,7 +388,7 @@ private async Task OpenClientWebSocket(CancellationToken cancellationToken, bool
{
case WebSocketStatus.Waiting:
{
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, WebSocketClosing);
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, WebSocketClosing);
if (acceptIncomingMessages)
LifetimeTask = InitMessageReceiveListener(cancellationTokenSource.Token);
await ConnectUnderlyingWebSocket(cancellationToken).ConfigureAwait(false);
Expand All @@ -406,7 +412,7 @@ private async Task OpenServerWebSocket(CancellationToken cancellationToken, bool
case WebSocketStatus.Waiting:
{
await ConnectUnderlyingWebSocket(cancellationToken).ConfigureAwait(false);
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, WebSocketClosing);
using var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, WebSocketClosing);
if (acceptIncomingMessages)
LifetimeTask = InitMessageReceiveListener(cancellationTokenSource.Token);
Status = WebSocketStatus.Open;
Expand Down Expand Up @@ -447,7 +453,9 @@ public async ValueTask DisposeAsync()
var terminalName = TerminalConnection?.Resource?.Name;
await ReleaseTerminal().ConfigureAwait(false);
await TryClose(CloseDescription, CancellationToken.None).ConfigureAwait(false);
await WebSocketClosingSource.CancelAsync().ConfigureAwait(false);
try { await WebSocketClosingSource.CancelAsync().ConfigureAwait(false); }
catch (ObjectDisposedException) { }
WebSocketClosingSource.Dispose();
Status = WebSocketStatus.Closed;
ClosedAt = DateTime.Now;
if (terminalName != Console.TypeName)
Expand Down Expand Up @@ -498,7 +506,7 @@ protected Task HandleBinaryInput(Stream binaryInput, CancellationToken cancellat
/// Returns a stream that, when written to, writes data over the websocket as a single
/// message until dispose, as either binary or text.
/// </summary>
protected abstract Stream GetOutgoingMessageStream(bool asText, CancellationToken cancellationToken);
protected abstract Stream GetOutgoingMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken);

/// <summary>
/// Sends the WebSocket upgrade and initiates the actual underlying WebSocket connection
Expand Down
9 changes: 7 additions & 2 deletions src/RESTable/WebSockets/WebSocketCombination.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,21 @@ public WebSocketCombination(IWebSocket[] webSockets)
return DoForAll(ws => ws.DirectToShell(assignments, cancellationToken));
}

public async ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new())
public async ValueTask<Stream> GetMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken = new CancellationToken())
{
var streams = new Stream[WebSockets.Length];
for (var i = 0; i < WebSockets.Length; i += 1)
{
streams[i] = await WebSockets[i].GetMessageStream(asText, cancellationToken).ConfigureAwait(false);
streams[i] = await WebSockets[i].GetMessageStream(asText, mode, cancellationToken).ConfigureAwait(false);
}
return new CombinedWebSocketsMessageStream(streams, asText, cancellationToken);
}

public ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new())
{
return GetMessageStream(asText, WebSocketMessageStreamMode.Strict, cancellationToken);
}

public Task DirectTo<T>(ITerminalResource<T> terminalResource, ICollection<Condition<T>>? assignments = null, CancellationToken cancellationToken = new())
where T : Terminal
{
Expand Down
6 changes: 6 additions & 0 deletions src/RESTable/WebSockets/WebSocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ internal void Unsuspend()
return WebSocket.Send(data, asText, cancellationToken);
}

/// <inheritdoc />
public ValueTask<Stream> GetMessageStream(bool asText, WebSocketMessageStreamMode mode, CancellationToken cancellationToken = new CancellationToken())
{
return WebSocket.GetMessageStream(asText, mode, cancellationToken);
}

/// <inheritdoc />
public ValueTask<Stream> GetMessageStream(bool asText, CancellationToken cancellationToken = new())
{
Expand Down
Loading

0 comments on commit 6d50923

Please sign in to comment.