Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cooperative cancellation #144

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/Websocket.Client/IWebsocketClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Websocket.Client
Expand Down Expand Up @@ -95,28 +96,28 @@ public interface IWebsocketClient : IDisposable
/// In case of connection error it doesn't throw an exception.
/// Only streams a message via 'DisconnectionHappened' and logs it.
/// </summary>
Task Start();
Task Start(CancellationToken cancellation = default);

/// <summary>
/// Start listening to the websocket stream on the background thread.
/// In case of connection error it throws an exception.
/// Fail fast approach.
/// </summary>
Task StartOrFail();
Task StartOrFail(CancellationToken cancellation = default);

/// <summary>
/// Stop/close websocket connection with custom close code.
/// Method doesn't throw exception, only logs it and mark client as closed.
/// </summary>
/// <returns>Returns true if close was initiated successfully</returns>
Task<bool> Stop(WebSocketCloseStatus status, string statusDescription);
Task<bool> Stop(WebSocketCloseStatus status, string statusDescription, CancellationToken cancellation = default);

/// <summary>
/// Stop/close websocket connection with custom close code.
/// Method could throw exceptions, but client is marked as closed anyway.
/// </summary>
/// <returns>Returns true if close was initiated successfully</returns>
Task<bool> StopOrFail(WebSocketCloseStatus status, string statusDescription);
Task<bool> StopOrFail(WebSocketCloseStatus status, string statusDescription, CancellationToken cancellation = default);

/// <summary>
/// Send message to the websocket channel.
Expand Down Expand Up @@ -149,7 +150,8 @@ public interface IWebsocketClient : IDisposable
/// on the full .NET Framework platform
/// </summary>
/// <param name="message">Message to be sent</param>
Task SendInstant(string message);
/// <param name="cancellation">Cancellation token</param>
Task SendInstant(string message, CancellationToken cancellation = default);

/// <summary>
/// Send binary message to the websocket channel.
Expand All @@ -158,7 +160,8 @@ public interface IWebsocketClient : IDisposable
/// on the full .NET Framework platform
/// </summary>
/// <param name="message">Message to be sent</param>
Task SendInstant(byte[] message);
/// <param name="cancellation">Cancellation token</param>
Task SendInstant(byte[] message, CancellationToken cancellation = default);

/// <summary>
/// Send already converted text message to the websocket channel.
Expand All @@ -183,14 +186,14 @@ public interface IWebsocketClient : IDisposable
/// Closes current websocket stream and perform a new connection to the server.
/// In case of connection error it doesn't throw an exception, but tries to reconnect indefinitely.
/// </summary>
Task Reconnect();
Task Reconnect(CancellationToken cancellation = default);

/// <summary>
/// Force reconnection.
/// Closes current websocket stream and perform a new connection to the server.
/// In case of connection error it throws an exception and doesn't perform any other reconnection try.
/// </summary>
Task ReconnectOrFail();
Task ReconnectOrFail(CancellationToken cancellation = default);

