Skip to content

Commit

Permalink
Check exposure/enabled for HTTP verb match in management endpoints
Browse files Browse the repository at this point in the history
Instead of mapping endpoints in ASP.NET routing per verb, always map all verbs and then filter inside middleware.

Advantages:
- Returns 404 when not exposed/enabled with invalid verb, instead of 405
- Can change verbs at runtime

Disadvantages:
- Verb information is lost in route mappings endpoint (always shows all verbs)
- Entries in route mappings with all verbs disabled are no longer hidden
  • Loading branch information
bart-vmware committed Nov 21, 2024
1 parent 2d25b80 commit b03df23
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 33 deletions.
37 changes: 14 additions & 23 deletions src/Management/src/Endpoint/ActuatorEndpointMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,30 @@ public void Map(IEndpointRouteBuilder endpointRouteBuilder, ActuatorConventionBu
ArgumentNullException.ThrowIfNull(endpointRouteBuilder);
ArgumentNullException.ThrowIfNull(actuatorConventionBuilder);

InnerMap(middleware => endpointRouteBuilder.CreateApplicationBuilder().UseMiddleware(middleware.GetType()).Build(),
(middleware, requestPath, pipeline) =>
{
HashSet<string> allowedVerbs = middleware.EndpointOptions.GetSafeAllowedVerbs();

if (allowedVerbs.Count > 0)
{
IEndpointConventionBuilder endpointConventionBuilder = endpointRouteBuilder.MapMethods(requestPath, allowedVerbs, pipeline);
InnerMap(middleware => endpointRouteBuilder.CreateApplicationBuilder().UseMiddleware(middleware.GetType()).Build(), (requestPath, pipeline) =>
{
IEndpointConventionBuilder endpointConventionBuilder = endpointRouteBuilder.Map(requestPath, pipeline);

foreach (Action<IEndpointConventionBuilder> configureAction in _conventionOptionsMonitor.CurrentValue.ConfigureActions)
{
configureAction(endpointConventionBuilder);
}
foreach (Action<IEndpointConventionBuilder> configureAction in _conventionOptionsMonitor.CurrentValue.ConfigureActions)
{
configureAction(endpointConventionBuilder);
}

actuatorConventionBuilder.TrackTarget(endpointConventionBuilder);
}
});
actuatorConventionBuilder.TrackTarget(endpointConventionBuilder);
});
}

public void Map(IRouteBuilder routeBuilder)
{
ArgumentNullException.ThrowIfNull(routeBuilder);

InnerMap(middleware => routeBuilder.ApplicationBuilder.New().UseMiddleware(middleware.GetType()).Build(), (middleware, requestPath, pipeline) =>
InnerMap(middleware => routeBuilder.ApplicationBuilder.New().UseMiddleware(middleware.GetType()).Build(), (requestPath, pipeline) =>
{
foreach (string verb in middleware.EndpointOptions.GetSafeAllowedVerbs())
{
routeBuilder.MapVerb(verb, requestPath, pipeline);
}
routeBuilder.MapRoute(requestPath, pipeline);
});
}

private void InnerMap(Func<IEndpointMiddleware, RequestDelegate> createPipeline, Action<IEndpointMiddleware, string, RequestDelegate> applyMapping)
private void InnerMap(Func<IEndpointMiddleware, RequestDelegate> createPipeline, Action<string, RequestDelegate> applyMapping)
{
var collection = new HashSet<string>();

Expand All @@ -95,7 +86,7 @@ private void InnerMap(Func<IEndpointMiddleware, RequestDelegate> createPipeline,
}

private void MapEndpoints(HashSet<string> collection, string? baseRequestPath, IEnumerable<IEndpointMiddleware> middlewares,
Func<IEndpointMiddleware, RequestDelegate> createPipeline, Action<IEndpointMiddleware, string, RequestDelegate> applyMapping)
Func<IEndpointMiddleware, RequestDelegate> createPipeline, Action<string, RequestDelegate> applyMapping)
{
foreach (IEndpointMiddleware middleware in middlewares)
{
Expand All @@ -105,7 +96,7 @@ private void MapEndpoints(HashSet<string> collection, string? baseRequestPath, I

if (collection.Add(requestPath))
{
applyMapping(middleware, requestPath, pipeline);
applyMapping(requestPath, pipeline);
}
else
{
Expand Down
22 changes: 19 additions & 3 deletions src/Management/src/Endpoint/Middleware/EndpointMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,32 @@ public async Task InvokeAsync(HttpContext context, RequestDelegate? next)
{
ArgumentNullException.ThrowIfNull(context);

bool notFound = true;
bool verbNotAllowed = false;

if (ShouldInvoke(context.Request.Path))
{
TResult result = await InvokeEndpointHandlerAsync(context, context.RequestAborted);
await WriteResponseAsync(result, context, context.RequestAborted);
HashSet<string> allowedVerbs = EndpointOptions.GetSafeAllowedVerbs();

notFound = allowedVerbs.Count == 0;
verbNotAllowed = !allowedVerbs.Contains(context.Request.Method);
}
else

if (notFound)
{
// Terminal middleware
context.Response.StatusCode = (int)HttpStatusCode.NotFound;
}
else if (verbNotAllowed)
{
// Terminal middleware
context.Response.StatusCode = (int)HttpStatusCode.MethodNotAllowed;
}
else
{
TResult result = await InvokeEndpointHandlerAsync(context, context.RequestAborted);
await WriteResponseAsync(result, context, context.RequestAborted);
}
}

protected abstract Task<TResult> InvokeEndpointHandlerAsync(HttpContext context, CancellationToken cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,67 @@ public async Task Allows_only_POST_requests()
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
getResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);
}

[Fact]
public async Task Can_be_configured_to_unexposed()
{
var appSettings = new Dictionary<string, string?>
{
["Management:Endpoints:Actuator:Exposure:Include:0"] = string.Empty
};

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddInMemoryCollection(appSettings);
builder.Services.AddControllersWithViews(options => options.EnableEndpointRouting = false);
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.UseMvc(routes => routes.MapRoute("default", "{controller=Home}/{action=Index}/{id?}"));
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
}

[Fact]
public async Task Can_be_configured_to_disabled()
{
var appSettings = new Dictionary<string, string?>
{
["Management:Endpoints:Actuator:Exposure:Include:0"] = "refresh",
["Management:Endpoints:Refresh:Enabled"] = "false"
};

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddInMemoryCollection(appSettings);
builder.Services.AddControllersWithViews(options => options.EnableEndpointRouting = false);
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.UseMvc(routes => routes.MapRoute("default", "{controller=Home}/{action=Index}/{id?}"));
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
}

[Fact]
public async Task Can_be_configured_to_allow_no_verbs()
{
Expand Down Expand Up @@ -94,7 +149,7 @@ public async Task Can_be_configured_to_allow_only_GET_requests()
getResponse.StatusCode.Should().Be(HttpStatusCode.OK);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
postResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);
}

[Fact]
Expand Down Expand Up @@ -125,4 +180,68 @@ public async Task Can_be_configured_to_allow_both_GET_and_POST_requests()
HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);
}

[Fact]
public async Task Can_change_allowed_verbs_at_runtime()
{
const string fileName = "appsettings.json";
MemoryFileProvider fileProvider = new();

fileProvider.IncludeFile(fileName, """
{
"Management": {
"Endpoints": {
"Actuator": {
"Exposure": {
"Include": ["refresh"]
}
}
}
}
}
""");

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddJsonFile(fileProvider, fileName, false, true);
builder.Services.AddControllersWithViews(options => options.EnableEndpointRouting = false);
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.UseMvc(routes => routes.MapRoute("default", "{controller=Home}/{action=Index}/{id?}"));
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);

fileProvider.ReplaceFile(fileName, """
{
"Management": {
"Endpoints": {
"Actuator": {
"Exposure": {
"Include": ["refresh"]
}
},
"Refresh": {
"AllowedVerbs": ["GET"]
}
}
}
}
""");

fileProvider.NotifyChanged();

getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.OK);

postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,61 @@ public async Task Allows_only_POST_requests()
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);
}

[Fact]
public async Task Can_be_configured_to_unexposed()
{
var appSettings = new Dictionary<string, string?>
{
["Management:Endpoints:Actuator:Exposure:Include:0"] = string.Empty
};

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddInMemoryCollection(appSettings);
builder.Services.AddControllersWithViews();
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.MapDefaultControllerRoute();
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
}

[Fact]
public async Task Can_be_configured_to_disabled()
{
var appSettings = new Dictionary<string, string?>
{
["Management:Endpoints:Actuator:Exposure:Include:0"] = "refresh",
["Management:Endpoints:Refresh:Enabled"] = "false"
};

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddInMemoryCollection(appSettings);
builder.Services.AddControllersWithViews();
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.MapDefaultControllerRoute();
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.NotFound);
}

[Fact]
public async Task Can_be_configured_to_allow_no_verbs()
{
Expand Down Expand Up @@ -125,4 +180,68 @@ public async Task Can_be_configured_to_allow_both_GET_and_POST_requests()
HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);
}

[Fact]
public async Task Can_change_allowed_verbs_at_runtime()
{
const string fileName = "appsettings.json";
MemoryFileProvider fileProvider = new();

fileProvider.IncludeFile(fileName, """
{
"Management": {
"Endpoints": {
"Actuator": {
"Exposure": {
"Include": ["refresh"]
}
}
}
}
}
""");

WebApplicationBuilder builder = TestWebApplicationBuilderFactory.Create();
builder.Configuration.AddJsonFile(fileProvider, fileName, false, true);
builder.Services.AddControllersWithViews();
builder.Services.AddRefreshActuator();

await using WebApplication app = builder.Build();
app.MapDefaultControllerRoute();
await app.StartAsync();

using HttpClient httpClient = app.GetTestClient();
var requestUri = new Uri("/actuator/refresh", UriKind.Relative);

HttpResponseMessage getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);

HttpResponseMessage postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.OK);

fileProvider.ReplaceFile(fileName, """
{
"Management": {
"Endpoints": {
"Actuator": {
"Exposure": {
"Include": ["refresh"]
}
},
"Refresh": {
"AllowedVerbs": ["GET"]
}
}
}
}
""");

fileProvider.NotifyChanged();

getResponse = await httpClient.GetAsync(requestUri);
getResponse.StatusCode.Should().Be(HttpStatusCode.OK);

postResponse = await httpClient.PostAsync(requestUri, null);
postResponse.StatusCode.Should().Be(HttpStatusCode.MethodNotAllowed);
}
}
Loading

0 comments on commit b03df23

Please sign in to comment.