Skip to content

Commit

Permalink
fix broken test and add support for .net 5 send function in the Heade…
Browse files Browse the repository at this point in the history
…rPreservingRedirectHandler
  • Loading branch information
aspriddell committed Dec 28, 2020
1 parent 0ec5488 commit bb00381
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. Please refer to the LICENSE file at the root of this project for details

using System;
using System.Net;
using DragonFruit.Common.Data.Handlers;
using DragonFruit.Common.Data.Tests.Handlers.AuthPreservingHandler.Objects;
using NUnit.Framework;
Expand Down Expand Up @@ -33,7 +34,7 @@ public void TestHeaderPreservation()
redirectClient.Authorization = $"{auth.Type} {auth.AccessToken}";

// user lookups by username = 301. without our HeaderPreservingHandler we'd get a 401
redirectClient.Perform(new OrbitTestUserRequest());
Assert.AreEqual(redirectClient.Perform(new OrbitTestUserRequest()).StatusCode, HttpStatusCode.OK);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public class AuthRequest : ApiRequest
[FormParameter("client_secret")]
public string ClientSecret => GetEnvironmentVar("orbit_client_secret");

[FormParameter("scope")]
public string Scopes => "public";

private static string GetEnvironmentVar(string var)
{
var envVar = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
Expand Down
70 changes: 44 additions & 26 deletions DragonFruit.Common.Data/Handlers/HeaderPreservingRedirectHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Net;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -35,6 +36,20 @@ public HeaderPreservingRedirectHandler(HttpMessageHandler innerHandler)
}
}

#if NET5_0
protected override HttpResponseMessage Send(HttpRequestMessage request, CancellationToken cancellationToken)
{
var response = base.Send(request, cancellationToken);

if (IsRedirect(response.StatusCode))
{
response = base.Send(CopyRequest(response), cancellationToken);
}

return response;
}
#endif

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<HttpResponseMessage>();
Expand All @@ -56,28 +71,9 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
};
}

if (response.StatusCode == HttpStatusCode.MovedPermanently
|| response.StatusCode == HttpStatusCode.Moved
|| response.StatusCode == HttpStatusCode.Redirect
|| response.StatusCode == HttpStatusCode.Found
|| response.StatusCode == HttpStatusCode.SeeOther
|| response.StatusCode == HttpStatusCode.RedirectKeepVerb
|| response.StatusCode == HttpStatusCode.TemporaryRedirect
|| (int)response.StatusCode == 308)
if (IsRedirect(response.StatusCode))
{
var newRequest = CopyRequest(response);

if (response.StatusCode == HttpStatusCode.Redirect
|| response.StatusCode == HttpStatusCode.Found
|| response.StatusCode == HttpStatusCode.SeeOther)
{
newRequest.Content = null;
newRequest.Method = HttpMethod.Get;
}

newRequest.RequestUri = response.Headers.Location;

base.SendAsync(newRequest, cancellationToken)
base.SendAsync(CopyRequest(response), cancellationToken)
.ContinueWith(t2 => tcs.SetResult(t2.Result), cancellationToken);
}
else
Expand All @@ -91,7 +87,7 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques

private static HttpRequestMessage CopyRequest(HttpResponseMessage response)
{
var oldRequest = response.RequestMessage;
var oldRequest = response.RequestMessage ?? throw new NullReferenceException("Request Message not found");

var newRequest = new HttpRequestMessage(oldRequest.Method, oldRequest.RequestUri);

Expand All @@ -116,7 +112,7 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response)
#if NET5_0
foreach (var (key, value) in oldRequest.Options)
{
if (value == null || !(value is string s))
if (!(value is string s))
{
continue;
}
Expand All @@ -130,9 +126,7 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response)
}
#endif

if (response.StatusCode == HttpStatusCode.Redirect
|| response.StatusCode == HttpStatusCode.Found
|| response.StatusCode == HttpStatusCode.SeeOther)
if (AlterMethod(response.StatusCode))
{
newRequest.Content = null;
newRequest.Method = HttpMethod.Get;
Expand All @@ -144,5 +138,29 @@ private static HttpRequestMessage CopyRequest(HttpResponseMessage response)

return newRequest;
}

#region Switches

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool IsRedirect(HttpStatusCode code) => code switch
{
HttpStatusCode.Moved => true,
HttpStatusCode.Redirect => true,
HttpStatusCode.SeeOther => true,
HttpStatusCode.RedirectKeepVerb => true,

_ => false
};

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool AlterMethod(HttpStatusCode code) => code switch
{
HttpStatusCode.Redirect => true,
HttpStatusCode.SeeOther => true,

_ => false
};

#endregion
}
}

0 comments on commit bb00381

Please sign in to comment.