Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Headers Test 4 #2524

Merged
merged 11 commits into from
Jan 5, 2024
8 changes: 5 additions & 3 deletions API/Data/Repositories/UserRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public interface IUserRepository
Task<AppUser?> GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None);
Task<int> GetUserIdByUsernameAsync(string username);
Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookmarkIds);
Task<AppUser?> GetUserByEmailAsync(string email);
Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None);
Task<IEnumerable<AppUserPreferences>> GetAllPreferencesByThemeAsync(int themeId);
Task<bool> HasAccessToLibrary(int libraryId, int userId);
Task<bool> HasAccessToSeries(int userId, int seriesId);
Expand Down Expand Up @@ -240,10 +240,12 @@ public async Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookma
.ToListAsync();
}

public async Task<AppUser?> GetUserByEmailAsync(string email)
public async Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None)
{
var lowerEmail = email.ToLower();
return await _context.AppUser.SingleOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail));
return await _context.AppUser
.Includes(includes)
.FirstOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail));
}


Expand Down
101 changes: 101 additions & 0 deletions API/Middleware/CustomAuthHeaderMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using System;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using API.Data;
using API.Services;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace API.Middleware;

public class CustomAuthHeaderMiddleware(RequestDelegate next)
{
// Hardcoded list of allowed IP addresses in CIDR format
private readonly string[] allowedIpAddresses = { "192.168.1.0/24", "2001:db8::/32", "116.202.233.5", "104.21.81.112" };


public async Task Invoke(HttpContext context, IUnitOfWork unitOfWork, ILogger<CustomAuthHeaderMiddleware> logger, ITokenService tokenService)
{
// Extract user information from the custom header
string remoteUser = context.Request.Headers["Remote-User"];

// If header missing or user already authenticated, move on
if (string.IsNullOrEmpty(remoteUser) || context.User.Identity is {IsAuthenticated: true})
{
await next(context);
return;
}

// Validate IP address
if (IsValidIpAddress(context.Connection.RemoteIpAddress))
{
// Perform additional authentication logic if needed
// For now, you can log the authenticated user
var user = await unitOfWork.UserRepository.GetUserByEmailAsync(remoteUser);
if (user == null)
{
// Tell security log maybe?
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
return;
}
// Check if the RemoteUser has an account on the server
// if (!context.Request.Path.Equals("/login", StringComparison.OrdinalIgnoreCase))
// {
// // Attach the Auth header and allow it to pass through
// var token = await tokenService.CreateToken(user);
// context.Request.Headers.Add("Authorization", $"Bearer {token}");
// //context.Response.Redirect($"/login?apiKey={user.ApiKey}");
// return;
// }
// Attach the Auth header and allow it to pass through
var token = await tokenService.CreateToken(user);
context.Request.Headers.Append("Authorization", $"Bearer {token}");
await next(context);
return;
}

context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
await next(context);
}

private bool IsValidIpAddress(IPAddress ipAddress)
{
// Check if the IP address is in the whitelist
return allowedIpAddresses.Any(ipRange => IpAddressRange.Parse(ipRange).Contains(ipAddress));
}
}

// Helper class for IP address range parsing
public class IpAddressRange
{
private readonly uint _startAddress;
private readonly uint _endAddress;

private IpAddressRange(uint startAddress, uint endAddress)
{
_startAddress = startAddress;
_endAddress = endAddress;
}

public bool Contains(IPAddress address)
{
var ipAddressBytes = address.GetAddressBytes();
var ipAddress = BitConverter.ToUInt32(ipAddressBytes.Reverse().ToArray(), 0);
return ipAddress >= _startAddress && ipAddress <= _endAddress;
}

public static IpAddressRange Parse(string ipRange)
{
var parts = ipRange.Split('/');
var ipAddress = IPAddress.Parse(parts[0]);
var maskBits = int.Parse(parts[1]);

var ipBytes = ipAddress.GetAddressBytes().Reverse().ToArray();
var startAddress = BitConverter.ToUInt32(ipBytes, 0);
var endAddress = startAddress | (uint.MaxValue >> maskBits);

return new IpAddressRange(startAddress, endAddress);
}
}
16 changes: 3 additions & 13 deletions API/Middleware/ExceptionMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,17 @@

namespace API.Middleware;

