Skip to content

Commit

Permalink
[#260] WebSocketTransportInitiator.ConnectAsync can result in Unobser…
Browse files Browse the repository at this point in the history
…vedTaskException - 2
  • Loading branch information
xinchen10 committed Jun 13, 2024
1 parent c48e6ab commit afe233a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
namespace Microsoft.Azure.Amqp.Transport
{
using System;
using System.Threading;
using System.Threading.Tasks;
using Windows.Networking.Sockets;

sealed class WebSocketTransportInitiator : TransportInitiator
Expand All @@ -20,31 +22,39 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c
StreamWebSocket sws = new StreamWebSocket();
sws.Control.SupportedProtocols.Add(this.settings.SubProtocol);

var task = sws.ConnectAsync(this.settings.Uri).AsTask().WithTimeout(timeout, () => "timeout");
var cts = new CancellationTokenSource(timeout);
var task = sws.ConnectAsync(this.settings.Uri).AsTask(cts.Token);
if (task.IsCompleted)
{
callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri);
this.OnConnect(callbackArgs, task, sws, cts);
return false;
}

task.ContinueWith(t =>
{
if (t.IsFaulted)
{
callbackArgs.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
callbackArgs.Exception = new OperationCanceledException();
}
else
{
callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri);
}
this.OnConnect(callbackArgs, t, sws, cts);
callbackArgs.CompletedCallback(callbackArgs);
});
return true;
}

void OnConnect(TransportAsyncCallbackArgs callbackArgs, Task t, StreamWebSocket sws, CancellationTokenSource cts)
{
cts.Dispose();
if (t.IsFaulted)
{
sws.Dispose();
callbackArgs.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
sws.Dispose();
callbackArgs.Exception = new OperationCanceledException();
}
else
{
callbackArgs.Transport = new WebSocketTransport(sws, this.settings.Uri);
}
}
}
}
83 changes: 45 additions & 38 deletions Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,13 @@ public sealed override bool WriteAsync(TransportAsyncCallbackArgs args)
Task task = this.webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, CancellationToken.None);
if (task.IsCompleted)
{
this.OnWriteComplete(args, buffer, mergedBuffer, startTime);
this.OnWriteComplete(args, task, buffer, mergedBuffer, startTime);
return false;
}

task.ContinueWith(t =>
{
if (t.IsFaulted)
{
args.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
args.Exception = new OperationCanceledException();
}
else
{
this.OnWriteComplete(args, buffer, mergedBuffer, startTime);
}
this.OnWriteComplete(args, t, buffer, mergedBuffer, startTime);
args.CompletedCallback(args);
});
return true;
Expand All @@ -103,25 +91,13 @@ public sealed override bool ReadAsync(TransportAsyncCallbackArgs args)
Task<WebSocketReceiveResult> task = this.webSocket.ReceiveAsync(buffer, CancellationToken.None);
if (task.IsCompleted)
{
this.OnReadComplete(args, task.Result.Count, startTime);
this.OnReadComplete(args, task, startTime);
return false;
}

task.ContinueWith(t =>
{
if (t.IsFaulted)
{
args.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
args.Exception = new OperationCanceledException();
}
else
{
this.OnReadComplete(args, t.Result.Count, startTime);
}
this.OnReadComplete(args, t, startTime);
args.CompletedCallback(args);
});
return true;
Expand All @@ -137,6 +113,15 @@ protected override bool CloseInternal()
Task task = webSocket.CloseAsync(WebSocketCloseStatus.Empty, string.Empty, CancellationToken.None);
if (task.IsCompleted)
{
if (task.IsFaulted)
{
ExceptionDispatcher.Throw(task.Exception.InnerException);
}
else if (task.IsCanceled)
{
throw new OperationCanceledException();
}

return true;
}

Expand Down Expand Up @@ -165,26 +150,48 @@ internal static bool MatchScheme(string scheme)
string.Equals(scheme, WebSocketTransportSettings.SecureWebSockets, StringComparison.OrdinalIgnoreCase);
}

void OnWriteComplete(TransportAsyncCallbackArgs args, ArraySegment<byte> buffer, ByteBuffer byteBuffer, DateTime startTime)
void OnWriteComplete(TransportAsyncCallbackArgs args, Task t, ArraySegment<byte> buffer, ByteBuffer byteBuffer, DateTime startTime)
{
args.BytesTransfered = buffer.Count;
if (byteBuffer != null)
if (t.IsFaulted)
{
byteBuffer.Dispose();
args.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
args.Exception = new OperationCanceledException();
}
else
{
args.BytesTransfered = buffer.Count;
if (this.usageMeter != null)
{
this.usageMeter.OnTransportWrite(0, buffer.Count, 0, DateTime.UtcNow.Subtract(startTime).Ticks);
}
}

if (this.usageMeter != null)
if (byteBuffer != null)
{
this.usageMeter.OnTransportWrite(0, buffer.Count, 0, DateTime.UtcNow.Subtract(startTime).Ticks);
byteBuffer.Dispose();
}
}

