Skip to content
This repository has been archived by the owner on Dec 14, 2018. It is now read-only.

Commit

Permalink
Refactor CORS support out of MVC Core
Browse files Browse the repository at this point in the history
  • Loading branch information
javiercn committed Aug 14, 2017
1 parent 2ef2648 commit f2a8c1c
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ public class HttpMethodActionConstraint : IActionConstraint

private readonly IReadOnlyList<string> _httpMethods;

private readonly string OriginHeader = "Origin";
private readonly string AccessControlRequestMethod = "Access-Control-Request-Method";
private readonly string PreflightHttpMethod = "OPTIONS";

// Empty collection means any method will be accepted.
public HttpMethodActionConstraint(IEnumerable<string> httpMethods)
{
Expand Down Expand Up @@ -46,7 +42,7 @@ public HttpMethodActionConstraint(IEnumerable<string> httpMethods)

public int Order => HttpMethodConstraintOrder;

public bool Accept(ActionConstraintContext context)
public virtual bool Accept(ActionConstraintContext context)
{
if (context == null)
{
Expand All @@ -61,18 +57,6 @@ public bool Accept(ActionConstraintContext context)
var request = context.RouteContext.HttpContext.Request;
var method = request.Method;

// Perf: Check http method before accessing the Headers collection.
if (string.Equals(method, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) &&
request.Headers.ContainsKey(OriginHeader))
{
// Update the http method if it is preflight request.
var accessControlRequestMethod = request.Headers[AccessControlRequestMethod];
if (!StringValues.IsNullOrEmpty(accessControlRequestMethod))
{
method = accessControlRequestMethod;
}
}

for (var i = 0; i < _httpMethods.Count; i++)
{
var supportedMethod = _httpMethods[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System.Linq;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Mvc.ApplicationModels;
using Microsoft.AspNetCore.Mvc.Internal;
using Microsoft.Extensions.Options;

namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
Expand All @@ -28,6 +30,9 @@ public void OnProvidersExecuting(ApplicationModelProviderContext context)
throw new ArgumentNullException(nameof(context));
}

var isCorsEnabledGlobally = context.Result.Filters.OfType<ICorsAuthorizationFilter>().Any() ||
context.Result.Filters.OfType<CorsAuthorizationFilterFactory>().Any();

foreach (var controllerModel in context.Result.Controllers)
{
var enableCors = controllerModel.Attributes.OfType<IEnableCorsAttribute>().FirstOrDefault();
Expand All @@ -42,6 +47,8 @@ public void OnProvidersExecuting(ApplicationModelProviderContext context)
controllerModel.Filters.Add(new DisableCorsAuthorizationFilter());
}

var corsOnController = enableCors != null || disableCors != null || controllerModel.Filters.OfType<ICorsAuthorizationFilter>().Any();

foreach (var actionModel in controllerModel.Actions)
{
enableCors = actionModel.Attributes.OfType<IEnableCorsAttribute>().FirstOrDefault();
Expand All @@ -55,6 +62,28 @@ public void OnProvidersExecuting(ApplicationModelProviderContext context)
{
actionModel.Filters.Add(new DisableCorsAuthorizationFilter());
}

var corsOnAction = enableCors != null || disableCors != null || actionModel.Filters.OfType<ICorsAuthorizationFilter>().Any();

if (isCorsEnabledGlobally || corsOnController || corsOnAction)
{
UpdateHttpMethodActionConstraint(actionModel);
}
}
}
}

private static void UpdateHttpMethodActionConstraint(ActionModel actionModel)
{
for (var i = 0; i < actionModel.Selectors.Count; i++)
{
var selectorModel = actionModel.Selectors[i];
for (var j = 0; j < selectorModel.ActionConstraints.Count; j++)
{
if (selectorModel.ActionConstraints[j] is HttpMethodActionConstraint httpConstraint)
{
selectorModel.ActionConstraints[j] = new CorsHttpMethodActionConstraint(httpConstraint);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using Microsoft.AspNetCore.Mvc.ActionConstraints;
using Microsoft.Extensions.Primitives;
using Microsoft.AspNetCore.Mvc.Internal;

namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
public class CorsHttpMethodActionConstraint : HttpMethodActionConstraint
{
private readonly string OriginHeader = "Origin";
private readonly string AccessControlRequestMethod = "Access-Control-Request-Method";
private readonly string PreflightHttpMethod = "OPTIONS";

public CorsHttpMethodActionConstraint(HttpMethodActionConstraint constraint)
: base(constraint.HttpMethods)
{
}

public override bool Accept(ActionConstraintContext context)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}

var methods = (ReadOnlyCollection<string>)HttpMethods;
if (methods.Count == 0)
{
return true;
}

var request = context.RouteContext.HttpContext.Request;
if (request.Headers.ContainsKey(OriginHeader) &&
string.Equals(request.Method, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) &&
request.Headers.TryGetValue(AccessControlRequestMethod, out var accessControlRequestMethod) &&
!StringValues.IsNullOrEmpty(accessControlRequestMethod))
{
for (var i = 0; i < methods.Count; i++)
{
var supportedMethod = methods[i];
if (string.Equals(supportedMethod, accessControlRequestMethod, StringComparison.OrdinalIgnoreCase))
{
return true;
}
}

return false;
}

return base.Accept(context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class HttpMethodActionConstraintTest

[Theory]
[MemberData(nameof(AcceptCaseInsensitiveData))]
public void HttpMethodActionConstraint_Accept_Preflight_CaseInsensitive(IEnumerable<string> httpMethods, string accessControlMethod)
public void HttpMethodActionConstraint_IgnoresPreflightRequests(IEnumerable<string> httpMethods, string accessControlMethod)
{
// Arrange
var constraint = new HttpMethodActionConstraint(httpMethods);
Expand All @@ -37,7 +37,7 @@ public void HttpMethodActionConstraint_Accept_Preflight_CaseInsensitive(IEnumera
var result = constraint.Accept(context);

// Assert
Assert.True(result, "Request should have been accepted.");
Assert.False(result, "Request should have been rejected.");
}

[Theory]
Expand Down
Loading

0 comments on commit f2a8c1c

Please sign in to comment.