Skip to content

Commit

Permalink
1846 server mqttserverstopasync doesnt indicate correct reason (#1872)
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 authored Nov 25, 2023
1 parent b431a5d commit ae84aa4
Show file tree
Hide file tree
Showing 25 changed files with 521 additions and 147 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/ReleaseNotes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
* [Server] Fixed not working _UpdateRetainedMessageAsync_ public api (#1858, thanks to @kimdiego2098).
* [Client] Added support for custom CA chain validation (#1851, thanks to @rido-min).
* [Client] Added support for custom CA chain validation (#1851, thanks to @rido-min).
* [Client] Fixed handling of unobserved tasks exceptions (#1871).
* [Client] Fixed not specified ReasonCode when using _SendExtendedAuthenticationExchangeDataAsync_ (#1882, thanks to @rido-min).
* [Client] Fixed not specified ReasonCode when using _SendExtendedAuthenticationExchangeDataAsync_ (#1882, thanks to @rido-min).
* [Server] Fixed not working _UpdateRetainedMessageAsync_ public api (#1858, thanks to @kimdiego2098).
* [Server] Added support for custom DISCONNECT packets when stopping the server or disconnect a client (BREAKING CHANGE!, #1846).
* [Server] Added new property to stop the server from accepting new connections even if it is running (#1846).
20 changes: 14 additions & 6 deletions Source/MQTTnet.AspnetCore/MqttHostedServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,20 +15,27 @@ namespace MQTTnet.AspNetCore
{
public sealed class MqttHostedServer : MqttServer, IHostedService
{
public MqttHostedServer(MqttServerOptions options, IEnumerable<IMqttServerAdapter> adapters, IMqttNetLogger logger)
: base(options, adapters, logger)
readonly MqttFactory _mqttFactory;

public MqttHostedServer(MqttFactory mqttFactory, MqttServerOptions options, IEnumerable<IMqttServerAdapter> adapters, IMqttNetLogger logger) : base(
options,
adapters,
logger)
{
_mqttFactory = mqttFactory ?? throw new ArgumentNullException(nameof(mqttFactory));
}

public Task StartAsync(CancellationToken cancellationToken)
public async Task StartAsync(CancellationToken cancellationToken)
{
_ = StartAsync();
return Task.CompletedTask;
// The yield makes sure that the hosted service is considered up and running.
await Task.Yield();

_ = StartAsync();
}

public Task StopAsync(CancellationToken cancellationToken)
{
return StopAsync();
return StopAsync(_mqttFactory.CreateMqttServerStopOptionsBuilder().Build());
}
}
}
110 changes: 68 additions & 42 deletions Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.DependencyInjection.Extensions;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics;
using MQTTnet.Implementations;
Expand All @@ -14,76 +14,102 @@ namespace MQTTnet.AspNetCore
{
public static class ServiceCollectionExtensions
{
public static IServiceCollection AddMqttServer(this IServiceCollection serviceCollection, Action<MqttServerOptionsBuilder> configure = null)
public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, MqttServerOptions options)
{
if (serviceCollection is null)
if (services == null)
{
throw new ArgumentNullException(nameof(serviceCollection));
throw new ArgumentNullException(nameof(services));
}

serviceCollection.AddMqttConnectionHandler();
serviceCollection.AddHostedMqttServer(configure);

return serviceCollection;
}

public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, MqttServerOptions options)
{
if (options == null) throw new ArgumentNullException(nameof(options));
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}

services.AddSingleton(options);

services.AddHostedMqttServer();

return services;
}

public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, Action<MqttServerOptionsBuilder> configure = null)
public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, Action<MqttServerOptionsBuilder> configure)
{
services.AddSingleton(s =>
if (services == null)
{
var serverOptionsBuilder = new MqttServerOptionsBuilder();
configure?.Invoke(serverOptionsBuilder);
return serverOptionsBuilder.Build();
});
throw new ArgumentNullException(nameof(services));
}

services.AddHostedMqttServer();
if (configure == null)
{
throw new ArgumentNullException(nameof(configure));
}

return services;
var serverOptionsBuilder = new MqttServerOptionsBuilder();
configure.Invoke(serverOptionsBuilder);
var options = serverOptionsBuilder.Build();

return AddHostedMqttServer(services, options);
}

public static void AddHostedMqttServer(this IServiceCollection services)
{
// The user may have these services already registered.
services.TryAddSingleton<IMqttNetLogger>(MqttNetNullLogger.Instance);
services.TryAddSingleton(new MqttFactory());

services.AddSingleton<MqttHostedServer>();
services.AddHostedService<MqttHostedServer>();
}

public static IServiceCollection AddHostedMqttServerWithServices(this IServiceCollection services, Action<AspNetMqttServerOptionsBuilder> configure)
{
services.AddSingleton(s =>
if (services == null)
{
var builder = new AspNetMqttServerOptionsBuilder(s);
configure(builder);
return builder.Build();
});
throw new ArgumentNullException(nameof(services));
}

services.AddSingleton(
s =>
{
var builder = new AspNetMqttServerOptionsBuilder(s);
configure(builder);
return builder.Build();
});

services.AddHostedMqttServer();

return services;
}

static IServiceCollection AddHostedMqttServer(this IServiceCollection services)
public static IServiceCollection AddMqttConnectionHandler(this IServiceCollection services)
{
var logger = new MqttNetEventLogger();

services.AddSingleton<IMqttNetLogger>(logger);
services.AddSingleton<MqttHostedServer>();
services.AddSingleton<IHostedService>(s => s.GetService<MqttHostedServer>());
services.AddSingleton<MqttServer>(s => s.GetService<MqttHostedServer>());
services.AddSingleton<MqttConnectionHandler>();
services.AddSingleton<IMqttServerAdapter>(s => s.GetService<MqttConnectionHandler>());

return services;
}

public static IServiceCollection AddMqttWebSocketServerAdapter(this IServiceCollection services)
public static void AddMqttLogger(this IServiceCollection services, IMqttNetLogger logger)
{
services.AddSingleton<MqttWebSocketServerAdapter>();
services.AddSingleton<IMqttServerAdapter>(s => s.GetService<MqttWebSocketServerAdapter>());
if (services == null)
{
throw new ArgumentNullException(nameof(services));
}

return services;
services.AddSingleton(logger);
}

public static IServiceCollection AddMqttServer(this IServiceCollection serviceCollection, Action<MqttServerOptionsBuilder> configure = null)
{
if (serviceCollection is null)
{
throw new ArgumentNullException(nameof(serviceCollection));
}

serviceCollection.AddMqttConnectionHandler();
serviceCollection.AddHostedMqttServer(configure);

return serviceCollection;
}

public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services)
Expand All @@ -94,12 +120,12 @@ public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection
return services;
}

public static IServiceCollection AddMqttConnectionHandler(this IServiceCollection services)
public static IServiceCollection AddMqttWebSocketServerAdapter(this IServiceCollection services)
{
services.AddSingleton<MqttConnectionHandler>();
services.AddSingleton<IMqttServerAdapter>(s => s.GetService<MqttConnectionHandler>());
services.AddSingleton<MqttWebSocketServerAdapter>();
services.AddSingleton<IMqttServerAdapter>(s => s.GetService<MqttWebSocketServerAdapter>());

return services;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using MQTTnet.LowLevelClient;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using MQTTnet.Server;

namespace MQTTnet.Tests.Clients.LowLevelMqttClient
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using MQTTnet.Server;
using MQTTnet.Tests.Mockups;

// ReSharper disable InconsistentNaming
Expand Down
1 change: 1 addition & 0 deletions Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading.Tasks;
using MQTTnet.Internal;
using MQTTnet.Protocol;
using MQTTnet.Server;

namespace MQTTnet.Tests.MQTTv5
{
Expand Down
8 changes: 4 additions & 4 deletions Source/MQTTnet.Tests/Server/General.cs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public async Task Intercept_Message()
var server = await testEnvironment.StartServer();
server.InterceptingPublishAsync += e =>
{
e.ApplicationMessage.Payload = Encoding.ASCII.GetBytes("extended");
e.ApplicationMessage.PayloadSegment = new ArraySegment<byte>(Encoding.ASCII.GetBytes("extended"));
return CompletedTask.Instance;
};

Expand All @@ -314,7 +314,7 @@ public async Task Intercept_Message()
var isIntercepted = false;
c2.ApplicationMessageReceivedAsync += e =>
{
isIntercepted = string.Compare("extended", Encoding.UTF8.GetString(e.ApplicationMessage.Payload), StringComparison.Ordinal) == 0;
isIntercepted = string.Compare("extended", Encoding.UTF8.GetString(e.ApplicationMessage.PayloadSegment.ToArray()), StringComparison.Ordinal) == 0;
return CompletedTask.Instance;
};

Expand Down Expand Up @@ -425,7 +425,7 @@ await server.InjectApplicationMessage(
new MqttApplicationMessage
{
Topic = "/test/1",
Payload = Encoding.UTF8.GetBytes("true"),
PayloadSegment = new ArraySegment<byte>(Encoding.UTF8.GetBytes("true")),
QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce
})
{
Expand Down Expand Up @@ -780,7 +780,7 @@ public async Task Send_Long_Body()
var client1 = await testEnvironment.ConnectClient();
client1.ApplicationMessageReceivedAsync += e =>
{
receivedBody = e.ApplicationMessage.Payload;
receivedBody = e.ApplicationMessage.PayloadSegment.ToArray();
return CompletedTask.Instance;
};

Expand Down
1 change: 1 addition & 0 deletions Source/MQTTnet.Tests/Server/Publishing_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Protocol;
using MQTTnet.Server;

namespace MQTTnet.Tests.Server
{
Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Tests/Server/Session_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public async Task Set_Session_Item()

server.InterceptingPublishAsync += e =>
{
e.ApplicationMessage.Payload = Encoding.UTF8.GetBytes(e.SessionItems["default_payload"] as string ?? string.Empty);
e.ApplicationMessage.PayloadSegment = new ArraySegment<byte>(Encoding.UTF8.GetBytes(e.SessionItems["default_payload"] as string ?? string.Empty));
return CompletedTask.Instance;
};

Expand Down
20 changes: 9 additions & 11 deletions Source/MQTTnet.Tests/Server/Tls_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ await firstClient.PublishAsync(
new MqttApplicationMessage
{
Topic = "TestTopic1",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
});

await testEnvironment.Server.InjectApplicationMessage(
new InjectedMqttApplicationMessage(
new MqttApplicationMessage
{
Topic = "TestTopic1",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
}));

certificateProvider.CurrentCertificate = CreateCertificate(secondOid);
Expand Down Expand Up @@ -137,31 +137,31 @@ await firstClient.PublishAsync(
new MqttApplicationMessage
{
Topic = "TestTopic2",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
});

await testEnvironment.Server.InjectApplicationMessage(
new InjectedMqttApplicationMessage(
new MqttApplicationMessage
{
Topic = "TestTopic2",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
}));

// Ensure first client still works
await firstClient.PublishAsync(
new MqttApplicationMessage
{
Topic = "TestTopic1",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
});

await testEnvironment.Server.InjectApplicationMessage(
new InjectedMqttApplicationMessage(
new MqttApplicationMessage
{
Topic = "TestTopic1",
Payload = new byte[] { 1, 2, 3, 4 }
PayloadSegment = new ArraySegment<byte>(new byte[] { 1, 2, 3, 4 })
}));

await Task.Delay(1000);
Expand All @@ -178,12 +178,10 @@ static async Task<IMqttClient> ConnectClientAsync(TestEnvironment testEnvironmen
var clientOptionsBuilder = testEnvironment.Factory.CreateClientOptionsBuilder();
clientOptionsBuilder.WithClientId(Guid.NewGuid().ToString())
.WithTcpServer("localhost", 8883)
.WithTls(
tls =>
.WithTlsOptions(
o =>
{
tls.UseTls = true;
tls.SslProtocol = SslProtocols.Tls12;
tls.CertificateValidationHandler = certValidator;
o.WithSslProtocols(SslProtocols.Tls12).WithCertificateValidationHandler(certValidator);
});

var clientOptions = clientOptionsBuilder.Build();
Expand Down
Loading

0 comments on commit ae84aa4

Please sign in to comment.