void OnReadComplete(TransportAsyncCallbackArgs args, int count, DateTime startTime)
void OnReadComplete(TransportAsyncCallbackArgs args, Task<WebSocketReceiveResult> t, DateTime startTime)
{
args.BytesTransfered = count;
if (this.usageMeter != null)
if (t.IsFaulted)
{
this.usageMeter.OnTransportRead(0, count, 0, DateTime.UtcNow.Subtract(startTime).Ticks);
args.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
args.Exception = new OperationCanceledException();
}
else
{
args.BytesTransfered = t.Result.Count;
if (this.usageMeter != null)
{
this.usageMeter.OnTransportRead(0, args.BytesTransfered, 0, DateTime.UtcNow.Subtract(startTime).Ticks);
}
}
}
}
Expand Down
40 changes: 23 additions & 17 deletions Microsoft.Azure.Amqp/Amqp/Transport/WebSocketTransportInitiator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,39 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c
cws.Options.SetBuffer(this.settings.ReceiveBufferSize, this.settings.SendBufferSize);
#endif

Task task = cws.ConnectAsync(this.settings.Uri, CancellationToken.None).WithTimeout(timeout, () => "Client WebSocket connect timed out");
var cts = new CancellationTokenSource(timeout);
Task task = cws.ConnectAsync(this.settings.Uri, cts.Token);
if (task.IsCompleted)
{
callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri);
this.OnConnect(callbackArgs, task, cws, cts);
return false;
}

task.ContinueWith(t =>
{
if (t.IsFaulted)
{
cws.Abort();
callbackArgs.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
cws.Abort();
callbackArgs.Exception = new OperationCanceledException();
}
else
{
callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri);
}
this.OnConnect(callbackArgs, t, cws, cts);
callbackArgs.CompletedCallback(callbackArgs);
});
return true;
}

void OnConnect(TransportAsyncCallbackArgs callbackArgs, Task t, ClientWebSocket cws, CancellationTokenSource cts)
{
cts.Dispose();
if (t.IsFaulted)
{
cws.Dispose();
callbackArgs.Exception = t.Exception.InnerException;
}
else if (t.IsCanceled)
{
cws.Dispose();
callbackArgs.Exception = new OperationCanceledException();
}
else
{
callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri);
}
}
}
}
60 changes: 0 additions & 60 deletions Microsoft.Azure.Amqp/TaskHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ namespace Microsoft.Azure.Amqp
{
using System;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

static class TaskHelpers
Expand Down Expand Up @@ -58,17 +57,6 @@ public static Task<T> CreateTask<T>(Func<AsyncCallback, object, IAsyncResult> be
return retval;
}

public static void Fork(this Task thisTask)
{
Fork(thisTask, "TaskExtensions.Fork");
}

public static void Fork(this Task thisTask, string tracingInfo)
{
Fx.Assert(thisTask != null, "task is required!");
thisTask.ContinueWith(t => AmqpTrace.Provider.AmqpHandleException(t.Exception, tracingInfo), TaskContinuationOptions.OnlyOnFaulted);
}

public static IAsyncResult ToAsyncResult(this Task task, AsyncCallback callback, object state)
{
if (task.AsyncState == state)
Expand Down Expand Up @@ -167,54 +155,6 @@ public static TResult EndAsyncResult<TResult>(IAsyncResult asyncResult)
return task.GetAwaiter().GetResult();
}

public static Task WithTimeout(this Task task, TimeSpan timeout, Func<string> errorMessage)
{
return WithTimeout(task, timeout, errorMessage, CancellationToken.None);
}

public static async Task WithTimeout(this Task task, TimeSpan timeout, Func<string> errorMessage, CancellationToken token)
{
if (timeout == TimeSpan.MaxValue)
{
timeout = Timeout.InfiniteTimeSpan;
}
else if (timeout.TotalMilliseconds > Int32.MaxValue)
{
timeout = TimeSpan.FromMilliseconds(Int32.MaxValue);
}

if (task.IsCompleted || (timeout == Timeout.InfiniteTimeSpan && token == CancellationToken.None))
{
await task.ConfigureAwait(false);
return;
}

using (var cts = CancellationTokenSource.CreateLinkedTokenSource(token))
{
if (task == await Task.WhenAny(task, CreateDelayTask(timeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel();
await task.ConfigureAwait(false);
return;
}
}

throw new TimeoutException(errorMessage());
}

static async Task CreateDelayTask(TimeSpan timeout, CancellationToken token)
{
try
{
await Task.Delay(timeout, token).ConfigureAwait(false);
}
catch (TaskCanceledException)
{
// No need to throw. Caller is responsible for detecting
// which task completed and throwing appropriate Timeout Exception
}
}

[StructLayout(LayoutKind.Sequential, Size = 1)]
internal struct VoidTaskResult
{
Expand Down

0 comments on commit afe233a

Please sign in to comment.