public class ExceptionMiddleware
public class ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
{
private readonly RequestDelegate _next;
private readonly ILogger<ExceptionMiddleware> _logger;


public ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
{
_next = next;
_logger = logger;
}

public async Task InvokeAsync(HttpContext context)
{
try
{
await _next(context); // downstream middlewares or http call
await next(context); // downstream middlewares or http call
}
catch (Exception ex)
{
_logger.LogError(ex, "There was an exception");
logger.LogError(ex, "There was an exception");
context.Response.ContentType = "application/json";
context.Response.StatusCode = (int) HttpStatusCode.InternalServerError;

Expand Down
24 changes: 8 additions & 16 deletions API/Middleware/JWTRevocationMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,16 @@ namespace API.Middleware;
/// <summary>
/// Responsible for maintaining an in-memory. Not in use
/// </summary>
public class JwtRevocationMiddleware
public class JwtRevocationMiddleware(
RequestDelegate next,
IEasyCachingProviderFactory cacheFactory,
ILogger<JwtRevocationMiddleware> logger)
{
private readonly RequestDelegate _next;
private readonly IEasyCachingProviderFactory _cacheFactory;
private readonly ILogger<JwtRevocationMiddleware> _logger;

public JwtRevocationMiddleware(RequestDelegate next, IEasyCachingProviderFactory cacheFactory, ILogger<JwtRevocationMiddleware> logger)
{
_next = next;
_cacheFactory = cacheFactory;
_logger = logger;
}

public async Task InvokeAsync(HttpContext context)
{
if (context.User.Identity is {IsAuthenticated: false})
{
await _next(context);
await next(context);
return;
}

Expand All @@ -37,18 +29,18 @@ public async Task InvokeAsync(HttpContext context)
// Check if the token is revoked
if (await IsTokenRevoked(token))
{
_logger.LogWarning("Revoked token detected: {Token}", token);
logger.LogWarning("Revoked token detected: {Token}", token);
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
return;
}

await _next(context);
await next(context);
}

private async Task<bool> IsTokenRevoked(string token)
{
// Check if the token exists in the revocation list stored in the cache
var isRevoked = await _cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt)
var isRevoked = await cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt)
.GetAsync<string>(token);


Expand Down
27 changes: 9 additions & 18 deletions API/Middleware/SecurityMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,29 @@
using Kavita.Common;
using Microsoft.AspNetCore.Http;
using Serilog;
using ILogger = Serilog.ILogger;
using ILogger = Serilog.Core.Logger;

namespace API.Middleware;

public class SecurityEventMiddleware
public class SecurityEventMiddleware(RequestDelegate next)
{
private readonly RequestDelegate _next;
private readonly ILogger _logger;

public SecurityEventMiddleware(RequestDelegate next)
{
_next = next;

_logger = new LoggerConfiguration()
.MinimumLevel.Debug()
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
.CreateLogger();
}
private readonly ILogger _logger = new LoggerConfiguration()
.MinimumLevel.Debug()
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
.CreateLogger();

public async Task InvokeAsync(HttpContext context)
{
try
{
await _next(context);
await next(context);
}
catch (KavitaUnauthenticatedUserException ex)
{
var ipAddress = context.Connection.RemoteIpAddress?.ToString();
var requestMethod = context.Request.Method;
var requestPath = context.Request.Path;
var userAgent = context.Request.Headers["User-Agent"];
var userAgent = context.Request.Headers.UserAgent;
var securityEvent = new
{
IpAddress = ipAddress,
Expand All @@ -57,8 +49,7 @@ public async Task InvokeAsync(HttpContext context)

var options = new JsonSerializerOptions
{
PropertyNamingPolicy =
JsonNamingPolicy.CamelCase
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
};

var json = JsonSerializer.Serialize(response, options);
Expand Down
1 change: 1 addition & 0 deletions API/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ public void Configure(IApplicationBuilder app, IBackgroundJobClient backgroundJo

app.UseMiddleware<ExceptionMiddleware>();
app.UseMiddleware<SecurityEventMiddleware>();
app.UseMiddleware<CustomAuthHeaderMiddleware>();


if (env.IsDevelopment())
Expand Down
2 changes: 1 addition & 1 deletion Kavita.Common/Kavita.Common.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<TargetFramework>net8.0</TargetFramework>
<Company>kavitareader.com</Company>
<Product>Kavita</Product>
<AssemblyVersion>0.7.11.7</AssemblyVersion>
<AssemblyVersion>0.7.11.10</AssemblyVersion>
<NeutralLanguage>en</NeutralLanguage>
<TieredPGO>true</TieredPGO>
</PropertyGroup>
Expand Down
2 changes: 1 addition & 1 deletion openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"name": "GPL-3.0",
"url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE"
},
"version": "0.7.11.6"
"version": "0.7.11.10"
},
"servers": [
{
Expand Down
Loading