diff --git a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs index 53712c42..5dae21d0 100644 --- a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs +++ b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs @@ -1,456 +1,481 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using MediatR; -using MediatR.NotificationPublishers; -using MediatR.Pipeline; -using MediatR.Registration; - -namespace Microsoft.Extensions.DependencyInjection; - -public class MediatRServiceConfiguration -{ - /// - /// Optional filter for types to register. Default value is a function returning true. - /// - public Func TypeEvaluator { get; set; } = t => true; - - /// - /// Mediator implementation type to register. Default is - /// - public Type MediatorImplementationType { get; set; } = typeof(Mediator); - - /// - /// Strategy for publishing notifications. Defaults to - /// - public INotificationPublisher NotificationPublisher { get; set; } = new ForeachAwaitPublisher(); - - /// - /// Type of notification publisher strategy to register. If set, overrides - /// - public Type? NotificationPublisherType { get; set; } - - /// - /// Service lifetime to register services under. Default value is - /// - public ServiceLifetime Lifetime { get; set; } = ServiceLifetime.Transient; - - /// - /// Request exception action processor strategy. Default value is - /// - public RequestExceptionActionProcessorStrategy RequestExceptionActionProcessorStrategy { get; set; } - = RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions; - - internal List AssembliesToRegister { get; } = new(); - - /// - /// List of behaviors to register in specific order - /// - public List BehaviorsToRegister { get; } = new(); - - /// - /// List of stream behaviors to register in specific order - /// - public List StreamBehaviorsToRegister { get; } = new(); - - /// - /// List of request pre processors to register in specific order - /// - public List RequestPreProcessorsToRegister { get; } = new(); - - /// - /// List of request post processors to register in specific order - /// - public List RequestPostProcessorsToRegister { get; } = new(); - - /// - /// Automatically register processors during assembly scanning - /// - public bool AutoRegisterRequestProcessors { get; set; } - - /// - /// Register various handlers from assembly containing given type - /// - /// Type from assembly to scan - /// This - public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining() - => RegisterServicesFromAssemblyContaining(typeof(T)); - - /// - /// Register various handlers from assembly containing given type - /// - /// Type from assembly to scan - /// This - public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining(Type type) - => RegisterServicesFromAssembly(type.Assembly); - - /// - /// Register various handlers from assembly - /// - /// Assembly to scan - /// This - public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembly) - { - AssembliesToRegister.Add(assembly); - - return this; - } - - /// - /// Register various handlers from assemblies - /// - /// Assemblies to scan - /// This - public MediatRServiceConfiguration RegisterServicesFromAssemblies( - params Assembly[] assemblies) - { - AssembliesToRegister.AddRange(assemblies); - - return this; - } - - /// - /// Register a closed behavior type - /// - /// Closed behavior interface type - /// Closed behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed behavior type against all implementations - /// - /// Closed behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - return AddBehavior(typeof(TImplementationType), serviceLifetime); - } - - /// - /// Register a closed behavior type against all implementations - /// - /// Closed behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList(); - - if (implementedGenericInterfaces.Count == 0) - { - throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}"); - } - - foreach (var implementedBehaviorType in implementedGenericInterfaces) - { - BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); - } - - return this; - } - - /// - /// Register a closed behavior type - /// - /// Closed behavior interface type - /// Closed behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddBehavior(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); - - return this; - } - - /// - /// Registers an open behavior type against the open generic interface type - /// - /// An open generic behavior type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - if (!openBehaviorType.IsGenericType) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); - } - - var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>))); - - if (implementedOpenBehaviorInterfaces.Count == 0) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}"); - } - - foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) - { - BehaviorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); - } - - return this; - } - - /// - /// Register a closed stream behavior type - /// - /// Closed stream behavior interface type - /// Closed stream behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddStreamBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed stream behavior type - /// - /// Closed stream behavior interface type - /// Closed stream behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - StreamBehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); - - return this; - } - - /// - /// Register a closed stream behavior type against all implementations - /// - /// Closed stream behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddStreamBehavior(typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed stream behavior type against all implementations - /// - /// Closed stream behavior implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList(); - - if (implementedGenericInterfaces.Count == 0) - { - throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}"); - } - - foreach (var implementedBehaviorType in implementedGenericInterfaces) - { - StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); - } - - return this; - } - - /// - /// Registers an open stream behavior type against the open generic interface type - /// - /// An open generic stream behavior type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - if (!openBehaviorType.IsGenericType) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); - } - - var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>))); - - if (implementedOpenBehaviorInterfaces.Count == 0) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}"); - } - - foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) - { - StreamBehaviorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); - } - - return this; - } - - /// - /// Register a closed request pre processor type - /// - /// Closed request pre processor interface type - /// Closed request pre processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPreProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed request pre processor type - /// - /// Closed request pre processor interface type - /// Closed request pre processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPreProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - RequestPreProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); - - return this; - } - - /// - /// Register a closed request pre processor type against all implementations - /// - /// Closed request pre processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPreProcessor( - ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddRequestPreProcessor(typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed request pre processor type against all implementations - /// - /// Closed request pre processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList(); - - if (implementedGenericInterfaces.Count == 0) - { - throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}"); - } - - foreach (var implementedPreProcessorType in implementedGenericInterfaces) - { - RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime)); - } - - return this; - } - - /// - /// Registers an open request pre processor type against the open generic interface type - /// - /// An open generic request pre processor type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - if (!openBehaviorType.IsGenericType) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); - } - - var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>))); - - if (implementedOpenBehaviorInterfaces.Count == 0) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}"); - } - - foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) - { - RequestPreProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); - } - - return this; - } - - /// - /// Register a closed request post processor type - /// - /// Closed request post processor interface type - /// Closed request post processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddRequestPostProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed request post processor type - /// - /// Closed request post processor interface type - /// Closed request post processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - RequestPostProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); - - return this; - } - - /// - /// Register a closed request post processor type against all implementations - /// - /// Closed request post processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - => AddRequestPostProcessor(typeof(TImplementationType), serviceLifetime); - - /// - /// Register a closed request post processor type against all implementations - /// - /// Closed request post processor implementation type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList(); - - if (implementedGenericInterfaces.Count == 0) - { - throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}"); - } - - foreach (var implementedPostProcessorType in implementedGenericInterfaces) - { - RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime)); - } - return this; - } - - /// - /// Registers an open request post processor type against the open generic interface type - /// - /// An open generic request post processor type - /// Optional service lifetime, defaults to . - /// This - public MediatRServiceConfiguration AddOpenRequestPostProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) - { - if (!openBehaviorType.IsGenericType) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); - } - - var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>))); - - if (implementedOpenBehaviorInterfaces.Count == 0) - { - throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}"); - } - - foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) - { - RequestPostProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); - } - - return this; - } - - +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using MediatR; +using MediatR.NotificationPublishers; +using MediatR.Pipeline; +using MediatR.Registration; + +namespace Microsoft.Extensions.DependencyInjection; + +public class MediatRServiceConfiguration +{ + /// + /// Optional filter for types to register. Default value is a function returning true. + /// + public Func TypeEvaluator { get; set; } = t => true; + + /// + /// Mediator implementation type to register. Default is + /// + public Type MediatorImplementationType { get; set; } = typeof(Mediator); + + /// + /// Strategy for publishing notifications. Defaults to + /// + public INotificationPublisher NotificationPublisher { get; set; } = new ForeachAwaitPublisher(); + + /// + /// Type of notification publisher strategy to register. If set, overrides + /// + public Type? NotificationPublisherType { get; set; } + + /// + /// Service lifetime to register services under. Default value is + /// + public ServiceLifetime Lifetime { get; set; } = ServiceLifetime.Transient; + + /// + /// Request exception action processor strategy. Default value is + /// + public RequestExceptionActionProcessorStrategy RequestExceptionActionProcessorStrategy { get; set; } + = RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions; + + internal List AssembliesToRegister { get; } = new(); + + /// + /// List of behaviors to register in specific order + /// + public List BehaviorsToRegister { get; } = new(); + + /// + /// List of stream behaviors to register in specific order + /// + public List StreamBehaviorsToRegister { get; } = new(); + + /// + /// List of request pre processors to register in specific order + /// + public List RequestPreProcessorsToRegister { get; } = new(); + + /// + /// List of request post processors to register in specific order + /// + public List RequestPostProcessorsToRegister { get; } = new(); + + /// + /// Automatically register processors during assembly scanning + /// + public bool AutoRegisterRequestProcessors { get; set; } + + /// + /// Configure the maximum number of type parameters that a generic request handler can have. To Disable this constraint, set the value to 0. + /// + public int MaxGenericTypeParameters { get; set; } = 10; + + /// + /// Configure the maximum number of types that can close a generic request type parameter constraint. To Disable this constraint, set the value to 0. + /// + public int MaxTypesClosing { get; set; } = 100; + + /// + /// Configure the Maximum Amount of Generic RequestHandler Types MediatR will try to register. To Disable this constraint, set the value to 0. + /// + public int MaxGenericTypeRegistrations { get; set; } = 125000; + + /// + /// Configure the Timeout in Milliseconds that the GenericHandler Registration Process will exit with error. To Disable this constraint, set the value to 0. + /// + public int RegistrationTimeout { get; set; } = 15000; + + /// + /// Flag that controlls whether MediatR will attempt to register handlers that containg generic type parameters. + /// + public bool RegisterGenericHandlers { get; set; } = true; + + /// + /// Register various handlers from assembly containing given type + /// + /// Type from assembly to scan + /// This + public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining() + => RegisterServicesFromAssemblyContaining(typeof(T)); + + /// + /// Register various handlers from assembly containing given type + /// + /// Type from assembly to scan + /// This + public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining(Type type) + => RegisterServicesFromAssembly(type.Assembly); + + /// + /// Register various handlers from assembly + /// + /// Assembly to scan + /// This + public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembly) + { + AssembliesToRegister.Add(assembly); + + return this; + } + + /// + /// Register various handlers from assemblies + /// + /// Assemblies to scan + /// This + public MediatRServiceConfiguration RegisterServicesFromAssemblies( + params Assembly[] assemblies) + { + AssembliesToRegister.AddRange(assemblies); + + return this; + } + + /// + /// Register a closed behavior type + /// + /// Closed behavior interface type + /// Closed behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed behavior type against all implementations + /// + /// Closed behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + return AddBehavior(typeof(TImplementationType), serviceLifetime); + } + + /// + /// Register a closed behavior type against all implementations + /// + /// Closed behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList(); + + if (implementedGenericInterfaces.Count == 0) + { + throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}"); + } + + foreach (var implementedBehaviorType in implementedGenericInterfaces) + { + BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); + } + + return this; + } + + /// + /// Register a closed behavior type + /// + /// Closed behavior interface type + /// Closed behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddBehavior(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); + + return this; + } + + /// + /// Registers an open behavior type against the open generic interface type + /// + /// An open generic behavior type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + if (!openBehaviorType.IsGenericType) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); + } + + var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); + var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>))); + + if (implementedOpenBehaviorInterfaces.Count == 0) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}"); + } + + foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) + { + BehaviorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); + } + + return this; + } + + /// + /// Register a closed stream behavior type + /// + /// Closed stream behavior interface type + /// Closed stream behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddStreamBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed stream behavior type + /// + /// Closed stream behavior interface type + /// Closed stream behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + StreamBehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); + + return this; + } + + /// + /// Register a closed stream behavior type against all implementations + /// + /// Closed stream behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddStreamBehavior(typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed stream behavior type against all implementations + /// + /// Closed stream behavior implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList(); + + if (implementedGenericInterfaces.Count == 0) + { + throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}"); + } + + foreach (var implementedBehaviorType in implementedGenericInterfaces) + { + StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); + } + + return this; + } + + /// + /// Registers an open stream behavior type against the open generic interface type + /// + /// An open generic stream behavior type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + if (!openBehaviorType.IsGenericType) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); + } + + var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); + var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>))); + + if (implementedOpenBehaviorInterfaces.Count == 0) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}"); + } + + foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) + { + StreamBehaviorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); + } + + return this; + } + + /// + /// Register a closed request pre processor type + /// + /// Closed request pre processor interface type + /// Closed request pre processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPreProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed request pre processor type + /// + /// Closed request pre processor interface type + /// Closed request pre processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPreProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + RequestPreProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); + + return this; + } + + /// + /// Register a closed request pre processor type against all implementations + /// + /// Closed request pre processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPreProcessor( + ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddRequestPreProcessor(typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed request pre processor type against all implementations + /// + /// Closed request pre processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList(); + + if (implementedGenericInterfaces.Count == 0) + { + throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}"); + } + + foreach (var implementedPreProcessorType in implementedGenericInterfaces) + { + RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime)); + } + + return this; + } + + /// + /// Registers an open request pre processor type against the open generic interface type + /// + /// An open generic request pre processor type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + if (!openBehaviorType.IsGenericType) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); + } + + var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); + var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>))); + + if (implementedOpenBehaviorInterfaces.Count == 0) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}"); + } + + foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) + { + RequestPreProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); + } + + return this; + } + + /// + /// Register a closed request post processor type + /// + /// Closed request post processor interface type + /// Closed request post processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddRequestPostProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed request post processor type + /// + /// Closed request post processor interface type + /// Closed request post processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + RequestPostProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime)); + + return this; + } + + /// + /// Register a closed request post processor type against all implementations + /// + /// Closed request post processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + => AddRequestPostProcessor(typeof(TImplementationType), serviceLifetime); + + /// + /// Register a closed request post processor type against all implementations + /// + /// Closed request post processor implementation type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList(); + + if (implementedGenericInterfaces.Count == 0) + { + throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}"); + } + + foreach (var implementedPostProcessorType in implementedGenericInterfaces) + { + RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime)); + } + return this; + } + + /// + /// Registers an open request post processor type against the open generic interface type + /// + /// An open generic request post processor type + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddOpenRequestPostProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + if (!openBehaviorType.IsGenericType) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must be generic"); + } + + var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); + var implementedOpenBehaviorInterfaces = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>))); + + if (implementedOpenBehaviorInterfaces.Count == 0) + { + throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}"); + } + + foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces) + { + RequestPostProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime)); + } + + return this; + } + + } \ No newline at end of file diff --git a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs index 50e9787f..6e211b27 100644 --- a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs +++ b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs @@ -47,7 +47,9 @@ public static IServiceCollection AddMediatR(this IServiceCollection services, throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers."); } - ServiceRegistrar.AddMediatRClasses(services, configuration); + ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration); + + ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration); ServiceRegistrar.AddRequiredServices(services, configuration); diff --git a/src/MediatR/Registration/ServiceRegistrar.cs b/src/MediatR/Registration/ServiceRegistrar.cs index 8fd5bf96..c4cdd008 100644 --- a/src/MediatR/Registration/ServiceRegistrar.cs +++ b/src/MediatR/Registration/ServiceRegistrar.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Threading; using MediatR.Pipeline; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -9,13 +10,42 @@ namespace MediatR.Registration; public static class ServiceRegistrar -{ - public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration) - { +{ + private static int MaxGenericTypeParameters; + private static int MaxTypesClosing; + private static int MaxGenericTypeRegistrations; + private static int RegistrationTimeout; + + public static void SetGenericRequestHandlerRegistrationLimitations(MediatRServiceConfiguration configuration) + { + MaxGenericTypeParameters = configuration.MaxGenericTypeParameters; + MaxTypesClosing = configuration.MaxTypesClosing; + MaxGenericTypeRegistrations = configuration.MaxGenericTypeRegistrations; + RegistrationTimeout = configuration.RegistrationTimeout; + } + + public static void AddMediatRClassesWithTimeout(IServiceCollection services, MediatRServiceConfiguration configuration) + { + using(var cts = new CancellationTokenSource(RegistrationTimeout)) + { + try + { + AddMediatRClasses(services, configuration, cts.Token); + } + catch (OperationCanceledException) + { + throw new TimeoutException("The generic handler registration process timed out."); + } + } + } + + public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration, CancellationToken cancellationToken = default) + { + var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray(); - ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration); - ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration); + ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration, cancellationToken); + ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration, cancellationToken); ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration); ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration); ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration); @@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa IServiceCollection services, IEnumerable assembliesToScan, bool addIfAlreadyExists, - MediatRServiceConfiguration configuration) + MediatRServiceConfiguration configuration, + CancellationToken cancellationToken = default) { var concretions = new List(); var interfaces = new List(); @@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa var types = assembliesToScan .SelectMany(a => a.DefinedTypes) + .Where(t => !t.ContainsGenericParameters || configuration.RegisterGenericHandlers) .Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any()) .Where(configuration.TypeEvaluator) - .ToList(); + .ToList(); foreach (var type in types) { @@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa foreach (var @interface in genericInterfaces) { var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList(); - AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan); + AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan, cancellationToken); } } @@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation) { - var closingType = concreteGenericTRequest.GetGenericArguments().First(); + var closingTypes = concreteGenericTRequest.GetGenericArguments(); var concreteTResponse = concreteGenericTRequest.GetInterfaces() .FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>)) @@ -187,17 +219,25 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes( typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) : typeDefinition.MakeGenericType(concreteGenericTRequest); - return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType)); + return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingTypes)); } - private static List? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable assembliesToScan) + private static List? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable assembliesToScan, CancellationToken cancellationToken) { - var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints(); - - var typesThatCanClose = assembliesToScan - .SelectMany(assembly => assembly.GetTypes()) - .Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))) - .ToList(); + //request generic type constraints + var constraintsForEachParameter = openRequestHandlerImplementation + .GetGenericArguments() + .Select(x => x.GetGenericParameterConstraints()) + .ToList(); + + if (constraintsForEachParameter.Count > 2 && constraintsForEachParameter.Any(constraints => !constraints.Where(x => x.IsInterface || x.IsClass).Any())) + throw new ArgumentException($"Error registering the generic handler type: {openRequestHandlerImplementation.FullName}. When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class."); + + var typesThatCanCloseForEachParameter = constraintsForEachParameter + .Select(constraints => assembliesToScan + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))).ToList() + ).ToList(); var requestType = openRequestHandlerInterface.GenericTypeArguments.First(); @@ -205,15 +245,64 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes( return null; var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition(); + + var combinations = GenerateCombinations(requestType, typesThatCanCloseForEachParameter, 0, cancellationToken); + + return combinations.Select(types => requestGenericTypeDefinition.MakeGenericType(types.ToArray())).ToList(); + } + + // Method to generate combinations recursively + public static List> GenerateCombinations(Type requestType, List> lists, int depth = 0, CancellationToken cancellationToken = default) + { + if (depth == 0) + { + // Initial checks + if (MaxGenericTypeParameters > 0 && lists.Count > MaxGenericTypeParameters) + throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The number of generic type parameters exceeds the maximum allowed ({MaxGenericTypeParameters})."); + + foreach (var list in lists) + { + if (MaxTypesClosing > 0 && list.Count > MaxTypesClosing) + throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({MaxTypesClosing})."); + } + + // Calculate the total number of combinations + long totalCombinations = 1; + foreach (var list in lists) + { + totalCombinations *= list.Count; + if (MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations) + throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The total number of generic type registrations exceeds the maximum allowed ({MaxGenericTypeRegistrations})."); + } + } + + if (depth >= lists.Count) + return new List> { new List() }; + + cancellationToken.ThrowIfCancellationRequested(); - return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList(); + var currentList = lists[depth]; + var childCombinations = GenerateCombinations(requestType, lists, depth + 1, cancellationToken); + var combinations = new List>(); + + foreach (var item in currentList) + { + foreach (var childCombination in childCombinations) + { + var currentCombination = new List { item }; + currentCombination.AddRange(childCombination); + combinations.Add(currentCombination); + } + } + + return combinations; } - private static void AddAllConcretionsThatClose(Type openRequestInterface, List concretions, IServiceCollection services, IEnumerable assembliesToScan) + private static void AddAllConcretionsThatClose(Type openRequestInterface, List concretions, IServiceCollection services, IEnumerable assembliesToScan, CancellationToken cancellationToken) { foreach (var concretion in concretions) - { - var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan); + { + var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan, cancellationToken); if (concreteRequests is null) continue; @@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List + { + cfg.RegisterServicesFromAssemblies(dynamicAssembly); + }); + + var provider = services.BuildServiceProvider(); + + var dynamicRequestType = dynamicAssembly.GetType("DynamicRequest")!; + + int expectedCombinations = CalculateTotalCombinations(numberOfClasses, numberOfInterfaces, numberOfTypeParameters); + + var testClasses = Enumerable.Range(1, numberOfClasses) + .Select(i => dynamicAssembly.GetType($"TestClass{i}")!) + .ToArray(); + + var combinations = GenerateCombinations(testClasses, numberOfInterfaces); + + foreach (var combination in combinations) + { + var concreteRequestType = dynamicRequestType.MakeGenericType(combination); + var requestHandlerInterface = typeof(IRequestHandler<>).MakeGenericType(concreteRequestType); + + var handler = provider.GetService(requestHandlerInterface); + handler.ShouldNotBeNull($"Handler for {concreteRequestType} should not be null"); + } + } + + [Theory] + [InlineData(9, 3, 3)] + [InlineData(10, 4, 4)] + [InlineData(1, 1, 1)] + [InlineData(50, 3, 3)] + public void ShouldRegisterTheCorrectAmountOfHandlers(int numberOfClasses, int numberOfInterfaces, int numberOfTypeParameters) + { + var dynamicAssembly = GenerateCombinationsTestAssembly(numberOfClasses, numberOfInterfaces, numberOfTypeParameters); + int expectedCombinations = CalculateTotalCombinations(numberOfClasses, numberOfInterfaces, numberOfTypeParameters); + var testClasses = Enumerable.Range(1, numberOfClasses) + .Select(i => dynamicAssembly.GetType($"TestClass{i}")!) + .ToArray(); + var combinations = GenerateCombinations(testClasses, numberOfInterfaces); + combinations.Count.ShouldBe(expectedCombinations, $"Should have tested all {expectedCombinations} combinations"); + } + + [Theory] + [InlineData(9, 3, 3)] + [InlineData(10, 4, 4)] + [InlineData(1, 1, 1)] + [InlineData(50, 3, 3)] + public void ShouldNotRegisterDuplicateHandlers(int numberOfClasses, int numberOfInterfaces, int numberOfTypeParameters) + { + var dynamicAssembly = GenerateCombinationsTestAssembly(numberOfClasses, numberOfInterfaces, numberOfTypeParameters); + int expectedCombinations = CalculateTotalCombinations(numberOfClasses, numberOfInterfaces, numberOfTypeParameters); + var testClasses = Enumerable.Range(1, numberOfClasses) + .Select(i => dynamicAssembly.GetType($"TestClass{i}")!) + .ToArray(); + var combinations = GenerateCombinations(testClasses, numberOfInterfaces); + var hasDuplicates = combinations + .Select(x => string.Join(", ", x.Select(y => y.Name))) + .GroupBy(x => x) + .Any(g => g.Count() > 1); + + hasDuplicates.ShouldBeFalse(); + } + + [Fact] + public void ShouldThrowExceptionWhenRegisterningHandlersWithNoConstraints() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateMissingConstraintsAssembly(); + + Should.Throw(() => + { + services.AddMediatR(cfg => + { + cfg.RegisterServicesFromAssembly(assembly); + }); + }) + .Message.ShouldContain("When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class."); + } + + [Fact] + public void ShouldThrowExceptionWhenTypesClosingExceedsMaximum() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateTypesClosingExceedsMaximumAssembly(); + + Should.Throw(() => + { + services.AddMediatR(cfg => + { + cfg.RegisterServicesFromAssembly(assembly); + }); + }) + .Message.ShouldContain("One of the generic type parameter's count of types that can close exceeds the maximum length allowed"); + } + + [Fact] + public void ShouldThrowExceptionWhenGenericHandlerRegistrationsExceedsMaximum() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateHandlerRegistrationsExceedsMaximumAssembly(); + + Should.Throw(() => + { + services.AddMediatR(cfg => + { + cfg.RegisterServicesFromAssembly(assembly); + }); + }) + .Message.ShouldContain("The total number of generic type registrations exceeds the maximum allowed"); + } + + [Fact] + public void ShouldThrowExceptionWhenGenericTypeParametersExceedsMaximum() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateGenericTypeParametersExceedsMaximumAssembly(); + + Should.Throw(() => + { + services.AddMediatR(cfg => + { + cfg.RegisterServicesFromAssembly(assembly); + }); + }) + .Message.ShouldContain("The number of generic type parameters exceeds the maximum allowed"); + } + + [Fact] + public void ShouldThrowExceptionWhenTimeoutOccurs() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateTimeoutOccursAssembly(); + + Should.Throw(() => + { + services.AddMediatR(cfg => + { + cfg.MaxGenericTypeParameters = 0; + cfg.MaxGenericTypeRegistrations = 0; + cfg.MaxTypesClosing = 0; + cfg.RegistrationTimeout = 1000; + cfg.RegisterServicesFromAssembly(assembly); + }); + }) + .Message.ShouldBe("The generic handler registration process timed out."); + } + + [Fact] + public void ShouldNotRegisterGenericHandlersWhenOptingOut() + { + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(new Logger()); + + var assembly = GenerateOptOutAssembly(); + services.AddMediatR(cfg => + { + //opt out flag set + cfg.RegisterGenericHandlers = false; + cfg.RegisterServicesFromAssembly(assembly); + }); + + var provider = services.BuildServiceProvider(); + var testClasses = Enumerable.Range(1, 2) + .Select(i => assembly.GetType($"TestClass{i}")!) + .ToArray(); + var requestType = assembly.GetType("OptOutRequest")!; + var combinations = GenerateCombinations(testClasses, 2); + + var concreteRequestType = requestType.MakeGenericType(combinations.First()); + var requestHandlerInterface = typeof(IRequestHandler<>).MakeGenericType(concreteRequestType); + + var handler = provider.GetService(requestHandlerInterface); + handler.ShouldBeNull($"Handler for {concreteRequestType} should be null"); + + + } + } +} diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs index dbbcaefc..60575620 100644 --- a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs @@ -3,12 +3,10 @@ namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests; using System; -using System.Collections.Generic; using System.Linq; -using System.Reflection; using Shouldly; -using Xunit; - +using Xunit; + public class AssemblyResolutionTests { private readonly IServiceProvider _provider; diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/BaseGenericRequestHandlerTests.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/BaseGenericRequestHandlerTests.cs new file mode 100644 index 00000000..b87fc02b --- /dev/null +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/BaseGenericRequestHandlerTests.cs @@ -0,0 +1,211 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Emit; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; + +namespace MediatR.Tests.MicrosoftExtensionsDI +{ + public abstract class BaseGenericRequestHandlerTests + { + + protected static Assembly GenerateMissingConstraintsAssembly() => + CreateAssemblyModuleBuilder("MissingConstraintsAssembly", 3, 3, CreateHandlerForMissingConstraintsTest); + + protected static Assembly GenerateTypesClosingExceedsMaximumAssembly() => + CreateAssemblyModuleBuilder("ExceedsMaximumTypesClosingAssembly", 201, 1, CreateHandlerForExceedsMaximumClassesTest); + + protected static Assembly GenerateHandlerRegistrationsExceedsMaximumAssembly() => + CreateAssemblyModuleBuilder("ExceedsMaximumHandlerRegistrationsAssembly", 500, 10, CreateHandlerForExceedsMaximumHandlerRegistrationsTest); + + protected static Assembly GenerateGenericTypeParametersExceedsMaximumAssembly() => + CreateAssemblyModuleBuilder("ExceedsMaximumGenericTypeParametersAssembly", 1, 1, CreateHandlerForExceedsMaximumGenericTypeParametersTest); + + protected static Assembly GenerateTimeoutOccursAssembly() => + CreateAssemblyModuleBuilder("TimeOutOccursAssembly", 400, 3, CreateHandlerForTimeoutOccursTest); + + protected static Assembly GenerateOptOutAssembly() => + CreateAssemblyModuleBuilder("OptOutAssembly", 2, 2, CreateHandlerForOptOutTest); + + protected static void CreateHandlerForOptOutTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "OptOutRequest", 2); + + protected static void CreateHandlerForMissingConstraintsTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "MissingConstraintsRequest", 3, 0, false); + + protected static void CreateHandlerForExceedsMaximumClassesTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "ExceedsMaximumTypesClosingRequest", 1); + + protected static void CreateHandlerForExceedsMaximumHandlerRegistrationsTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "ExceedsMaximumHandlerRegistrationsRequest", 4); + + protected static void CreateHandlerForExceedsMaximumGenericTypeParametersTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "ExceedsMaximumGenericTypeParametersRequest", 11, 1); + + protected static void CreateHandlerForTimeoutOccursTest(ModuleBuilder moduleBuilder) => + CreateRequestHandler(moduleBuilder, "TimeoutOccursRequest", 3); + + protected static void CreateHandlerForCombinationsTest(ModuleBuilder moduleBuilder, int numberOfGenericParameters) => + CreateRequestHandler(moduleBuilder, "DynamicRequest", numberOfGenericParameters); + + protected static void CreateClass(ModuleBuilder moduleBuilder, string className, Type interfaceType) + { + TypeBuilder typeBuilder = moduleBuilder.DefineType(className, TypeAttributes.Public); + typeBuilder.AddInterfaceImplementation(interfaceType); + typeBuilder.CreateTypeInfo(); + } + + protected static Type CreateInterface(ModuleBuilder moduleBuilder, string interfaceName) + { + TypeBuilder interfaceBuilder = moduleBuilder.DefineType(interfaceName, TypeAttributes.Public | TypeAttributes.Interface | TypeAttributes.Abstract); + return interfaceBuilder.CreateTypeInfo().AsType(); + } + + protected static AssemblyBuilder CreateAssemblyModuleBuilder(string name, int classes, int interfaces, Action handlerCreation) + { + AssemblyName assemblyName = new AssemblyName(name); + AssemblyBuilder assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run); + ModuleBuilder moduleBuilder = assemblyBuilder.DefineDynamicModule("MainModule"); + + CreateTestClassesAndInterfaces(moduleBuilder, classes, interfaces); + handlerCreation.Invoke(moduleBuilder); + + return assemblyBuilder; + } + + protected static AssemblyBuilder GenerateCombinationsTestAssembly(int classes, int interfaces, int genericParameters) + { + AssemblyName assemblyName = new AssemblyName("DynamicAssembly"); + AssemblyBuilder assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run); + ModuleBuilder moduleBuilder = assemblyBuilder.DefineDynamicModule("MainModule"); + + CreateTestClassesAndInterfaces(moduleBuilder, classes, interfaces); + CreateHandlerForCombinationsTest(moduleBuilder, genericParameters); + + return assemblyBuilder; + } + + protected static string[] GetGenericParameterNames(int numberOfTypeParameters) => + Enumerable.Range(1, numberOfTypeParameters).Select(i => $"T{i}").ToArray(); + + protected static void CreateRequestHandler(ModuleBuilder moduleBuilder, string requestName, int numberOfTypeParameters, int numberOfInterfaces = 0, bool includeConstraints = true) + { + if(numberOfInterfaces == 0) + { + numberOfInterfaces = numberOfTypeParameters; + } + + // Define the dynamic request class + var handlerTypeBuilder = moduleBuilder!.DefineType($"{requestName}Handler", TypeAttributes.Public); + var requestTypeBuilder = moduleBuilder!.DefineType(requestName, TypeAttributes.Public); + + // Define the generic parameters + string[] genericParameterNames = GetGenericParameterNames(numberOfTypeParameters); + var handlerGenericParameters = handlerTypeBuilder.DefineGenericParameters(genericParameterNames); + var requestGenericParameters = requestTypeBuilder.DefineGenericParameters(genericParameterNames); + requestTypeBuilder.AddInterfaceImplementation(typeof(IRequest)); + + if(includeConstraints) + { + for (int i = 0; i < numberOfTypeParameters; i++) + { + int interfaceIndex = i % numberOfInterfaces + 1; + + var constraintType = moduleBuilder.Assembly.GetType($"ITestInterface{interfaceIndex}"); + handlerGenericParameters[i].SetInterfaceConstraints(constraintType!); + requestGenericParameters[i].SetInterfaceConstraints(constraintType!); + } + } + + var requestType = requestTypeBuilder.CreateTypeInfo().AsType(); + handlerTypeBuilder.AddInterfaceImplementation(typeof(IRequestHandler<>).MakeGenericType(requestType)); + + // Define the Handle method + MethodBuilder handleMethodBuilder = handlerTypeBuilder.DefineMethod( + "Handle", + MethodAttributes.Public | MethodAttributes.Virtual, + typeof(Task), + new[] { requestType, typeof(CancellationToken) }); + + ILGenerator ilGenerator = handleMethodBuilder.GetILGenerator(); + + ilGenerator.Emit(OpCodes.Ret); + + // Implement the interface method + handlerTypeBuilder.DefineMethodOverride(handleMethodBuilder, typeof(IRequestHandler<>).MakeGenericType(requestType).GetMethod("Handle")!); + + // Create the dynamic request class + handlerTypeBuilder.CreateTypeInfo(); + } + + protected static void CreateTestClassesAndInterfaces(ModuleBuilder moduleBuilder, int numberOfClasses, int numberOfInterfaces) + { + + Type[] interfaces = new Type[numberOfInterfaces]; + for (int i = 1; i <= numberOfInterfaces; i++) + { + string interfaceName = $"ITestInterface{i}"; + interfaces[i - 1] = CreateInterface(moduleBuilder, interfaceName); + } + + for (int i = 1; i <= numberOfClasses; i++) + { + string className = $"TestClass{i}"; + Type interfaceType = interfaces[(i - 1) % numberOfInterfaces]; + CreateClass(moduleBuilder, className, interfaceType); + } + } + + protected List GenerateCombinations(Type[] types, int interfaces) + { + var groups = new List[interfaces]; + for (int i = 0; i < interfaces; i++) + { + groups[i] = types.Where((t, index) => index % interfaces == i).ToList(); + } + + return GenerateCombinationsRecursive(groups, 0); + } + + protected List GenerateCombinationsRecursive(List[] groups, int currentGroup) + { + var result = new List(); + + if (currentGroup == groups.Length) + { + result.Add(Array.Empty()); + return result; + } + + foreach (var type in groups[currentGroup]) + { + foreach (var subCombination in GenerateCombinationsRecursive(groups, currentGroup + 1)) + { + result.Add(new[] { type }.Concat(subCombination).ToArray()); + } + } + + return result; + } + + protected static int CalculateTotalCombinations(int numberOfClasses, int numberOfInterfaces, int numberOfTypeParameters) + { + var testClasses = Enumerable.Range(1, numberOfClasses) + .Select(i => $"TestClass{i}") + .ToArray(); + + var groups = new List[numberOfInterfaces]; + for (int i = 0; i < numberOfInterfaces; i++) + { + groups[i] = testClasses.Where((t, index) => index % numberOfInterfaces == i).ToList(); + } + + return groups + .Take(numberOfTypeParameters) + .Select(group => group.Count) + .Aggregate(1, (a, b) => a * b); + } + } +} diff --git a/test/MediatR.Tests/SendTests.cs b/test/MediatR.Tests/SendTests.cs index b489a808..553c689c 100644 --- a/test/MediatR.Tests/SendTests.cs +++ b/test/MediatR.Tests/SendTests.cs @@ -1,19 +1,19 @@ using System.Threading; -namespace MediatR.Tests; - using System; using System.Threading.Tasks; using Shouldly; using Xunit; -using Microsoft.Extensions.DependencyInjection; - +using Microsoft.Extensions.DependencyInjection; +using System.Reflection; + +namespace MediatR.Tests; public class SendTests { private readonly IServiceProvider _serviceProvider; private Dependency _dependency; - private readonly IMediator _mediator; - + private readonly IMediator _mediator; + public SendTests() { _dependency = new Dependency(); @@ -22,7 +22,6 @@ public SendTests() services.AddSingleton(_dependency); _serviceProvider = services.BuildServiceProvider(); _mediator = _serviceProvider.GetService()!; - } public class Ping : IRequest @@ -50,6 +49,7 @@ public Task Handle(Ping request, CancellationToken cancellationToken) public class Dependency { public bool Called { get; set; } + public bool CalledSpecific { get; set; } } public class VoidPingHandler : IRequestHandler @@ -103,8 +103,58 @@ public Task Handle(VoidGenericPing request, CancellationToken cancellationTok return Task.CompletedTask; } + } + + public class PongExtension : Pong + { + + } + + public class TestClass1PingRequestHandler : IRequestHandler> + { + private readonly Dependency _dependency; + + public TestClass1PingRequestHandler(Dependency dependency) => _dependency = dependency; + + public Task Handle(VoidGenericPing request, CancellationToken cancellationToken) + { + _dependency.CalledSpecific = true; + return Task.CompletedTask; + } + } + + public interface ITestInterface1 { } + public interface ITestInterface2 { } + public interface ITestInterface3 { } + + public class TestClass1 : ITestInterface1 { } + public class TestClass2 : ITestInterface2 { } + public class TestClass3 : ITestInterface3 { } + + public class MultipleGenericTypeParameterRequest : IRequest + where T1 : ITestInterface1 + where T2 : ITestInterface2 + where T3 : ITestInterface3 + { + public int Foo { get; set; } } + public class MultipleGenericTypeParameterRequestHandler : IRequestHandler, int> + where T1 : ITestInterface1 + where T2 : ITestInterface2 + where T3 : ITestInterface3 + { + private readonly Dependency _dependency; + + public MultipleGenericTypeParameterRequestHandler(Dependency dependency) => _dependency = dependency; + + public Task Handle(MultipleGenericTypeParameterRequest request, CancellationToken cancellationToken) + { + _dependency.Called = true; + return Task.FromResult(1); + } + } + [Fact] public async Task Should_resolve_main_handler() { @@ -183,4 +233,49 @@ public async Task Should_resolve_generic_void_handler() _dependency.Called.ShouldBeTrue(); } + + [Fact] + public async Task Should_resolve_multiple_type_parameter_generic_handler() + { + var request = new MultipleGenericTypeParameterRequest(); + await _mediator.Send(request); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_closed_handler_if_defined() + { + var dependency = new Dependency(); + var services = new ServiceCollection(); + services.AddSingleton(dependency); + services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblies(Assembly.GetExecutingAssembly())); + services.AddTransient>,TestClass1PingRequestHandler>(); + var serviceProvider = services.BuildServiceProvider(); + var mediator = serviceProvider.GetService()!; + + var request = new VoidGenericPing(); + await mediator.Send(request); + + dependency.Called.ShouldBeFalse(); + dependency.CalledSpecific.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_open_handler_if_not_defined() + { + var dependency = new Dependency(); + var services = new ServiceCollection(); + services.AddSingleton(dependency); + services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblies(Assembly.GetExecutingAssembly())); + services.AddTransient>, TestClass1PingRequestHandler>(); + var serviceProvider = services.BuildServiceProvider(); + var mediator = serviceProvider.GetService()!; + + var request = new VoidGenericPing(); + await mediator.Send(request); + + dependency.Called.ShouldBeTrue(); + dependency.CalledSpecific.ShouldBeFalse(); + } } \ No newline at end of file