diff --git a/API/Data/Repositories/UserRepository.cs b/API/Data/Repositories/UserRepository.cs index 0b604e59ea..9f52b9efc0 100644 --- a/API/Data/Repositories/UserRepository.cs +++ b/API/Data/Repositories/UserRepository.cs @@ -72,7 +72,7 @@ public interface IUserRepository Task GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None); Task GetUserIdByUsernameAsync(string username); Task> GetAllBookmarksByIds(IList bookmarkIds); - Task GetUserByEmailAsync(string email); + Task GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None); Task> GetAllPreferencesByThemeAsync(int themeId); Task HasAccessToLibrary(int libraryId, int userId); Task HasAccessToSeries(int userId, int seriesId); @@ -240,10 +240,12 @@ public async Task> GetAllBookmarksByIds(IList bookma .ToListAsync(); } - public async Task GetUserByEmailAsync(string email) + public async Task GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None) { var lowerEmail = email.ToLower(); - return await _context.AppUser.SingleOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail)); + return await _context.AppUser + .Includes(includes) + .FirstOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail)); } diff --git a/API/Middleware/CustomAuthHeaderMiddleware.cs b/API/Middleware/CustomAuthHeaderMiddleware.cs new file mode 100644 index 0000000000..74d95a6dc9 --- /dev/null +++ b/API/Middleware/CustomAuthHeaderMiddleware.cs @@ -0,0 +1,101 @@ +using System; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using API.Data; +using API.Services; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace API.Middleware; + +public class CustomAuthHeaderMiddleware(RequestDelegate next) +{ + // Hardcoded list of allowed IP addresses in CIDR format + private readonly string[] allowedIpAddresses = { "192.168.1.0/24", "2001:db8::/32", "116.202.233.5", "104.21.81.112" }; + + + public async Task Invoke(HttpContext context, IUnitOfWork unitOfWork, ILogger logger, ITokenService tokenService) + { + // Extract user information from the custom header + string remoteUser = context.Request.Headers["Remote-User"]; + + // If header missing or user already authenticated, move on + if (string.IsNullOrEmpty(remoteUser) || context.User.Identity is {IsAuthenticated: true}) + { + await next(context); + return; + } + + // Validate IP address + if (IsValidIpAddress(context.Connection.RemoteIpAddress)) + { + // Perform additional authentication logic if needed + // For now, you can log the authenticated user + var user = await unitOfWork.UserRepository.GetUserByEmailAsync(remoteUser); + if (user == null) + { + // Tell security log maybe? + context.Response.StatusCode = (int)HttpStatusCode.Unauthorized; + return; + } + // Check if the RemoteUser has an account on the server + // if (!context.Request.Path.Equals("/login", StringComparison.OrdinalIgnoreCase)) + // { + // // Attach the Auth header and allow it to pass through + // var token = await tokenService.CreateToken(user); + // context.Request.Headers.Add("Authorization", $"Bearer {token}"); + // //context.Response.Redirect($"/login?apiKey={user.ApiKey}"); + // return; + // } + // Attach the Auth header and allow it to pass through + var token = await tokenService.CreateToken(user); + context.Request.Headers.Append("Authorization", $"Bearer {token}"); + await next(context); + return; + } + + context.Response.StatusCode = (int)HttpStatusCode.Unauthorized; + await next(context); + } + + private bool IsValidIpAddress(IPAddress ipAddress) + { + // Check if the IP address is in the whitelist + return allowedIpAddresses.Any(ipRange => IpAddressRange.Parse(ipRange).Contains(ipAddress)); + } +} + +// Helper class for IP address range parsing +public class IpAddressRange +{ + private readonly uint _startAddress; + private readonly uint _endAddress; + + private IpAddressRange(uint startAddress, uint endAddress) + { + _startAddress = startAddress; + _endAddress = endAddress; + } + + public bool Contains(IPAddress address) + { + var ipAddressBytes = address.GetAddressBytes(); + var ipAddress = BitConverter.ToUInt32(ipAddressBytes.Reverse().ToArray(), 0); + return ipAddress >= _startAddress && ipAddress <= _endAddress; + } + + public static IpAddressRange Parse(string ipRange) + { + var parts = ipRange.Split('/'); + var ipAddress = IPAddress.Parse(parts[0]); + var maskBits = int.Parse(parts[1]); + + var ipBytes = ipAddress.GetAddressBytes().Reverse().ToArray(); + var startAddress = BitConverter.ToUInt32(ipBytes, 0); + var endAddress = startAddress | (uint.MaxValue >> maskBits); + + return new IpAddressRange(startAddress, endAddress); + } +} diff --git a/API/Middleware/ExceptionMiddleware.cs b/API/Middleware/ExceptionMiddleware.cs index 98c3c6aec2..0b2b308c9d 100644 --- a/API/Middleware/ExceptionMiddleware.cs +++ b/API/Middleware/ExceptionMiddleware.cs @@ -9,27 +9,17 @@ namespace API.Middleware; -public class ExceptionMiddleware +public class ExceptionMiddleware(RequestDelegate next, ILogger logger) { - private readonly RequestDelegate _next; - private readonly ILogger _logger; - - - public ExceptionMiddleware(RequestDelegate next, ILogger logger) - { - _next = next; - _logger = logger; - } - public async Task InvokeAsync(HttpContext context) { try { - await _next(context); // downstream middlewares or http call + await next(context); // downstream middlewares or http call } catch (Exception ex) { - _logger.LogError(ex, "There was an exception"); + logger.LogError(ex, "There was an exception"); context.Response.ContentType = "application/json"; context.Response.StatusCode = (int) HttpStatusCode.InternalServerError; diff --git a/API/Middleware/JWTRevocationMiddleware.cs b/API/Middleware/JWTRevocationMiddleware.cs index 2bcba34252..65ea9e80f0 100644 --- a/API/Middleware/JWTRevocationMiddleware.cs +++ b/API/Middleware/JWTRevocationMiddleware.cs @@ -10,24 +10,16 @@ namespace API.Middleware; /// /// Responsible for maintaining an in-memory. Not in use /// -public class JwtRevocationMiddleware +public class JwtRevocationMiddleware( + RequestDelegate next, + IEasyCachingProviderFactory cacheFactory, + ILogger logger) { - private readonly RequestDelegate _next; - private readonly IEasyCachingProviderFactory _cacheFactory; - private readonly ILogger _logger; - - public JwtRevocationMiddleware(RequestDelegate next, IEasyCachingProviderFactory cacheFactory, ILogger logger) - { - _next = next; - _cacheFactory = cacheFactory; - _logger = logger; - } - public async Task InvokeAsync(HttpContext context) { if (context.User.Identity is {IsAuthenticated: false}) { - await _next(context); + await next(context); return; } @@ -37,18 +29,18 @@ public async Task InvokeAsync(HttpContext context) // Check if the token is revoked if (await IsTokenRevoked(token)) { - _logger.LogWarning("Revoked token detected: {Token}", token); + logger.LogWarning("Revoked token detected: {Token}", token); context.Response.StatusCode = StatusCodes.Status401Unauthorized; return; } - await _next(context); + await next(context); } private async Task IsTokenRevoked(string token) { // Check if the token exists in the revocation list stored in the cache - var isRevoked = await _cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt) + var isRevoked = await cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt) .GetAsync(token); diff --git a/API/Middleware/SecurityMiddleware.cs b/API/Middleware/SecurityMiddleware.cs index 5b2019594f..61ca1c75d1 100644 --- a/API/Middleware/SecurityMiddleware.cs +++ b/API/Middleware/SecurityMiddleware.cs @@ -7,37 +7,29 @@ using Kavita.Common; using Microsoft.AspNetCore.Http; using Serilog; -using ILogger = Serilog.ILogger; +using ILogger = Serilog.Core.Logger; namespace API.Middleware; -public class SecurityEventMiddleware +public class SecurityEventMiddleware(RequestDelegate next) { - private readonly RequestDelegate _next; - private readonly ILogger _logger; - - public SecurityEventMiddleware(RequestDelegate next) - { - _next = next; - - _logger = new LoggerConfiguration() - .MinimumLevel.Debug() - .WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day) - .CreateLogger(); - } + private readonly ILogger _logger = new LoggerConfiguration() + .MinimumLevel.Debug() + .WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day) + .CreateLogger(); public async Task InvokeAsync(HttpContext context) { try { - await _next(context); + await next(context); } catch (KavitaUnauthenticatedUserException ex) { var ipAddress = context.Connection.RemoteIpAddress?.ToString(); var requestMethod = context.Request.Method; var requestPath = context.Request.Path; - var userAgent = context.Request.Headers["User-Agent"]; + var userAgent = context.Request.Headers.UserAgent; var securityEvent = new { IpAddress = ipAddress, @@ -57,8 +49,7 @@ public async Task InvokeAsync(HttpContext context) var options = new JsonSerializerOptions { - PropertyNamingPolicy = - JsonNamingPolicy.CamelCase + PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; var json = JsonSerializer.Serialize(response, options); diff --git a/API/Startup.cs b/API/Startup.cs index 939bfb5867..b108d13f7b 100644 --- a/API/Startup.cs +++ b/API/Startup.cs @@ -261,6 +261,7 @@ public void Configure(IApplicationBuilder app, IBackgroundJobClient backgroundJo app.UseMiddleware(); app.UseMiddleware(); + app.UseMiddleware(); if (env.IsDevelopment()) diff --git a/Kavita.Common/Kavita.Common.csproj b/Kavita.Common/Kavita.Common.csproj index ddcbd37057..651a32ead9 100644 --- a/Kavita.Common/Kavita.Common.csproj +++ b/Kavita.Common/Kavita.Common.csproj @@ -4,7 +4,7 @@ net8.0 kavitareader.com Kavita - 0.7.11.7 + 0.7.11.10 en true diff --git a/openapi.json b/openapi.json index f813654093..0b8bc25ef4 100644 --- a/openapi.json +++ b/openapi.json @@ -7,7 +7,7 @@ "name": "GPL-3.0", "url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE" }, - "version": "0.7.11.6" + "version": "0.7.11.10" }, "servers": [ {