Skip to content

Commit

Permalink
fix receive timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
Scretch9 committed Dec 18, 2023
1 parent 276ea83 commit 5d0da86
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void Connect(IConnection connection)

public async Task<bool> InitializeConnection(GenericPointState state, bool observeAbilityToMove, CancellationToken cancellationToken)
{
if (await ReceiveMessage<PointPdiVersionCheckCommand>(cancellationToken) == null)
if (await ReceiveMessageWithTimeout<PointPdiVersionCheckCommand>(cancellationToken) == null)
{
_logger.LogError("Unexpected message.");
return false;
Expand All @@ -61,7 +61,7 @@ public async Task<bool> InitializeConnection(GenericPointState state, bool obser
var versionCheckResponse = new PointPdiVersionCheckMessage(_localId, _remoteId, PointPdiVersionCheckMessageResultPdiVersionCheck.PDIVersionsFromReceiverAndSenderDoMatch, /* TODO */ 0, 0, Array.Empty<byte>());
await SendMessage(versionCheckResponse);

if (await ReceiveMessage<PointInitialisationRequestCommand>(cancellationToken) == null)
if (await ReceiveMessageWithTimeout<PointInitialisationRequestCommand>(cancellationToken) == null)
{
_logger.LogError("Unexpected message.");
return false;
Expand Down Expand Up @@ -123,12 +123,19 @@ private async Task SendMessage(byte[] message)
await CurrentConnection.SendAsync(message);
}

private async Task<T> ReceiveMessage<T>(CancellationToken cancellationToken) where T : Message
private async Task<T> ReceiveMessageWithTimeout<T>(CancellationToken cancellationToken) where T : Message
{
if (CurrentConnection == null) throw new InvalidOperationException("Connection is null. Did you call Connect()?");
ResetTimeout();
var token = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _timeout.Token).Token;
return await ReceiveMessage<T>(token);
}

private async Task<T> ReceiveMessage<T>(CancellationToken cancellationToken) where T : Message
{
if (CurrentConnection == null) throw new InvalidOperationException("Connection is null. Did you call Connect()?");

var message = Message.FromBytes(await CurrentConnection.ReceiveAsync(CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _timeout.Token).Token));
var message = Message.FromBytes(await CurrentConnection.ReceiveAsync(cancellationToken));
if (message is T tMessage) return tMessage;
_logger.LogError("Unexpected message: {}", message);
throw new InvalidOperationException("Unexpected message.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void Connect(IConnection connection)

public async Task<bool> InitializeConnection(GenericPointState state, bool observeAbilityToMove, CancellationToken cancellationToken)
{
if (await ReceiveMessage<PointPdiVersionCheckCommand>(cancellationToken) == null)
if (await ReceiveMessageWithTimeout<PointPdiVersionCheckCommand>(cancellationToken) == null)
{
_logger.LogError("Unexpected message.");
return false;
Expand All @@ -60,7 +60,7 @@ public async Task<bool> InitializeConnection(GenericPointState state, bool obser
var versionCheckResponse = new PointPdiVersionCheckMessage(_localId, _remoteId, PointPdiVersionCheckMessageResultPdiVersionCheck.PDIVersionsFromReceiverAndSenderDoMatch, /* TODO */ 0, 0, new byte[] { });
await SendMessage(versionCheckResponse);

if (await ReceiveMessage<PointInitialisationRequestCommand>(cancellationToken) == null)
if (await ReceiveMessageWithTimeout<PointInitialisationRequestCommand>(cancellationToken) == null)
{
_logger.LogError("Unexpected message.");
return false;
Expand Down Expand Up @@ -122,12 +122,20 @@ private async Task SendMessage(byte[] message)
await CurrentConnection.SendAsync(message);
}

private async Task<T> ReceiveMessageWithTimeout<T>(CancellationToken cancellationToken) where T : Message
{
if (CurrentConnection == null) throw new InvalidOperationException("Connection is null. Did you call Connect()?");
ResetTimeout();
var token = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _timeout.Token).Token;
return await ReceiveMessage<T>(token);
}

private async Task<T> ReceiveMessage<T>(CancellationToken cancellationToken) where T : Message
{
if (CurrentConnection == null) throw new InvalidOperationException("Connection is null. Did you call Connect()?");
ResetTimeout();

var message = Message.FromBytes(await CurrentConnection.ReceiveAsync(CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _timeout.Token).Token));
var message = Message.FromBytes(await CurrentConnection.ReceiveAsync(cancellationToken));
if (message is T tMessage) return tMessage;
_logger.LogError("Unexpected message: {}", message);
throw new InvalidOperationException("Unexpected message.");
Expand Down

0 comments on commit 5d0da86

Please sign in to comment.