/// <summary>
/// Stream/publish fake message (via 'MessageReceived' observable).
Expand Down
4 changes: 2 additions & 2 deletions src/Websocket.Client/Threading/WebsocketAsyncLock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public IDisposable Lock()
/// <summary>
/// Use inside 'using' block with await
/// </summary>
public Task<IDisposable> LockAsync()
public Task<IDisposable> LockAsync(CancellationToken cancellation = default)
{
var waitTask = _semaphore.WaitAsync();
var waitTask = _semaphore.WaitAsync(cancellation);
return waitTask.IsCompleted
? _releaserTask
: waitTask.ContinueWith(
Expand Down
26 changes: 14 additions & 12 deletions src/Websocket.Client/WebsocketClient.Reconnecting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ public partial class WebsocketClient
/// Closes current websocket stream and perform a new connection to the server.
/// In case of connection error it doesn't throw an exception, but tries to reconnect indefinitely.
/// </summary>
public Task Reconnect()
public Task Reconnect(CancellationToken cancellation = default)
{
return ReconnectInternal(false);
return ReconnectInternal(false, cancellation);
}

/// <summary>
/// Force reconnection.
/// Closes current websocket stream and perform a new connection to the server.
/// In case of connection error it throws an exception and doesn't perform any other reconnection try.
/// </summary>
public Task ReconnectOrFail()
public Task ReconnectOrFail(CancellationToken cancellation = default)
{
return ReconnectInternal(true);
return ReconnectInternal(true, cancellation);
}

private async Task ReconnectInternal(bool failFast)
private async Task ReconnectInternal(bool failFast, CancellationToken cancellation)
{
if (!IsStarted)
{
Expand All @@ -38,23 +38,23 @@ private async Task ReconnectInternal(bool failFast)

try
{
await ReconnectSynchronized(ReconnectionType.ByUser, failFast, null).ConfigureAwait(false);
await ReconnectSynchronized(ReconnectionType.ByUser, failFast, null, cancellation).ConfigureAwait(false);
}
finally
{
_reconnecting = false;
}
}

private async Task ReconnectSynchronized(ReconnectionType type, bool failFast, Exception? causedException)
private async Task ReconnectSynchronized(ReconnectionType type, bool failFast, Exception? causedException, CancellationToken cancellation)
{
using (await _locker.LockAsync())
using (await _locker.LockAsync(cancellation))
{
await Reconnect(type, failFast, causedException);
await Reconnect(type, failFast, causedException, cancellation);
}
}

private async Task Reconnect(ReconnectionType type, bool failFast, Exception? causedException)
private async Task Reconnect(ReconnectionType type, bool failFast, Exception? causedException, CancellationToken cancellation)
{
IsRunning = false;
if (_disposing || !IsStarted)
Expand Down Expand Up @@ -98,7 +98,9 @@ private async Task Reconnect(ReconnectionType type, bool failFast, Exception? ca

_logger.LogDebug(L("Reconnecting..."), Name);
_cancellation = new CancellationTokenSource();
await StartClient(_url, _cancellation.Token, type, failFast).ConfigureAwait(false);

using var cts = CancellationTokenSource.CreateLinkedTokenSource(_cancellation.Token, cancellation);
await StartClient(_url, cts.Token, type, failFast).ConfigureAwait(false);
_reconnecting = false;
}

Expand Down Expand Up @@ -130,7 +132,7 @@ private void LastChance(object? state)
_logger.LogDebug(L("Last message received more than {timeoutMs} ms ago. Hard restart.."), Name, timeoutMs.ToString("F"));

DeactivateLastChance();
_ = ReconnectSynchronized(ReconnectionType.NoMessageReceived, false, null);
_ = ReconnectSynchronized(ReconnectionType.NoMessageReceived, false, null, _cancellation?.Token ?? CancellationToken.None);
}
}
}
Expand Down
35 changes: 19 additions & 16 deletions src/Websocket.Client/WebsocketClient.Sending.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ public bool Send(ArraySegment<byte> message)
/// on the full .NET Framework platform
/// </summary>
/// <param name="message">Message to be sent</param>
public Task SendInstant(string message)
/// <param name="cancellation">Cancellation token</param>
public Task SendInstant(string message, CancellationToken cancellation)
{
Validations.Validations.ValidateInput(message, nameof(message));

return SendInternalSynchronized(new RequestTextMessage(message));
return SendInternalSynchronized(new RequestTextMessage(message), cancellation);
}

/// <summary>
Expand All @@ -82,9 +83,10 @@ public Task SendInstant(string message)
/// on the full .NET Framework platform
/// </summary>
/// <param name="message">Message to be sent</param>
public Task SendInstant(byte[] message)
/// <param name="cancellation">Cancellation token</param>
public Task SendInstant(byte[] message, CancellationToken cancellation)
{
return SendInternalSynchronized(new ArraySegment<byte>(message));
return SendInternalSynchronized(new ArraySegment<byte>(message), cancellation);
}

/// <summary>
Expand Down Expand Up @@ -138,7 +140,7 @@ private async Task SendTextFromQueue()
{
try
{
await SendInternalSynchronized(message).ConfigureAwait(false);
await SendInternalSynchronized(message, _cancellationTotal?.Token ?? CancellationToken.None).ConfigureAwait(false);
}
catch (Exception e)
{
Expand Down Expand Up @@ -179,7 +181,7 @@ private async Task SendBinaryFromQueue()
{
try
{
await SendInternalSynchronized(message).ConfigureAwait(false);
await SendInternalSynchronized(message, _cancellationTotal?.Token ?? CancellationToken.None).ConfigureAwait(false);
}
catch (Exception e)
{
Expand Down Expand Up @@ -220,15 +222,15 @@ private void StartBackgroundThreadForSendingBinary()
_ = Task.Factory.StartNew(_ => SendBinaryFromQueue(), TaskCreationOptions.LongRunning, _cancellationTotal?.Token ?? CancellationToken.None);
}

private async Task SendInternalSynchronized(RequestMessage message)
private async Task SendInternalSynchronized(RequestMessage message, CancellationToken cancellation)
{
using (await _locker.LockAsync())
using (await _locker.LockAsync(cancellation))
{
await SendInternal(message);
await SendInternal(message, cancellation);
}
}

private async Task SendInternal(RequestMessage message)
private async Task SendInternal(RequestMessage message, CancellationToken cancellation)
{
if (!IsClientConnected())
{
Expand All @@ -255,20 +257,21 @@ private async Task SendInternal(RequestMessage message)
throw new ArgumentException($"Unknown message type: {message.GetType()}");
}

using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellation, _cancellation?.Token ?? CancellationToken.None);
await _client!
.SendAsync(payload, WebSocketMessageType.Text, true, _cancellation?.Token ?? CancellationToken.None)
.SendAsync(payload, WebSocketMessageType.Text, true, cts.Token)
.ConfigureAwait(false);
}

private async Task SendInternalSynchronized(ArraySegment<byte> message)
private async Task SendInternalSynchronized(ArraySegment<byte> message, CancellationToken cancellation)
{
using (await _locker.LockAsync())
using (await _locker.LockAsync(cancellation))
{
await SendInternal(message);
await SendInternal(message, cancellation);
}
}

private async Task SendInternal(ArraySegment<byte> payload)
private async Task SendInternal(ArraySegment<byte> payload, CancellationToken cancellation)
{
if (!IsClientConnected())
{
Expand All @@ -279,7 +282,7 @@ private async Task SendInternal(ArraySegment<byte> payload)
_logger.LogTrace(L("Sending binary, length: {length}"), Name, payload.Count);

await _client!
.SendAsync(payload, WebSocketMessageType.Binary, true, _cancellation?.Token ?? CancellationToken.None)
.SendAsync(payload, WebSocketMessageType.Binary, true, cancellation)
.ConfigureAwait(false);
}
}
Expand Down
29 changes: 15 additions & 14 deletions src/Websocket.Client/WebsocketClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,33 +237,33 @@ public void Dispose()
/// In case of connection error it doesn't throw an exception.
/// Only streams a message via 'DisconnectionHappened' and logs it.
/// </summary>
public Task Start()
public Task Start(CancellationToken cancellation = default)
{
return StartInternal(false);
return StartInternal(false, cancellation);
}

/// <summary>
/// Start listening to the websocket stream on the background thread.
/// In case of connection error it throws an exception.
/// Fail fast approach.
/// </summary>
public Task StartOrFail()
public Task StartOrFail(CancellationToken cancellation = default)
{
return StartInternal(true);
return StartInternal(true, cancellation);
}

/// <summary>
/// Stop/close websocket connection with custom close code.
/// Method doesn't throw exception, only logs it and mark client as closed.
/// </summary>
/// <returns>Returns true if close was initiated successfully</returns>
public async Task<bool> Stop(WebSocketCloseStatus status, string statusDescription)
public async Task<bool> Stop(WebSocketCloseStatus status, string statusDescription, CancellationToken cancellation = default)
{
var result = await StopInternal(
_client,
status,
statusDescription,
null,
cancellation,
false,
false).ConfigureAwait(false);
_disconnectedSubject.OnNext(DisconnectionInfo.Create(DisconnectionType.ByUser, _client, null));
Expand All @@ -275,13 +275,13 @@ public async Task<bool> Stop(WebSocketCloseStatus status, string statusDescripti
/// Method could throw exceptions, but client is marked as closed anyway.
/// </summary>
/// <returns>Returns true if close was initiated successfully</returns>
public async Task<bool> StopOrFail(WebSocketCloseStatus status, string statusDescription)
public async Task<bool> StopOrFail(WebSocketCloseStatus status, string statusDescription, CancellationToken cancellation = default)
{
var result = await StopInternal(
_client,
status,
statusDescription,
null,
cancellation,
true,
false).ConfigureAwait(false);
_disconnectedSubject.OnNext(DisconnectionInfo.Create(DisconnectionType.ByUser, _client, null));
Expand All @@ -301,7 +301,7 @@ public async Task<bool> StopOrFail(WebSocketCloseStatus status, string statusDes
});
}

private async Task StartInternal(bool failFast)
private async Task StartInternal(bool failFast, CancellationToken cancellation)
{
if (_disposing)
{
Expand All @@ -320,7 +320,8 @@ private async Task StartInternal(bool failFast)
_cancellation = new CancellationTokenSource();
_cancellationTotal = new CancellationTokenSource();

await StartClient(_url, _cancellation.Token, ReconnectionType.Initial, failFast).ConfigureAwait(false);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(_cancellation.Token, cancellation);
await StartClient(_url, cts.Token, ReconnectionType.Initial, failFast).ConfigureAwait(false);

StartBackgroundThreadForSendingText();
StartBackgroundThreadForSendingBinary();
Expand Down Expand Up @@ -393,7 +394,7 @@ private async Task StartClient(Uri uri, CancellationToken token, ReconnectionTyp
try
{
_client = await _connectionFactory(uri, token).ConfigureAwait(false);
_ = Listen(_client, token);
_ = Listen(_client, _cancellation?.Token ?? CancellationToken.None);
IsRunning = true;
IsStarted = true;
_reconnectionSubject.OnNext(ReconnectionInfo.Create(type));
Expand Down Expand Up @@ -438,7 +439,7 @@ private async Task StartClient(Uri uri, CancellationToken token, ReconnectionTyp
private void ReconnectOnError(object? state)
{
// await Task.Delay(timeout, token).ConfigureAwait(false);
_ = Reconnect(ReconnectionType.Error, false, state as Exception).ConfigureAwait(false);
_ = Reconnect(ReconnectionType.Error, false, state as Exception, _cancellation?.Token ?? CancellationToken.None).ConfigureAwait(false);
}

private bool IsClientConnected()
Expand Down Expand Up @@ -508,7 +509,7 @@ await StopInternal(client, WebSocketCloseStatus.NormalClosure, "Closing",
// reconnect if enabled
if (IsReconnectionEnabled && !ShouldIgnoreReconnection(client))
{
_ = ReconnectSynchronized(ReconnectionType.Lost, false, null);
_ = ReconnectSynchronized(ReconnectionType.Lost, false, null, token);
}

return;
Expand Down Expand Up @@ -571,7 +572,7 @@ await StopInternal(client, WebSocketCloseStatus.NormalClosure, "Closing",
}

// listening thread is lost, we have to reconnect
_ = ReconnectSynchronized(ReconnectionType.Lost, false, causedException);
_ = ReconnectSynchronized(ReconnectionType.Lost, false, causedException, token);
}

private bool ShouldIgnoreReconnection(WebSocket client)
Expand Down
Loading