diff --git a/EmbedIO.sln b/EmbedIO.sln
index dc415bcf2..84da6b6f0 100644
--- a/EmbedIO.sln
+++ b/EmbedIO.sln
@@ -1,13 +1,12 @@
Microsoft Visual Studio Solution File, Format Version 12.00
-# Visual Studio 15
-VisualStudioVersion = 15.0.26730.16
+# Visual Studio Version 16
+VisualStudioVersion = 16.0.29609.76
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{97BC259A-4E78-4BA8-8F4D-2656BC78BB34}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{73F25F81-0412-412E-89C9-BAD33E9BCCDE}"
ProjectSection(SolutionItems) = preProject
- .travis.yml = .travis.yml
appveyor.yml = appveyor.yml
LICENSE = LICENSE
README.md = README.md
diff --git a/src/EmbedIO.Samples/Program.cs b/src/EmbedIO.Samples/Program.cs
index f0021e580..89adcb1b9 100644
--- a/src/EmbedIO.Samples/Program.cs
+++ b/src/EmbedIO.Samples/Program.cs
@@ -1,10 +1,12 @@
using System;
+using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using EmbedIO.Actions;
using EmbedIO.Files;
+using EmbedIO.Security;
using EmbedIO.WebApi;
using Swan;
using Swan.Logging;
@@ -60,6 +62,13 @@ private static WebServer CreateWebServer(string url)
var server = new WebServer(o => o
.WithUrlPrefix(url)
.WithMode(HttpListenerMode.EmbedIO))
+ .WithIPBanning(o => o
+ .WithWhitelist(
+ "",
+ "172.16.16.124",
+ "172.16.17.1/24",
+ "192.168.1-2.2-5")
+ .WithRules("(404 Not Found)+"), 5,5)
.WithLocalSessionManager()
.WithCors(
// Origins, separated by comma without last slash
diff --git a/src/EmbedIO/Security/BannedInfo.cs b/src/EmbedIO/Security/BannedInfo.cs
new file mode 100644
index 000000000..afe250b51
--- /dev/null
+++ b/src/EmbedIO/Security/BannedInfo.cs
@@ -0,0 +1,25 @@
+using System.Net;
+
+namespace EmbedIO.Security
+{
+ ///
+ /// Represents the info af a banned IP address.
+ ///
+ public class BannedInfo
+ {
+ ///
+ /// Gets or sets the banned IP address.
+ ///
+ public IPAddress IPAddress { get; set; }
+
+ ///
+ /// Gets or sets until when the IP will remain ban.
+ ///
+ public long BanUntil { get; set; }
+
+ ///
+ /// Gets or sets a value indicating whether this instance was explicitly banned by user.
+ ///
+ public bool IsExplicit { get; set; }
+ }
+}
diff --git a/src/EmbedIO/Security/IPBanningModule.cs b/src/EmbedIO/Security/IPBanningModule.cs
new file mode 100644
index 000000000..bcdc13f5c
--- /dev/null
+++ b/src/EmbedIO/Security/IPBanningModule.cs
@@ -0,0 +1,301 @@
+using EmbedIO.Utilities;
+using Swan;
+using Swan.Logging;
+using Swan.Threading;
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Text.RegularExpressions;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace EmbedIO.Security
+{
+ ///
+ /// A module for ban IPs that show the malicious signs, based on scanning log messages.
+ ///
+ ///
+ public class IPBanningModule : WebModuleBase, ILogger
+ {
+ ///
+ /// The default ban time, in minutes.
+ ///
+ public const int DefaultBanTime = 30;
+
+ ///
+ /// The default maximum retries per minute.
+ ///
+ public const int DefaultMaxRetry = 10;
+
+ private static readonly ConcurrentDictionary> AccessAttempts = new ConcurrentDictionary>();
+ private static readonly ConcurrentDictionary Blacklist = new ConcurrentDictionary();
+ private static readonly ConcurrentDictionary FailRegex = new ConcurrentDictionary();
+ private static readonly PeriodicTask? Purger;
+
+ private readonly List _whitelist = new List();
+ private readonly int _banTime;
+ private readonly int _maxRetry;
+ private bool _disposedValue;
+
+ static IPBanningModule()
+ {
+ Purger = new PeriodicTask(TimeSpan.FromMinutes(1), ct =>
+ {
+ PurgeBlackList();
+ PurgeAccessAttempts();
+
+ return Task.CompletedTask;
+ });
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The base route.
+ /// A collection of regex to match the log messages against.
+ /// The time that an IP will remain ban, in minutes.
+ /// The maximum number of failed attempts before banning an IP.
+ public IPBanningModule(string baseRoute,
+ IEnumerable failRegex,
+ int banTime = DefaultBanTime,
+ int maxRetry = DefaultMaxRetry)
+ : this(baseRoute, failRegex, null, banTime, maxRetry)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The base route.
+ /// A collection of regex to match the log messages against.
+ /// A collection of valid IPs that never will be banned.
+ /// The time that an IP will remain ban, in minutes.
+ /// The maximum number of failed attempts before banning an IP.
+ public IPBanningModule(string baseRoute,
+ IEnumerable? failRegex = null,
+ IEnumerable? whitelist = null,
+ int banTime = DefaultBanTime,
+ int maxRetry = DefaultMaxRetry)
+ : base(baseRoute)
+ {
+ if (failRegex != null)
+ AddRules(failRegex);
+
+ _banTime = banTime;
+ _maxRetry = maxRetry;
+ AddToWhitelist(whitelist);
+ Logger.RegisterLogger(this);
+ }
+
+ ///
+ public override bool IsFinalHandler => false;
+
+ ///
+ public LogLevel LogLevel => LogLevel.Trace;
+
+ private IPAddress? ClientAddress { get; set; }
+
+ ///
+ /// Gets the list of current banned IPs.
+ ///
+ /// A collection of in the blacklist.
+ public static IEnumerable GetBannedIPs() =>
+ Blacklist.Values.ToList();
+
+ ///
+ /// Tries to ban an IP explicitly.
+ ///
+ /// The IP address to ban.
+ /// The time in minutes that the IP will remain ban.
+ /// if set to true [is explicit].
+ ///
+ /// true if the IP was added to the blacklist; otherwise, false.
+ ///
+ public static bool TryBanIP(IPAddress address, int minutes, bool isExplicit = true) =>
+ TryBanIP(address, DateTime.Now.AddMinutes(minutes), isExplicit);
+
+ ///
+ /// Tries to ban an IP explicitly.
+ ///
+ /// The IP address to ban.
+ /// An that sets the time the IP will remain ban.
+ /// if set to true [is explicit].
+ ///
+ /// true if the IP was added to the blacklist; otherwise, false.
+ ///
+ public static bool TryBanIP(IPAddress address, TimeSpan banTime, bool isExplicit = true) =>
+ TryBanIP(address, DateTime.Now.Add(banTime), isExplicit);
+
+ ///
+ /// Tries to ban an IP explicitly.
+ ///
+ /// The IP address to ban.
+ /// A that sets until when the IP will remain ban.
+ /// if set to true [is explicit].
+ ///
+ /// true if the IP was added to the blacklist; otherwise, false.
+ ///
+ public static bool TryBanIP(IPAddress address, DateTime banUntil, bool isExplicit = true)
+ {
+ if (Blacklist.ContainsKey(address))
+ {
+ var bannedInfo = Blacklist[address];
+ bannedInfo.BanUntil = banUntil.Ticks;
+ bannedInfo.IsExplicit = isExplicit;
+
+ return true;
+ }
+
+ return Blacklist.TryAdd(address, new BannedInfo()
+ {
+ IPAddress = address,
+ BanUntil = banUntil.Ticks,
+ IsExplicit = isExplicit,
+ });
+ }
+
+ ///
+ /// Tries to unban an IP explicitly.
+ ///
+ /// The IP address.
+ ///
+ /// true if the IP was removed from the blacklist; otherwise, false.
+ ///
+ public static bool TryUnbanIP(IPAddress address) =>
+ Blacklist.TryRemove(address, out _);
+
+ ///
+ public void Log(LogMessageReceivedEventArgs logEvent)
+ {
+ // Process Log
+ if (string.IsNullOrWhiteSpace(logEvent.Message) ||
+ ClientAddress == null ||
+ !FailRegex.Any() ||
+ _whitelist.Contains(ClientAddress) ||
+ Blacklist.ContainsKey(ClientAddress))
+ return;
+
+ foreach (var regex in FailRegex.Values)
+ {
+ try
+ {
+ if (!regex.IsMatch(logEvent.Message)) continue;
+
+ // Add to list
+ AddAccessAttempt(ClientAddress);
+ UpdateBlackList();
+ break;
+ }
+ catch (RegexMatchTimeoutException ex)
+ {
+ $"Timeout trying to match '{ex.Input}' with pattern '{ex.Pattern}'.".Error(nameof(IPBanningModule));
+ }
+ }
+ }
+
+ ///
+ public void Dispose() =>
+ Dispose(true);
+
+ internal void AddRules(IEnumerable patterns)
+ {
+ foreach (var pattern in patterns)
+ AddRule(pattern);
+ }
+
+ internal void AddRule(string pattern)
+ {
+ try
+ {
+ FailRegex.TryAdd(pattern, new Regex(pattern, RegexOptions.Compiled | RegexOptions.CultureInvariant, TimeSpan.FromMilliseconds(500)));
+ }
+ catch (Exception ex)
+ {
+ ex.Log(nameof(IPBanningModule), $"Invalid regex - '{pattern}'.");
+ }
+ }
+
+ internal void AddToWhitelist(IEnumerable whitelist) =>
+ AddToWhitelistAsync(whitelist).GetAwaiter().GetResult();
+
+ internal async Task AddToWhitelistAsync(IEnumerable whitelist)
+ {
+ if (whitelist?.Any() != true)
+ return;
+
+ foreach (var address in whitelist)
+ {
+ var addressees = await IPParser.Parse(address).ConfigureAwait(false);
+ _whitelist.AddRange(addressees.Where(x => !_whitelist.Contains(x)));
+ }
+ }
+
+ ///
+ protected override Task OnRequestAsync(IHttpContext context)
+ {
+ ClientAddress = context.Request.RemoteEndPoint.Address;
+ if (!Blacklist.ContainsKey(ClientAddress))
+ return Task.CompletedTask;
+
+ context.SetHandled();
+ throw HttpException.Forbidden();
+ }
+
+ ///
+ /// Releases unmanaged and - optionally - managed resources.
+ ///
+ /// true to release both managed and unmanaged resources; false to release only unmanaged resources.
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposedValue) return;
+ if (disposing)
+ {
+ _whitelist.Clear();
+ }
+
+ _disposedValue = true;
+ }
+
+ private static void AddAccessAttempt(IPAddress address)
+ {
+ if (AccessAttempts.ContainsKey(address))
+ AccessAttempts[address].Add(DateTime.Now.Ticks);
+ else
+ AccessAttempts.TryAdd(address, new ConcurrentBag() { DateTime.Now.Ticks });
+ }
+
+ private static void PurgeBlackList()
+ {
+ foreach (var k in Blacklist.Keys)
+ {
+ if (DateTime.Now.Ticks > Blacklist[k].BanUntil)
+ Blacklist.TryRemove(k, out _);
+ }
+ }
+
+ private static void PurgeAccessAttempts()
+ {
+ var banDate = DateTime.Now.AddMinutes(-1).Ticks;
+
+ foreach (var k in AccessAttempts.Keys)
+ {
+ var recentAttempts = new ConcurrentBag(AccessAttempts[k].Where(x => x >= banDate));
+ if (!recentAttempts.Any())
+ AccessAttempts.TryRemove(k, out _);
+ else
+ Interlocked.Exchange(ref recentAttempts, AccessAttempts[k]);
+ }
+ }
+
+ private void UpdateBlackList()
+ {
+ var time = DateTime.Now.AddMinutes(-1).Ticks;
+ if ((AccessAttempts[ClientAddress]?.Where(x => x >= time).Count() >= _maxRetry))
+ {
+ TryBanIP(ClientAddress, _banTime, false);
+ }
+ }
+ }
+}
diff --git a/src/EmbedIO/Security/IPBanningModuleExtensions.cs b/src/EmbedIO/Security/IPBanningModuleExtensions.cs
new file mode 100644
index 000000000..250aaba83
--- /dev/null
+++ b/src/EmbedIO/Security/IPBanningModuleExtensions.cs
@@ -0,0 +1,40 @@
+namespace EmbedIO.Security
+{
+ ///
+ /// Provides extension methods for and derived classes.
+ ///
+ public static class IPBanningModuleExtensions
+ {
+ ///
+ /// Adds a collection of valid IPs that never will be banned.
+ ///
+ /// The type of the module.
+ /// The module on which this method is called.
+ /// A collection of valid IPs that never will be banned.
+ ///
+ /// with its whitelist configured.
+ ///
+ public static TModule WithWhitelist(this TModule @this, params string[] value)
+ where TModule : IPBanningModule
+ {
+ @this.AddToWhitelist(value);
+ return @this;
+ }
+
+ ///
+ /// Add a collection of regex to match the log messages against.
+ ///
+ /// The type of the module.
+ /// The module on which this method is called.
+ /// A collection of regex to match the log messages against.
+ ///
+ /// with its fail regex configured.
+ ///
+ public static TModule WithRules(this TModule @this, params string[] value)
+ where TModule : IPBanningModule
+ {
+ @this.AddRules(value);
+ return @this;
+ }
+ }
+}
diff --git a/src/EmbedIO/Utilities/IPParser.cs b/src/EmbedIO/Utilities/IPParser.cs
new file mode 100644
index 000000000..0570e53a7
--- /dev/null
+++ b/src/EmbedIO/Utilities/IPParser.cs
@@ -0,0 +1,203 @@
+using System.Collections.Generic;
+using System.Globalization;
+using System.Linq;
+using System.Net;
+using System.Threading.Tasks;
+
+namespace EmbedIO.Utilities
+{
+ ///
+ /// Provides standard methods to parse IP address strings.
+ ///
+ public static class IPParser
+ {
+ ///
+ /// Parses the specified IP address.
+ ///
+ /// The IP address.
+ /// A collection of parsed correctly from .
+ public static async Task> Parse(string address)
+ {
+ if (address == null)
+ return Enumerable.Empty();
+
+ if (IPAddress.TryParse(address, out var ip))
+ return new List { ip };
+
+ try
+ {
+ return await Dns.GetHostAddressesAsync(address).ConfigureAwait(false);
+ }
+ catch
+ {
+ // Ignore
+ }
+
+ if (IsCIDRNotation(address))
+ return ParseCIDRNotation(address);
+
+ return IsSimpleIPRange(address) ? TryParseSimpleIPRange(address) : Enumerable.Empty();
+ }
+
+ ///
+ /// Determines whether the IP-range string is in CIDR notation.
+ ///
+ /// The IP-range string.
+ ///
+ /// true if the IP-range string is CIDR notation; otherwise, false.
+ ///
+ public static bool IsCIDRNotation(string range)
+ {
+ if (string.IsNullOrWhiteSpace(range))
+ return false;
+
+ var parts = range.Split('/');
+ if (parts.Length != 2)
+ return false;
+
+ var prefix = parts[0];
+ var prefixLen = parts[1];
+
+ var prefixParts = prefix.Split('.');
+ if (prefixParts.Length != 4)
+ return false;
+
+ return byte.TryParse(prefixLen, out var len) && len >= 0 && len <= 32;
+ }
+
+ ///
+ /// Parse IP-range string in CIDR notation. For example "12.15.0.0/16".
+ ///
+ /// The IP-range string.
+ /// A collection of parsed correctly from .
+ public static IEnumerable ParseCIDRNotation(string range)
+ {
+ if (!IsCIDRNotation(range))
+ return Enumerable.Empty();
+
+ var parts = range.Split('/');
+ var prefix = parts[0];
+
+ if (!byte.TryParse(parts[1], out var prefixLen))
+ return Enumerable.Empty();
+
+ var prefixParts = prefix.Split('.');
+ if (prefixParts.Select(x => byte.TryParse(x, out _)).Any(x => !x))
+ return Enumerable.Empty();
+
+ uint ip = 0;
+ for (var i = 0; i < 4; i++)
+ {
+ ip <<= 8;
+ ip += uint.Parse(prefixParts[i], NumberFormatInfo.InvariantInfo);
+ }
+
+ var shiftBits = (byte)(32 - prefixLen);
+ var ip1 = (ip >> shiftBits) << shiftBits;
+
+ if ((ip1 & ip) != ip1) // Check correct subnet address
+ return Enumerable.Empty();
+
+ var ip2 = ip1 >> shiftBits;
+ for (var k = 0; k < shiftBits; k++)
+ {
+ ip2 = (ip2 << 1) + 1;
+ }
+
+ var beginIP = new byte[4];
+ var endIP = new byte[4];
+
+ for (var i = 0; i < 4; i++)
+ {
+ beginIP[i] = (byte)((ip1 >> ((3 - i) * 8)) & 255);
+ endIP[i] = (byte)((ip2 >> ((3 - i) * 8)) & 255);
+ }
+
+ return GetAllIP(beginIP, endIP);
+ }
+
+ ///
+ /// Determines whether the IP-range string is in simple IP range notation.
+ ///
+ /// The IP-range string.
+ ///
+ /// true if the IP-range string is in simple IP range notation; otherwise, false.
+ ///
+ public static bool IsSimpleIPRange(string range)
+ {
+ if (string.IsNullOrWhiteSpace(range))
+ return false;
+
+ var parts = range.Split('.');
+ if (parts.Length != 4)
+ return false;
+
+ foreach (var part in parts)
+ {
+ var rangeParts = part.Split('-');
+ if (rangeParts.Length < 1 || rangeParts.Length > 2)
+ return false;
+
+ if (!byte.TryParse(rangeParts[0], out _) ||
+ (rangeParts.Length > 1 && !byte.TryParse(rangeParts[1], out _)))
+ return false;
+ }
+
+ return true;
+ }
+
+ ///
+ /// Tries Parse IP-range string "12.15-16.1-30.10-255"
+ ///
+ /// The IP-range string.
+ /// A collection of parsed correctly from .
+ public static IEnumerable TryParseSimpleIPRange(string range)
+ {
+ if (!IsSimpleIPRange(range))
+ return Enumerable.Empty();
+
+ var beginIP = new byte[4];
+ var endIP = new byte[4];
+
+ var parts = range.Split('.');
+ for (var i = 0; i < 4; i++)
+ {
+ var rangeParts = parts[i].Split('-');
+ beginIP[i] = byte.Parse(rangeParts[0], NumberFormatInfo.InvariantInfo);
+ endIP[i] = (rangeParts.Length == 1) ? beginIP[i] : byte.Parse(rangeParts[1], NumberFormatInfo.InvariantInfo);
+ }
+
+ return GetAllIP(beginIP, endIP);
+ }
+
+ private static IEnumerable GetAllIP(byte[] beginIP, byte[] endIP)
+ {
+ for (var i = 0; i < 4; i++)
+ {
+ if (endIP[i] < beginIP[i])
+ return Enumerable.Empty();
+ }
+
+ var capacity = 1;
+ for (var i = 0; i < 4; i++)
+ capacity *= endIP[i] - beginIP[i] + 1;
+
+ var ips = new List(capacity);
+ for (int i0 = beginIP[0]; i0 <= endIP[0]; i0++)
+ {
+ for (int i1 = beginIP[1]; i1 <= endIP[1]; i1++)
+ {
+ for (int i2 = beginIP[2]; i2 <= endIP[2]; i2++)
+ {
+ for (int i3 = beginIP[3]; i3 <= endIP[3]; i3++)
+ {
+ ips.Add(new IPAddress(new[] { (byte)i0, (byte)i1, (byte)i2, (byte)i3 }));
+ }
+ }
+ }
+ }
+
+ return ips;
+ }
+ }
+}
diff --git a/src/EmbedIO/WebModuleContainerExtensions-Security.cs b/src/EmbedIO/WebModuleContainerExtensions-Security.cs
new file mode 100644
index 000000000..1159fa040
--- /dev/null
+++ b/src/EmbedIO/WebModuleContainerExtensions-Security.cs
@@ -0,0 +1,68 @@
+using EmbedIO.Security;
+using System;
+using System.Collections.Generic;
+
+namespace EmbedIO
+{
+ partial class WebModuleContainerExtensions
+ {
+ ///
+ /// Withes the ip banning.
+ ///
+ /// The type of the container.
+ /// The on which this method is called.
+ /// A callback used to configure the module.
+ /// The time that an IP will remain ban, in minutes.
+ /// The maximum number of failed attempts before banning an IP.
+ /// with a added.
+ public static TContainer WithIPBanning(this TContainer @this,
+ Action? configure = null,
+ int banTime = IPBanningModule.DefaultBanTime,
+ int maxRetry = IPBanningModule.DefaultMaxRetry)
+ where TContainer : class, IWebModuleContainer =>
+ SetModule(@this, null, null, banTime, maxRetry, configure);
+
+ ///
+ /// Creates an instance of and adds it to a module container.
+ ///
+ /// The type of the module container.
+ /// The on which this method is called.
+ /// A collection of regex to match the log messages against.
+ /// The time that an IP will remain ban, in minutes.
+ /// The maximum number of failed attempts before banning an IP.
+ /// with a added.
+ public static TContainer WithIPBanning(this TContainer @this,
+ IEnumerable failRegex,
+ int banTime = IPBanningModule.DefaultBanTime,
+ int maxRetry = IPBanningModule.DefaultMaxRetry)
+ where TContainer : class, IWebModuleContainer =>
+ WithIPBanning(@this, failRegex, null, banTime, maxRetry);
+
+ ///
+ /// Creates an instance of and adds it to a module container.
+ ///
+ /// The type of the module container.
+ /// The on which this method is called.
+ /// A collection of regex to match the log messages against.
+ /// A collection of valid IPs that never will be banned.
+ /// The time that an IP will remain ban, in minutes.
+ /// The maximum number of failed attempts before banning an IP.
+ /// with a added.
+ public static TContainer WithIPBanning(this TContainer @this,
+ IEnumerable? failRegex,
+ IEnumerable? whitelist = null,
+ int banTime = IPBanningModule.DefaultBanTime,
+ int maxRetry = IPBanningModule.DefaultMaxRetry)
+ where TContainer : class, IWebModuleContainer =>
+ SetModule(@this, failRegex, whitelist, banTime, maxRetry);
+
+ private static TContainer SetModule(this TContainer @this,
+ IEnumerable? failRegex,
+ IEnumerable? whitelist = null,
+ int banTime = IPBanningModule.DefaultBanTime,
+ int maxRetry = IPBanningModule.DefaultMaxRetry,
+ Action? configure = null)
+ where TContainer : class, IWebModuleContainer =>
+ WithModule(@this, new IPBanningModule("/", failRegex, whitelist, banTime, maxRetry), configure);
+ }
+}
diff --git a/test/EmbedIO.Tests/IPBanningModuleTest.cs b/test/EmbedIO.Tests/IPBanningModuleTest.cs
new file mode 100644
index 000000000..b5c714f19
--- /dev/null
+++ b/test/EmbedIO.Tests/IPBanningModuleTest.cs
@@ -0,0 +1,100 @@
+using EmbedIO.Security;
+using EmbedIO.Tests.TestObjects;
+using System.Net;
+using System.Net.Http;
+using System.Threading.Tasks;
+using NUnit.Framework;
+using System;
+
+namespace EmbedIO.Tests
+{
+ [TestFixture]
+ public class IPBanningModuleTest : EndToEndFixtureBase
+ {
+ protected override void OnSetUp()
+ {
+ Server
+ .WithIPBanning(o => o
+ .WithRules("(404)+")
+ .WithRules("(401)+"), 30, 2)
+ .WithWebApi("/api", m => m.RegisterController());
+ }
+
+ private HttpRequestMessage GetNotFoundRequest() =>
+ new HttpRequestMessage(HttpMethod.Get, $"{WebServerUrl}/api/notFound");
+
+ private HttpRequestMessage GetUnauthorizedRequest() =>
+ new HttpRequestMessage(HttpMethod.Get, $"{WebServerUrl}/api/unauthorized");
+
+ private IPAddress LocalHost { get; } = IPAddress.Parse("127.0.0.1");
+
+ [Test]
+ public async Task RequestFailRegex_ReturnsForbidden()
+ {
+ _ = await Client.SendAsync(GetNotFoundRequest());
+ _ = await Client.SendAsync(GetUnauthorizedRequest());
+ var response = await Client.SendAsync(GetNotFoundRequest());
+
+ Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode, "Status Code Forbidden");
+ }
+
+ [Test]
+ public async Task BanIpMinutes_ReturnsForbidden()
+ {
+ IPBanningModule.TryUnbanIP(LocalHost);
+
+ var response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.NotFound, response.StatusCode, "Status Code NotFound");
+
+ IPBanningModule.TryBanIP(LocalHost, 10);
+
+ response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode, "Status Code Forbidden");
+ }
+
+ [Test]
+ public async Task BanIpTimeSpan_ReturnsForbidden()
+ {
+ IPBanningModule.TryUnbanIP(LocalHost);
+
+ var response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.NotFound, response.StatusCode, "Status Code NotFound");
+
+ IPBanningModule.TryBanIP(LocalHost, TimeSpan.FromMinutes(10));
+
+ response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode, "Status Code Forbidden");
+ }
+
+ [Test]
+ public async Task BanIpDateTime_ReturnsForbidden()
+ {
+ IPBanningModule.TryUnbanIP(LocalHost);
+
+ var response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.NotFound, response.StatusCode, "Status Code NotFound");
+
+ IPBanningModule.TryBanIP(LocalHost, DateTime.Now.AddMinutes(10));
+
+ response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode, "Status Code Forbidden");
+ }
+
+ [Test]
+ public async Task RequestFailRegex_UnbanIp_ReturnsNotFound()
+ {
+ _ = await Client.SendAsync(GetNotFoundRequest());
+ _ = await Client.SendAsync(GetNotFoundRequest());
+ var response = await Client.SendAsync(GetNotFoundRequest());
+
+ Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode, "Status Code Forbidden");
+
+ var bannedIps = IPBanningModule.GetBannedIPs();
+ foreach (var address in bannedIps)
+ IPBanningModule.TryUnbanIP(address.IPAddress);
+
+ response = await Client.SendAsync(GetNotFoundRequest());
+ Assert.AreEqual(HttpStatusCode.NotFound, response.StatusCode, "Status Code NotFound");
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/EmbedIO.Tests/TestObjects/TestController.cs b/test/EmbedIO.Tests/TestObjects/TestController.cs
index 2406daeb8..212248c21 100644
--- a/test/EmbedIO.Tests/TestObjects/TestController.cs
+++ b/test/EmbedIO.Tests/TestObjects/TestController.cs
@@ -63,6 +63,14 @@ public Person GetOptionalPerson(string skill, int? age = null)
[Route(HttpVerbs.Get, "/" + QueryFieldTestPath)]
public string TestQueryField([QueryField] string id) => id;
+ [Route(HttpVerbs.Get, "/notFound")]
+ public void GetNotFound() =>
+ throw HttpException.NotFound();
+
+ [Route(HttpVerbs.Get, "/unauthorized")]
+ public void GetUnauthorized() =>
+ throw HttpException.Unauthorized();
+
private static Person CheckPerson(int id)
=>PeopleRepository.Database.FirstOrDefault(p => p.Key == id)
?? throw HttpException.NotFound();
diff --git a/test/EmbedIO.Tests/Utilities/IPParserTest.cs b/test/EmbedIO.Tests/Utilities/IPParserTest.cs
new file mode 100644
index 000000000..815600f3a
--- /dev/null
+++ b/test/EmbedIO.Tests/Utilities/IPParserTest.cs
@@ -0,0 +1,75 @@
+using EmbedIO.Utilities;
+using NUnit.Framework;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace EmbedIO.Tests.Utilities
+{
+ public class IPParserTest
+ {
+ [TestCase(null, false)]
+ [TestCase("", false)]
+ [TestCase("192.168.1.52", false)]
+ [TestCase("192.168.1.52/", false)]
+ [TestCase("192.168.1.52-2", false)]
+ [TestCase("192.168.1.52/48", false)]
+ [TestCase("192.168.152/48", false)]
+ [TestCase("192.168.1.52/256", false)]
+ [TestCase("192.168.1.52/24.1", false)]
+ [TestCase("192.168.1.52/24", true)]
+ public void IsCIDRNotation_ReturnsCorrectValue(string address, bool expectedResult)
+ => Assert.AreEqual(expectedResult, IPParser.IsCIDRNotation(address));
+
+ [TestCase(null, false)]
+ [TestCase("", false)]
+ [TestCase("192.168.152", false)]
+ [TestCase("192.168.152.1.", false)]
+ [TestCase("192.168.1.52-", false)]
+ [TestCase("192.168.1-.52", false)]
+ [TestCase("192.168.1-2.52/1", false)]
+ [TestCase("192.168.1-2.52-1", true)]
+ [TestCase("192.168-169.1-2.52-53", true)]
+ [TestCase("192-193.168-169.1-2.52-53", true)]
+ public void IsSimpleIPRange_ReturnsCorrectValue(string address, bool expectedResult)
+ => Assert.AreEqual(expectedResult, IPParser.IsSimpleIPRange(address));
+
+ [TestCase(null)]
+ [TestCase("192.168.1.52/")]
+ [TestCase("192.168.1.52-2")]
+ [TestCase("192.168.1.52/48")]
+ [TestCase("192.168.152/24")]
+ [TestCase("192.168.1.52/256")]
+ [TestCase("192.168.1.52/24.1")]
+ [TestCase("192.168.152.1.")]
+ [TestCase("192.168.1.52-")]
+ [TestCase("192.168.1-.52")]
+ [TestCase("192.168.1-2.52/1")]
+ [TestCase("192.168.1-2/3.52/1")]
+ [TestCase("192.168.1-x.52/1")]
+ [TestCase("192.168.1-2.52-1")]
+ [TestCase("192.168.2-1.52-1")]
+ public async Task IpParseEmpty_ReturnsCorrectValue(string address)
+ => CollectionAssert.IsEmpty(await IPParser.Parse(address));
+
+ [TestCase("")]
+ [TestCase("192")]
+ [TestCase("192.168")]
+ [TestCase("192.168.152")]
+ [TestCase("192.168.1.52/24")]
+ [TestCase("192.168-169.1-2.52-53")]
+ [TestCase("192-193.168-169.1-2.52-53")]
+ public async Task IpParseNotEmpty_ReturnsCorrectValue(string address)
+ => CollectionAssert.IsNotEmpty(await IPParser.Parse(address));
+
+ [TestCase("192", 1)]
+ [TestCase("192.168", 1)]
+ [TestCase("192.168.152", 1)]
+ [TestCase("192.168.1.1/24", 256)]
+ [TestCase("192.168.1.50-53", 4)]
+ [TestCase("192.168.1-2.50-53", 8)]
+ [TestCase("192.168-169.1-2.50-53", 16)]
+ [TestCase("192-193.168-169.1-2.50-53", 32)]
+ public async Task IpParseCount_ReturnsCorrectValue(string address, int count)
+ => Assert.AreEqual(count, (await IPParser.Parse(address)).Count());
+ }
+}