From 9454da522e3d8570dbf80a017a6c7dcc29c93aa8 Mon Sep 17 00:00:00 2001 From: Dapeng Zhang Date: Mon, 16 Dec 2024 16:43:00 +0800 Subject: [PATCH] introduce the abstraction for tokencredential type and keycredential type (#5231) Fixes https://github.com/microsoft/typespec/issues/5235 Fixes https://github.com/microsoft/typespec/issues/4410 The ClientProviderTests class contains both general client provider features (such as service operations, query parameters, api-versions) and auth specific features. My PR is trying to separate them - therefore some test cases are moved to a new class. --- .../src/Primitives/ScmKnownParameters.cs | 2 - .../Abstractions/ClientPipelineApi.cs | 5 +- .../Abstractions/IClientPipelineApi.cs | 3 + .../src/Providers/ClientPipelineProvider.cs | 22 +- .../src/Providers/ClientProvider.cs | 242 ++++-- .../src/Providers/RestClientProvider.cs | 4 +- .../src/ScmTypeFactory.cs | 9 - .../OutputTypes/ScmKnownParametersTests.cs | 10 - .../Abstractions/ClientPipelineApiTests.cs | 11 +- .../ClientProviders/ClientProviderTests.cs | 705 ++++++++++++++---- ...yConstructor(WithDefault,False,False,0).cs | 6 + ...ryConstructor(WithDefault,False,True,0).cs | 8 + ...ryConstructor(WithDefault,True,False,0).cs | 8 + ...aryConstructor(WithDefault,True,True,0).cs | 8 + ...aryConstructor(WithDefault,True,True,1).cs | 8 + ...Constructor(WithRequired,False,False,0).cs | 6 + ...yConstructor(WithRequired,False,True,0).cs | 8 + ...yConstructor(WithRequired,True,False,0).cs | 8 + ...ryConstructor(WithRequired,True,True,0).cs | 8 + ...ryConstructor(WithRequired,True,True,1).cs | 8 + .../test/TestHelpers/MockHelpers.cs | 9 +- .../src/Generated/UnbrandedTypeSpecClient.cs | 2 +- 22 files changed, 819 insertions(+), 281 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,False,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,True,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,False,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,1).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,False,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,True,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,False,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,0).cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,1).cs diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Primitives/ScmKnownParameters.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Primitives/ScmKnownParameters.cs index fe2737a59a..00fa78fc93 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Primitives/ScmKnownParameters.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Primitives/ScmKnownParameters.cs @@ -30,8 +30,6 @@ internal static class ScmKnownParameters public static readonly ParameterProvider Data = new("data", FormattableStringHelpers.Empty, typeof(BinaryData)); public static ParameterProvider ClientOptions(CSharpType clientOptionsType) => new("options", $"The options for configuring the client.", clientOptionsType.WithNullable(true), initializationValue: New.Instance(clientOptionsType.WithNullable(true))); - public static readonly ParameterProvider KeyAuth = new("keyCredential", $"The token credential to copy", ClientModelPlugin.Instance.TypeFactory.KeyCredentialType); - public static readonly ParameterProvider MatchConditionsParameter = new("matchConditions", $"The content to send as the request conditions of the request.", ClientModelPlugin.Instance.TypeFactory.MatchConditionsType, DefaultOf(ClientModelPlugin.Instance.TypeFactory.MatchConditionsType)); public static readonly ParameterProvider OptionalRequestOptions = new( ClientModelPlugin.Instance.TypeFactory.HttpRequestOptionsApi.ParameterName, $"The request options, which can override default behaviors of the client pipeline on a per-call basis.", diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/ClientPipelineApi.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/ClientPipelineApi.cs index 627f5c90d5..a9928848d3 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/ClientPipelineApi.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/ClientPipelineApi.cs @@ -14,6 +14,8 @@ public abstract record ClientPipelineApi : ScopedApi, IClientPipelineApi public abstract CSharpType ClientPipelineType { get; } public abstract CSharpType ClientPipelineOptionsType { get; } public abstract CSharpType PipelinePolicyType { get; } + public abstract CSharpType? KeyCredentialType { get; } + public abstract CSharpType? TokenCredentialType { get; } protected ClientPipelineApi(Type type, ValueExpression original) : base(type, original) { @@ -26,7 +28,8 @@ protected ClientPipelineApi(Type type, ValueExpression original) : base(type, or public abstract ValueExpression Create(ValueExpression options, ValueExpression perRetryPolicies); - public abstract ValueExpression AuthorizationPolicy(params ValueExpression[] arguments); + public abstract ValueExpression KeyAuthorizationPolicy(ValueExpression credential, ValueExpression headerName, ValueExpression? keyPrefix = null); + public abstract ValueExpression TokenAuthorizationPolicy(ValueExpression credential, ValueExpression scopes); public abstract ClientPipelineApi FromExpression(ValueExpression expression); public abstract ClientPipelineApi ToExpression(); } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/IClientPipelineApi.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/IClientPipelineApi.cs index 85f51aa1ee..dca1ebc635 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/IClientPipelineApi.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/Abstractions/IClientPipelineApi.cs @@ -10,5 +10,8 @@ public interface IClientPipelineApi : IExpressionApi CSharpType ClientPipelineType { get; } CSharpType ClientPipelineOptionsType { get; } CSharpType PipelinePolicyType { get; } + + CSharpType? KeyCredentialType { get; } + CSharpType? TokenCredentialType { get; } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientPipelineProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientPipelineProvider.cs index 283861da85..5cf9279ea5 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientPipelineProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientPipelineProvider.cs @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; +using System.ClientModel; using System.ClientModel.Primitives; using Microsoft.Generator.CSharp.Expressions; using Microsoft.Generator.CSharp.Primitives; -using Microsoft.Generator.CSharp.Statements; using Microsoft.Generator.CSharp.Snippets; +using Microsoft.Generator.CSharp.Statements; using static Microsoft.Generator.CSharp.Snippets.Snippet; namespace Microsoft.Generator.CSharp.ClientModel.Providers @@ -25,6 +27,10 @@ public ClientPipelineProvider(ValueExpression original) : base(typeof(ClientPipe public override CSharpType PipelinePolicyType => typeof(PipelinePolicy); + public override CSharpType KeyCredentialType => typeof(ApiKeyCredential); + + public override CSharpType? TokenCredentialType => null; // Scm library does not support token credentials yet. + public override ValueExpression Create(ValueExpression options, ValueExpression perRetryPolicies) => Static().Invoke(nameof(ClientPipeline.Create), [options, New.Array(ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.PipelinePolicyType), perRetryPolicies, New.Array(ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.PipelinePolicyType)]).As(); @@ -34,8 +40,18 @@ public override ValueExpression CreateMessage(HttpRequestOptionsApi requestOptio public override ClientPipelineApi FromExpression(ValueExpression expression) => new ClientPipelineProvider(expression); - public override ValueExpression AuthorizationPolicy(params ValueExpression[] arguments) - => Static().Invoke(nameof(ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy), arguments).As(); + public override ValueExpression KeyAuthorizationPolicy(ValueExpression credential, ValueExpression headerName, ValueExpression? keyPrefix = null) + { + ValueExpression[] arguments = keyPrefix == null ? [credential, headerName] : [credential, headerName, keyPrefix]; + return Static().Invoke(nameof(ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy), arguments).As(); + } + + public override ValueExpression TokenAuthorizationPolicy(ValueExpression credential, ValueExpression scopes) + { + // Scm library does not support token credentials yet. The throw here is intentional. + // For a plugin that supports token credentials, they could override this implementation as well as the above TokenCredentialType property. + throw new NotImplementedException(); + } public override ClientPipelineApi ToExpression() => this; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs index 6ea60345e8..c050d10981 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs @@ -19,9 +19,15 @@ namespace Microsoft.Generator.CSharp.ClientModel.Providers { public class ClientProvider : TypeProvider { + private record AuthFields(FieldProvider AuthField); + private record ApiKeyFields(FieldProvider AuthField, FieldProvider AuthorizationHeaderField, FieldProvider? AuthorizationApiKeyPrefixField) : AuthFields(AuthField); + private record OAuth2Fields(FieldProvider AuthField, FieldProvider AuthorizationScopesField) : AuthFields(AuthField); + private const string AuthorizationHeaderConstName = "AuthorizationHeader"; private const string AuthorizationApiKeyPrefixConstName = "AuthorizationApiKeyPrefix"; private const string ApiKeyCredentialFieldName = "_keyCredential"; + private const string TokenCredentialScopesFieldName = "AuthorizationScopes"; + private const string TokenCredentialFieldName = "_tokenCredential"; private const string EndpointFieldName = "_endpoint"; private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; @@ -29,11 +35,12 @@ public class ClientProvider : TypeProvider private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; private readonly FieldProvider? _clientCachingField; - private readonly FieldProvider? _apiKeyAuthField; - private readonly FieldProvider? _authorizationHeaderConstant; - private readonly FieldProvider? _authorizationApiKeyPrefixConstant; + + private readonly ApiKeyFields? _apiKeyAuthFields; + private readonly OAuth2Fields? _oauth2Fields; + private FieldProvider? _apiVersionField; - private readonly List _subClientInternalConstructorParams; + private readonly Lazy> _subClientInternalConstructorParams; private IReadOnlyList>? _subClients; private ParameterProvider? _clientOptionsParameter; private ClientOptionsProvider? _clientOptions; @@ -61,24 +68,50 @@ public ClientProvider(InputClient inputClient) _publicCtorDescription = $"Initializes a new instance of {Name}."; var apiKey = _inputAuth?.ApiKey; - _apiKeyAuthField = apiKey != null ? new FieldProvider( - FieldModifiers.Private | FieldModifiers.ReadOnly, - ClientModelPlugin.Instance.TypeFactory.KeyCredentialType, - ApiKeyCredentialFieldName, - this, - description: $"A credential used to authenticate to the service.") : null; - _authorizationHeaderConstant = apiKey?.Name != null ? new( - FieldModifiers.Private | FieldModifiers.Const, - typeof(string), - AuthorizationHeaderConstName, - this, - initializationValue: Literal(apiKey.Name)) : null; - _authorizationApiKeyPrefixConstant = apiKey?.Prefix != null ? new( - FieldModifiers.Private | FieldModifiers.Const, - typeof(string), - AuthorizationApiKeyPrefixConstName, - this, - initializationValue: Literal(apiKey.Prefix)) : null; + var keyCredentialType = ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.KeyCredentialType; + if (apiKey != null && keyCredentialType != null) + { + var apiKeyAuthField = new FieldProvider( + FieldModifiers.Private | FieldModifiers.ReadOnly, + keyCredentialType, + ApiKeyCredentialFieldName, + this, + description: $"A credential used to authenticate to the service."); + var authorizationHeaderField = new FieldProvider( + FieldModifiers.Private | FieldModifiers.Const, + typeof(string), + AuthorizationHeaderConstName, + this, + initializationValue: Literal(apiKey.Name)); + var authorizationApiKeyPrefixField = apiKey.Prefix != null ? + new FieldProvider( + FieldModifiers.Private | FieldModifiers.Const, + typeof(string), + AuthorizationApiKeyPrefixConstName, + this, + initializationValue: Literal(apiKey.Prefix)) : + null; + _apiKeyAuthFields = new(apiKeyAuthField, authorizationHeaderField, authorizationApiKeyPrefixField); + } + // in this plugin, the type of TokenCredential is null therefore these code will never be executed, but it should be invoked in other plugins that could support it. + var tokenAuth = _inputAuth?.OAuth2; + var tokenCredentialType = ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.TokenCredentialType; + if (tokenAuth != null && tokenCredentialType != null) + { + var tokenCredentialField = new FieldProvider( + FieldModifiers.Private | FieldModifiers.ReadOnly, + tokenCredentialType, + TokenCredentialFieldName, + this, + description: $"A credential used to authenticate to the service."); + var tokenCredentialScopesField = new FieldProvider( + FieldModifiers.Private | FieldModifiers.Static | FieldModifiers.ReadOnly, + typeof(string[]), + TokenCredentialScopesFieldName, + this, + initializationValue: New.Array(typeof(string), tokenAuth.Scopes.Select(Literal).ToArray())); + _oauth2Fields = new(tokenCredentialField, tokenCredentialScopesField); + } EndpointField = new( FieldModifiers.Private | FieldModifiers.ReadOnly, typeof(Uri), @@ -92,10 +125,6 @@ public ClientProvider(InputClient inputClient) body: new AutoPropertyBody(false), enclosingType: this); - _subClientInternalConstructorParams = _apiKeyAuthField != null - ? [PipelineProperty.AsParameter, _apiKeyAuthField.AsParameter, _endpointParameter] - : [PipelineProperty.AsParameter, _endpointParameter]; - if (_inputClient.Parent != null) { // _clientCachingField will only have subClients (children) @@ -108,23 +137,43 @@ public ClientProvider(InputClient inputClient) } _endpointParameterName = new(GetEndpointParameterName); - _additionalClientFields = new Lazy>(() => BuildAdditionalClientFields()); + _additionalClientFields = new(BuildAdditionalClientFields); _allClientParameters = _inputClient.Parameters.Concat(_inputClient.Operations.SelectMany(op => op.Parameters).Where(p => p.Kind == InputOperationParameterKind.Client)).DistinctBy(p => p.Name).ToArray(); + _subClientInternalConstructorParams = new(GetSubClientInternalConstructorParameters); + _clientParameters = new(GetClientParameters); + } - foreach (var field in _additionalClientFields.Value) + private IReadOnlyList GetSubClientInternalConstructorParameters() + { + var subClientParameters = new List + { + PipelineProperty.AsParameter + }; + + if (_apiKeyAuthFields != null) + { + subClientParameters.Add(_apiKeyAuthFields.AuthField.AsParameter); + } + if (_oauth2Fields != null) { - _subClientInternalConstructorParams.Add(field.AsParameter); + subClientParameters.Add(_oauth2Fields.AuthField.AsParameter); } + subClientParameters.Add(_endpointParameter); + subClientParameters.AddRange(ClientParameters); + + return subClientParameters; } - private List? _clientParameters; - internal IReadOnlyList GetClientParameters() + private Lazy> _clientParameters; + internal IReadOnlyList ClientParameters => _clientParameters.Value; + private IReadOnlyList GetClientParameters() { - if (_clientParameters is null) + var parameters = new List(_additionalClientFields.Value.Count); + foreach (var field in _additionalClientFields.Value) { - _ = Constructors; + parameters.Add(field.AsParameter); } - return _clientParameters ?? []; + return parameters; } private Lazy _endpointParameterName; @@ -159,17 +208,22 @@ protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; - if (_apiKeyAuthField != null && _authorizationHeaderConstant != null) + if (_apiKeyAuthFields != null) { - fields.Add(_authorizationHeaderConstant); - fields.Add(_apiKeyAuthField); - - if (_authorizationApiKeyPrefixConstant != null) + fields.Add(_apiKeyAuthFields.AuthField); + fields.Add(_apiKeyAuthFields.AuthorizationHeaderField); + if (_apiKeyAuthFields.AuthorizationApiKeyPrefixField != null) { - fields.Add(_authorizationApiKeyPrefixConstant); + fields.Add(_apiKeyAuthFields.AuthorizationApiKeyPrefixField); } } + if (_oauth2Fields != null) + { + fields.Add(_oauth2Fields.AuthField); + fields.Add(_oauth2Fields.AuthorizationScopesField); + } + fields.AddRange(_additionalClientFields.Value); // add sub-client caching fields @@ -225,9 +279,8 @@ protected override ConstructorProvider[] BuildConstructors() // handle sub-client constructors if (ClientOptionsParameter is null) { - _clientParameters = _subClientInternalConstructorParams; List body = new(3) { EndpointField.Assign(_endpointParameter).Terminate() }; - foreach (var p in _subClientInternalConstructorParams) + foreach (var p in _subClientInternalConstructorParams.Value) { var assignment = p.Field?.Assign(p).Terminate() ?? p.Property?.Assign(p).Terminate(); if (assignment != null) @@ -236,35 +289,61 @@ protected override ConstructorProvider[] BuildConstructors() } } var subClientConstructor = new ConstructorProvider( - new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Internal, _subClientInternalConstructorParams), + new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Internal, _subClientInternalConstructorParams.Value), body, this); return [mockingConstructor, subClientConstructor]; } - var requiredParameters = GetRequiredParameters(); - ParameterProvider[] primaryConstructorParameters = [_endpointParameter, .. requiredParameters, ClientOptionsParameter]; - var primaryConstructor = new ConstructorProvider( - new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Public, primaryConstructorParameters), - BuildPrimaryConstructorBody(primaryConstructorParameters), - this); + // we need to construct two sets of constructors for both auth we supported if any. + var primaryConstructors = new List(); + var secondaryConstructors = new List(); - // If the endpoint parameter contains an initialization value, it is not required. - ParameterProvider[] secondaryConstructorParameters = _endpointParameter.InitializationValue is null - ? [_endpointParameter, .. requiredParameters] - : [.. requiredParameters]; - var secondaryConstructor = BuildSecondaryConstructor(secondaryConstructorParameters, primaryConstructorParameters); - var shouldIncludeMockingConstructor = secondaryConstructorParameters.Length > 0 || _apiKeyAuthField != null; + // if there is key auth + if (_apiKeyAuthFields != null) + { + AppendConstructors(_apiKeyAuthFields, primaryConstructors, secondaryConstructors); + } + // if there is oauth2 auth + if (_oauth2Fields!= null) + { + AppendConstructors(_oauth2Fields, primaryConstructors, secondaryConstructors); + } + // if there is no auth + if (_apiKeyAuthFields == null && _oauth2Fields == null) + { + AppendConstructors(null, primaryConstructors, secondaryConstructors); + } + var shouldIncludeMockingConstructor = secondaryConstructors.All(c => c.Signature.Parameters.Count > 0); return shouldIncludeMockingConstructor - ? [ConstructorProviderHelper.BuildMockingConstructor(this), secondaryConstructor, primaryConstructor] - : [secondaryConstructor, primaryConstructor]; + ? [ConstructorProviderHelper.BuildMockingConstructor(this), .. secondaryConstructors, .. primaryConstructors] + : [.. secondaryConstructors, .. primaryConstructors]; + + void AppendConstructors(AuthFields? authFields, List primaryConstructors, List secondaryConstructors) + { + var requiredParameters = GetRequiredParameters(authFields?.AuthField); + ParameterProvider[] primaryConstructorParameters = [_endpointParameter, .. requiredParameters, ClientOptionsParameter]; + var primaryConstructor = new ConstructorProvider( + new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Public, primaryConstructorParameters), + BuildPrimaryConstructorBody(primaryConstructorParameters, authFields), + this); + + primaryConstructors.Add(primaryConstructor); + + // If the endpoint parameter contains an initialization value, it is not required. + ParameterProvider[] secondaryConstructorParameters = _endpointParameter.InitializationValue is null + ? [_endpointParameter, .. requiredParameters] + : [.. requiredParameters]; + var secondaryConstructor = BuildSecondaryConstructor(secondaryConstructorParameters, primaryConstructorParameters); + + secondaryConstructors.Add(secondaryConstructor); + } } - private IReadOnlyList GetRequiredParameters() + private IReadOnlyList GetRequiredParameters(FieldProvider? authField) { List requiredParameters = []; - _clientParameters = []; ParameterProvider? currentParam = null; foreach (var parameter in _allClientParameters) @@ -275,11 +354,10 @@ private IReadOnlyList GetRequiredParameters() currentParam = CreateParameter(parameter); requiredParameters.Add(currentParam); } - _clientParameters.Add(currentParam ?? CreateParameter(parameter)); } - if (_apiKeyAuthField is not null) - requiredParameters.Add(_apiKeyAuthField.AsParameter); + if (authField is not null) + requiredParameters.Add(authField.AsParameter); return requiredParameters; } @@ -291,7 +369,7 @@ private ParameterProvider CreateParameter(InputParameter parameter) return param; } - private MethodBodyStatement[] BuildPrimaryConstructorBody(IReadOnlyList primaryConstructorParameters) + private MethodBodyStatement[] BuildPrimaryConstructorBody(IReadOnlyList primaryConstructorParameters, AuthFields? authFields) { if (ClientOptions is null || ClientOptionsParameter is null) { @@ -314,14 +392,19 @@ private MethodBodyStatement[] BuildPrimaryConstructorBody(IReadOnlyList().AuthorizationPolicy(authorizationPolicyArgs)); + case ApiKeyFields keyAuthFields: + ValueExpression? keyPrefixExpression = keyAuthFields.AuthorizationApiKeyPrefixField != null ? (ValueExpression)keyAuthFields.AuthorizationApiKeyPrefixField : null; + perRetryPolicies = New.Array(ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.PipelinePolicyType, isInline: true, This.ToApi().KeyAuthorizationPolicy(keyAuthFields.AuthField, keyAuthFields.AuthorizationHeaderField, keyPrefixExpression)); + break; + case OAuth2Fields oauth2AuthFields: + perRetryPolicies = New.Array(ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.PipelinePolicyType, isInline: true, This.ToApi().TokenAuthorizationPolicy(oauth2AuthFields.AuthField, oauth2AuthFields.AuthorizationScopesField)); + break; + default: + perRetryPolicies = New.Array(ClientModelPlugin.Instance.TypeFactory.ClientPipelineApi.PipelinePolicyType); + break; } body.Add(PipelineProperty.Assign(This.ToApi().Create(ClientOptionsParameter, perRetryPolicies)).Terminate()); @@ -329,18 +412,13 @@ private MethodBodyStatement[] BuildPrimaryConstructorBody(IReadOnlyList p.Name.ToCleanName()); foreach (var f in Fields) { - if (f != _apiKeyAuthField - && f != EndpointField - && !f.Modifiers.HasFlag(FieldModifiers.Const)) + if (f == _apiVersionField && ClientOptions.VersionProperty != null) { - if (f == _apiVersionField && ClientOptions.VersionProperty != null) - { - body.Add(f.Assign(ClientOptionsParameter.Property(ClientOptions.VersionProperty.Name)).Terminate()); - } - else if (clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out var optionsProperty)) - { - clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out optionsProperty); - } + body.Add(f.Assign(ClientOptionsParameter.Property(ClientOptions.VersionProperty.Name)).Terminate()); + } + else if (clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out var optionsProperty)) + { + clientOptionsPropertyDict.TryGetValue(f.Name.ToCleanName(), out optionsProperty); } } @@ -410,7 +488,7 @@ protected override MethodProvider[] BuildMethods() List subClientConstructorArgs = new(3); // Populate constructor arguments - foreach (var param in subClientInstance._subClientInternalConstructorParams) + foreach (var param in subClientInstance._subClientInternalConstructorParams.Value) { if (parentClientProperties.TryGetValue(param.Name, out var parentProperty)) { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs index 7c597bf4c8..0376b6d5eb 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs @@ -97,7 +97,7 @@ private MethodProvider BuildCreateRequestMethod(InputOperation operation) [.. parameters, options]); var paramMap = new Dictionary(signature.Parameters.ToDictionary(p => p.Name)); - foreach (var param in ClientProvider.GetClientParameters()) + foreach (var param in ClientProvider.ClientParameters) { paramMap[param.Name] = param; } @@ -356,7 +356,7 @@ private void AddUriSegments( /* when the parameter is in operation.uri, it is client parameter * It is not operation parameter and not in inputParamHash list. */ - var isClientParameter = ClientProvider.GetClientParameters().Any(p => p.Name == paramName); + var isClientParameter = ClientProvider.ClientParameters.Any(p => p.Name == paramName); CSharpType? type; string? format; ValueExpression valueExpression; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs index 18364b2d46..3a63e2614b 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs @@ -2,21 +2,16 @@ // Licensed under the MIT License. using System; -using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; -using System.IO; -using System.Net; using System.Text.Json; using Microsoft.Generator.CSharp.ClientModel.Providers; -using Microsoft.Generator.CSharp.ClientModel.Snippets; using Microsoft.Generator.CSharp.Expressions; using Microsoft.Generator.CSharp.Input; using Microsoft.Generator.CSharp.Primitives; using Microsoft.Generator.CSharp.Providers; using Microsoft.Generator.CSharp.Snippets; using Microsoft.Generator.CSharp.Statements; -using static Microsoft.Generator.CSharp.Snippets.Snippet; namespace Microsoft.Generator.CSharp.ClientModel { @@ -27,10 +22,6 @@ public class ScmTypeFactory : TypeFactory public virtual CSharpType MatchConditionsType => typeof(PipelineMessageClassifier); - public virtual CSharpType KeyCredentialType => typeof(ApiKeyCredential); - - public virtual CSharpType TokenCredentialType => throw new NotImplementedException("Token credential is not supported in Scm libraries yet"); - public virtual IClientResponseApi ClientResponseApi => ClientResultProvider.Instance; public virtual IHttpResponseApi HttpResponseApi => PipelineResponseProvider.Instance; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/OutputTypes/ScmKnownParametersTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/OutputTypes/ScmKnownParametersTests.cs index 091102c8de..063aa6c659 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/OutputTypes/ScmKnownParametersTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/OutputTypes/ScmKnownParametersTests.cs @@ -8,16 +8,6 @@ namespace Microsoft.Generator.CSharp.ClientModel.Tests.OutputTypes { internal class ScmKnownParametersTests { - [Test] - public void TestTokenAuth() - { - MockHelpers.LoadMockPlugin(keyCredentialType: () => typeof(int)); - - var result = ClientModelPlugin.Instance.TypeFactory.KeyCredentialType; - Assert.IsNotNull(result); - Assert.AreEqual(new CSharpType(typeof(int)), result); - } - [TestCase] public void TestMatchConditionsParameter() { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Abstractions/ClientPipelineApiTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Abstractions/ClientPipelineApiTests.cs index f7e5bd92c9..5e968b87b0 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Abstractions/ClientPipelineApiTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Abstractions/ClientPipelineApiTests.cs @@ -55,6 +55,10 @@ public TestClientPipelineApi(ValueExpression original) : base(typeof(string), or public override CSharpType PipelinePolicyType => typeof(string); + public override CSharpType KeyCredentialType => typeof(object); + + public override CSharpType TokenCredentialType => typeof(object); + public override ValueExpression Create(ValueExpression options, ValueExpression perRetryPolicies) => Original.Invoke("GetFakeCreate", [options, perRetryPolicies]); @@ -64,8 +68,11 @@ public override ValueExpression CreateMessage(HttpRequestOptionsApi requestOptio public override ClientPipelineApi FromExpression(ValueExpression expression) => new TestClientPipelineApi(expression); - public override ValueExpression AuthorizationPolicy(params ValueExpression[] arguments) - => Original.Invoke("GetFakeAuthorizationPolicy", arguments); + public override ValueExpression KeyAuthorizationPolicy(ValueExpression credential, ValueExpression headerName, ValueExpression? keyPrefix = null) + => Original.Invoke("GetFakeAuthorizationPolicy", keyPrefix == null ? [credential, headerName] : [credential, headerName, keyPrefix]); + + public override ValueExpression TokenAuthorizationPolicy(ValueExpression credential, ValueExpression scopes) + => Original.Invoke("GetFakeTokenAuthorizationPolicy", [credential, scopes]); public override ClientPipelineApi ToExpression() => this; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderTests.cs index 68b84320f2..f2bf2050cc 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderTests.cs @@ -6,6 +6,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using Microsoft.Generator.CSharp.ClientModel.Providers; using Microsoft.Generator.CSharp.Expressions; using Microsoft.Generator.CSharp.Input; @@ -22,6 +23,8 @@ namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.ClientProviders public class ClientProviderTests { private const string SubClientsCategory = "WithSubClients"; + private const string KeyAuthCategory = "WithKeyAuth"; + private const string OAuth2Category = "WithOAuth2"; private const string TestClientName = "TestClient"; private static readonly InputClient _animalClient = new("animal", "", "AnimalClient description", [], [], TestClientName); private static readonly InputClient _dogClient = new("dog", "", "DogClient description", [], [], _animalClient.Name); @@ -34,22 +37,30 @@ public class ClientProviderTests InputFactory.Property("p1", InputPrimitiveType.String, isRequired: true), ]); + private bool _containsSubClients; + private bool _hasKeyAuth; + private bool _hasOAuth2; + private bool _hasAuth; + [SetUp] public void SetUp() { var categories = TestContext.CurrentContext.Test?.Properties["Category"]; - bool containsSubClients = categories?.Contains(SubClientsCategory) ?? false; - - if (containsSubClients) - { - MockHelpers.LoadMockPlugin( - apiKeyAuth: () => new InputApiKeyAuth("mock", null), - clients: () => [_animalClient, _dogClient, _huskyClient]); - } - else - { - MockHelpers.LoadMockPlugin(apiKeyAuth: () => new InputApiKeyAuth("mock", null)); - } + _containsSubClients = categories?.Contains(SubClientsCategory) ?? false; + _hasKeyAuth = categories?.Contains(KeyAuthCategory) ?? false; + _hasOAuth2 = categories?.Contains(OAuth2Category) ?? false; + _hasAuth = _hasKeyAuth || _hasOAuth2; + + Func>? clients = _containsSubClients ? + () => [_animalClient, _dogClient, _huskyClient] : + null; + Func? apiKeyAuth = _hasKeyAuth ? () => new InputApiKeyAuth("mock", null) : null; + Func? oauth2Auth = _hasOAuth2 ? () => new InputOAuth2Auth(["mock"]) : null; + MockHelpers.LoadMockPlugin( + apiKeyAuth: apiKeyAuth, + oauth2Auth: oauth2Auth, + clients: clients, + clientPipelineApi: TestClientPipelineApi.Instance); } [Test] @@ -73,81 +84,133 @@ public void TestBuildProperties() } [TestCaseSource(nameof(BuildFieldsTestCases))] - public void TestBuildFields(List inputParameters, bool containsAdditionalParams) + public void TestBuildFields(List inputParameters, List expectedFields) { var client = InputFactory.Client(TestClientName, parameters: [.. inputParameters]); var clientProvider = new ClientProvider(client); Assert.IsNotNull(clientProvider); - // validate the fields - var fields = clientProvider.Fields; - if (containsAdditionalParams) - { - Assert.AreEqual(6, fields.Count); + AssertHasFields(clientProvider, expectedFields); + } - } - else + [TestCaseSource(nameof(BuildAuthFieldsTestCases), Category = KeyAuthCategory)] + [TestCaseSource(nameof(BuildAuthFieldsTestCases), Category = OAuth2Category)] + [TestCaseSource(nameof(BuildAuthFieldsTestCases), Category = $"{KeyAuthCategory},{OAuth2Category}")] + public void TestBuildAuthFields_WithAuth(List inputParameters) + { + var client = InputFactory.Client(TestClientName, parameters: [.. inputParameters]); + var clientProvider = new ClientProvider(client); + + Assert.IsNotNull(clientProvider); + + if (_hasKeyAuth) { - Assert.AreEqual(4, fields.Count); + // key auth should have the following fields: AuthorizationHeader, _keyCredential + AssertHasFields(clientProvider, new List + { + new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential") + }); } - - // validate the endpoint field - if (inputParameters.Any(p => p.IsEndpoint)) + if (_hasOAuth2) { - var endpointField = fields.FirstOrDefault(f => f.Name == "_endpoint"); - Assert.IsNotNull(endpointField); - Assert.AreEqual(new CSharpType(typeof(Uri)), endpointField?.Type); + // oauth2 auth should have the following fields: AuthorizationScopes, _tokenCredential + AssertHasFields(clientProvider, new List + { + new(FieldModifiers.Private | FieldModifiers.Static | FieldModifiers.ReadOnly, new CSharpType(typeof(string[])), "AuthorizationScopes"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(FakeTokenCredential)), "_tokenCredential"), + }); } + } - // validate other parameters as fields - if (containsAdditionalParams) - { - var optionalParamField = fields.FirstOrDefault(f => f.Name == "_optionalNullableParam"); - Assert.IsNotNull(optionalParamField); - Assert.AreEqual(new CSharpType(typeof(string), isNullable: true), optionalParamField?.Type); + [TestCaseSource(nameof(BuildAuthFieldsTestCases))] + public void TestBuildAuthFields_NoAuth(List inputParameters) + { + var client = InputFactory.Client(TestClientName, parameters: [.. inputParameters]); + var clientProvider = new ClientProvider(client); - var requiredParam2Field = fields.FirstOrDefault(f => f.Name == "_requiredParam2"); - Assert.IsNotNull(requiredParam2Field); - Assert.AreEqual(new CSharpType(typeof(string), isNullable: false), requiredParam2Field?.Type); + Assert.IsNotNull(clientProvider); - var requiredParam3Field = fields.FirstOrDefault(f => f.Name == "_requiredParam3"); - Assert.IsNotNull(requiredParam3Field); - Assert.AreEqual(new CSharpType(typeof(long), isNullable: false), requiredParam3Field?.Type); + // fields here should not have anything related with auth + bool authFieldFound = false; + foreach (var field in clientProvider.Fields) + { + if (field.Name.EndsWith("Credential") || field.Name.Contains("Authorization")) + { + authFieldFound = true; + } } + + Assert.IsFalse(authFieldFound); } // validates the fields are built correctly when a client has sub-clients - [TestCaseSource(nameof(SubClientTestCases), Category = SubClientsCategory)] - public void TestBuildFields_WithSubClients(InputClient client, bool hasSubClients) + [TestCaseSource(nameof(SubClientFieldsTestCases), Category = SubClientsCategory)] + public void TestBuildFields_WithSubClients(InputClient client, List expectedFields) { var clientProvider = new ClientProvider(client); Assert.IsNotNull(clientProvider); - // validate the fields - var fields = clientProvider.Fields; + AssertHasFields(clientProvider, expectedFields); + } - // validate the endpoint field - var endpointField = fields.FirstOrDefault(f => f.Name == "_endpoint"); - Assert.IsNotNull(endpointField); - Assert.AreEqual(new CSharpType(typeof(Uri)), endpointField?.Type); + // validates the credential fields are built correctly when a client has sub-clients + [TestCaseSource(nameof(SubClientAuthFieldsTestCases), Category = SubClientsCategory)] + public void TestBuildAuthFields_WithSubClients_NoAuth(InputClient client) + { + var clientProvider = new ClientProvider(client); - // there should be n number of caching client fields for every direct sub-client + endpoint field + auth fields - if (hasSubClients) + Assert.IsNotNull(clientProvider); + + // fields here should not have anything related with auth + bool authFieldFound = false; + foreach (var field in clientProvider.Fields) { - Assert.AreEqual(4, fields.Count); - var cachedClientFields = fields.Where(f => f.Name.StartsWith("_cached")); - Assert.AreEqual(1, cachedClientFields.Count()); + if (field.Name.EndsWith("Credential") || field.Name.Contains("Authorization")) + { + authFieldFound = true; + } } - else + + Assert.IsFalse(authFieldFound); + } + + // validates the credential fields are built correctly when a client has sub-clients + [TestCaseSource(nameof(SubClientAuthFieldsTestCases), Category = $"{SubClientsCategory},{KeyAuthCategory}")] + [TestCaseSource(nameof(SubClientAuthFieldsTestCases), Category = $"{SubClientsCategory},{OAuth2Category}")] + [TestCaseSource(nameof(SubClientAuthFieldsTestCases), Category = $"{SubClientsCategory},{KeyAuthCategory},{OAuth2Category}")] + public void TestBuildAuthFields_WithSubClients_WithAuth(InputClient client) + { + var clientProvider = new ClientProvider(client); + + Assert.IsNotNull(clientProvider); + + if (_hasKeyAuth) { - // The 3 fields are _endpoint, AuthorizationHeader, and _keyCredential - Assert.AreEqual(3, fields.Count); + // key auth should have the following fields: AuthorizationHeader, _keyCredential + AssertHasFields(clientProvider, new List + { + new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential") + }); + } + if (_hasOAuth2) + { + // oauth2 auth should have the following fields: AuthorizationScopes, _tokenCredential + AssertHasFields(clientProvider, new List + { + new(FieldModifiers.Private | FieldModifiers.Static | FieldModifiers.ReadOnly, new CSharpType(typeof(string[])), "AuthorizationScopes"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(FakeTokenCredential)), "_tokenCredential"), + }); } } [TestCaseSource(nameof(BuildConstructorsTestCases))] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = KeyAuthCategory)] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = OAuth2Category)] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = $"{KeyAuthCategory},{OAuth2Category}")] public void TestBuildConstructors_PrimaryConstructor(List inputParameters) { var client = InputFactory.Client(TestClientName, parameters: [.. inputParameters]); @@ -156,14 +219,25 @@ public void TestBuildConstructors_PrimaryConstructor(List inputP Assert.IsNotNull(clientProvider); var constructors = clientProvider.Constructors; - Assert.AreEqual(3, constructors.Count); - var primaryPublicConstructor = constructors.FirstOrDefault( - c => c.Signature?.Initializer == null && c.Signature?.Modifiers == MethodSignatureModifiers.Public); - ValidatePrimaryConstructor(primaryPublicConstructor, inputParameters); + var primaryPublicConstructors = constructors.Where( + c => c.Signature?.Initializer == null && c.Signature?.Modifiers == MethodSignatureModifiers.Public).ToArray(); + + // for no auth or one auth case, this should be 1 + // for both auth case, this should be 2 + var expectedPrimaryCtorCount = _hasKeyAuth && _hasOAuth2 ? 2 : 1; + Assert.AreEqual(expectedPrimaryCtorCount, primaryPublicConstructors.Length); + + for (int i = 0; i < primaryPublicConstructors.Length; i++) + { + ValidatePrimaryConstructor(primaryPublicConstructors[i], inputParameters, i); + } } [TestCaseSource(nameof(BuildConstructorsTestCases))] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = KeyAuthCategory)] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = OAuth2Category)] + [TestCaseSource(nameof(BuildConstructorsTestCases), Category = $"{KeyAuthCategory},{OAuth2Category}")] public void TestBuildConstructors_SecondaryConstructor(List inputParameters) { var client = InputFactory.Client(TestClientName, parameters: [.. inputParameters]); @@ -173,19 +247,46 @@ public void TestBuildConstructors_SecondaryConstructor(List inpu var constructors = clientProvider.Constructors; - Assert.AreEqual(3, constructors.Count); - var primaryPublicConstructor = constructors.FirstOrDefault( - c => c.Signature?.Initializer == null && c.Signature?.Modifiers == MethodSignatureModifiers.Public); + var primaryPublicConstructors = constructors.Where( + c => c.Signature?.Initializer == null && c.Signature?.Modifiers == MethodSignatureModifiers.Public).ToArray(); + var secondaryPublicConstructors = constructors.Where( + c => c.Signature?.Initializer != null && c.Signature?.Modifiers == MethodSignatureModifiers.Public).ToArray(); + + // for no auth or one auth case, this should be 1 + // for both auth case, this should be 2 + var expectedSecondaryCtorCount = _hasKeyAuth && _hasOAuth2 ? 2 : 1; + Assert.AreEqual(expectedSecondaryCtorCount, secondaryPublicConstructors.Length); + foreach (var secondaryPublicConstructor in secondaryPublicConstructors) + { + ValidateSecondaryConstructor(primaryPublicConstructors, secondaryPublicConstructor, inputParameters); + } + } - Assert.IsNotNull(primaryPublicConstructor); + [TestCase] + public void TestBuildConstructors_ForSubClient_NoAuth() + { + var clientProvider = new ClientProvider(_animalClient); - var secondaryPublicConstructor = constructors.FirstOrDefault( - c => c.Signature?.Initializer != null && c.Signature?.Modifiers == MethodSignatureModifiers.Public); - ValidateSecondaryConstructor(primaryPublicConstructor, secondaryPublicConstructor, inputParameters); + Assert.IsNotNull(clientProvider); + + var constructors = clientProvider.Constructors; + + Assert.AreEqual(2, constructors.Count); + var internalConstructor = constructors.FirstOrDefault( + c => c.Signature?.Modifiers == MethodSignatureModifiers.Internal); + Assert.IsNotNull(internalConstructor); + // in the no auth case, the ctor no longer has the credentail parameter therefore here we expect 2 parameters. + var ctorParams = internalConstructor?.Signature?.Parameters; + Assert.AreEqual(2, ctorParams?.Count); + + var mockingConstructor = constructors.FirstOrDefault( + c => c.Signature?.Modifiers == MethodSignatureModifiers.Protected); + Assert.IsNotNull(mockingConstructor); } - [Test] - public void TestBuildConstructors_ForSubClient() + [TestCase(Category = KeyAuthCategory)] + [TestCase(Category = OAuth2Category)] + public void TestBuildConstructors_ForSubClient_KeyAuthOrOAuth2Auth() { var clientProvider = new ClientProvider(_animalClient); @@ -197,6 +298,7 @@ public void TestBuildConstructors_ForSubClient() var internalConstructor = constructors.FirstOrDefault( c => c.Signature?.Modifiers == MethodSignatureModifiers.Internal); Assert.IsNotNull(internalConstructor); + // when there is only one approach of auth, we have 3 parameters in the ctor. var ctorParams = internalConstructor?.Signature?.Parameters; Assert.AreEqual(3, ctorParams?.Count); @@ -205,22 +307,46 @@ public void TestBuildConstructors_ForSubClient() Assert.IsNotNull(mockingConstructor); } - private static void ValidatePrimaryConstructor( - ConstructorProvider? primaryPublicConstructor, - List inputParameters) + [TestCase(Category = $"{KeyAuthCategory},{OAuth2Category}")] + public void TestBuildConstructors_ForSubClient_BothAuth() { - Assert.IsNotNull(primaryPublicConstructor); + var clientProvider = new ClientProvider(_animalClient); + + Assert.IsNotNull(clientProvider); + + var constructors = clientProvider.Constructors; + + Assert.AreEqual(2, constructors.Count); + var internalConstructor = constructors.FirstOrDefault( + c => c.Signature?.Modifiers == MethodSignatureModifiers.Internal); + Assert.IsNotNull(internalConstructor); + // when we have both auths, we have 4 parameters in the ctor, because now we should have two credential parameters + var ctorParams = internalConstructor?.Signature?.Parameters; + Assert.AreEqual(4, ctorParams?.Count); + var mockingConstructor = constructors.FirstOrDefault( + c => c.Signature?.Modifiers == MethodSignatureModifiers.Protected); + Assert.IsNotNull(mockingConstructor); + } + + private void ValidatePrimaryConstructor( + ConstructorProvider primaryPublicConstructor, + List inputParameters, + int ctorIndex, + [CallerMemberName] string method = "", + [CallerFilePath] string filePath = "") + { var primaryCtorParams = primaryPublicConstructor?.Signature?.Parameters; - var expectedPrimaryCtorParamCount = 3; + // in no auth case, the ctor only have two parameters: endpoint and options + // in other cases, the ctor should have three parameters: endpoint, credential, options + // specifically, in both auth cases, we should have two ctors corresponding to each credential type as the second parameter + var expectedPrimaryCtorParamCount = !_hasKeyAuth && !_hasOAuth2 ? 2 : 3; Assert.AreEqual(expectedPrimaryCtorParamCount, primaryCtorParams?.Count); - // validate the order of the parameters (endpoint, credential, client options) + // the first should be endpoint var endpointParam = primaryCtorParams?[0]; Assert.AreEqual(KnownParameters.Endpoint.Name, endpointParam?.Name); - Assert.AreEqual("keyCredential", primaryCtorParams?[1].Name); - Assert.AreEqual("options", primaryCtorParams?[2].Name); if (endpointParam?.DefaultValue != null) { @@ -229,41 +355,79 @@ private static void ValidatePrimaryConstructor( Assert.AreEqual(Literal(parsedValue), endpointParam?.InitializationValue); } + // the last parameter should be the options + var optionsParam = primaryCtorParams?[^1]; + Assert.AreEqual("options", optionsParam?.Name); + + if (_hasAuth) + { + // when there is any auth, the second should be auth parameter + var authParam = primaryCtorParams?[1]; + Assert.IsNotNull(authParam); + if (authParam?.Name == "keyCredential") + { + Assert.AreEqual(new CSharpType(typeof(ApiKeyCredential)), authParam?.Type); + } + else if (authParam?.Name == "tokenCredential") + { + Assert.AreEqual(new CSharpType(typeof(FakeTokenCredential)), authParam?.Type); + } + else + { + Assert.Fail("Unexpected auth parameter"); + } + } + // validate the body of the primary ctor + var caseName = TestContext.CurrentContext.Test.Properties.Get("caseName"); + var expected = Helpers.GetExpectedFromFile($"{caseName},{_hasKeyAuth},{_hasOAuth2},{ctorIndex}", method, filePath); var primaryCtorBody = primaryPublicConstructor?.BodyStatements; Assert.IsNotNull(primaryCtorBody); + Assert.AreEqual(expected, primaryCtorBody?.ToDisplayString()); } private void ValidateSecondaryConstructor( - ConstructorProvider? primaryConstructor, - ConstructorProvider? secondaryPublicConstructor, + IReadOnlyList primaryConstructors, + ConstructorProvider secondaryPublicConstructor, List inputParameters) { - Assert.IsNotNull(secondaryPublicConstructor); - var ctorParams = secondaryPublicConstructor?.Signature?.Parameters; + var ctorParams = secondaryPublicConstructor.Signature?.Parameters; - // secondary ctor should consist of all required parameters + auth parameter + // secondary ctor should consist of all required parameters + auth parameter (when present) var requiredParams = inputParameters.Where(p => p.IsRequired).ToList(); - Assert.AreEqual(requiredParams.Count + 1, ctorParams?.Count); + var authParameterCount = _hasAuth ? 1 : 0; + Assert.AreEqual(requiredParams.Count + authParameterCount, ctorParams?.Count); var endpointParam = ctorParams?.FirstOrDefault(p => p.Name == KnownParameters.Endpoint.Name); if (requiredParams.Count == 0) { - // auth should be the only parameter if endpoint is optional - Assert.AreEqual("keyCredential", ctorParams?[0].Name); + // auth should be the only parameter if endpoint is optional when there is auth + if (_hasAuth) + { + Assert.IsTrue(ctorParams?[0].Name.EndsWith("Credential")); + } + else + { + // when there is no auth, the ctor should not have parameters + Assert.AreEqual(0, ctorParams?.Count); + } } else { // otherwise, it should only consist of the auth parameter Assert.AreEqual(KnownParameters.Endpoint.Name, ctorParams?[0].Name); - Assert.AreEqual("keyCredential", ctorParams?[1].Name); + if (_hasAuth) + { + Assert.IsTrue(ctorParams?[1].Name.EndsWith("Credential")); + } } Assert.AreEqual(MethodBodyStatement.Empty, secondaryPublicConstructor?.BodyStatements); // validate the initializer var initializer = secondaryPublicConstructor?.Signature?.Initializer; - Assert.AreEqual(primaryConstructor?.Signature?.Parameters?.Count, initializer?.Arguments?.Count); + Assert.NotNull(initializer); + Assert.IsTrue(primaryConstructors.Any(pc => pc.Signature.Parameters.Count == initializer?.Arguments.Count)); } [TestCaseSource(nameof(EndpointParamInitializationValueTestCases))] @@ -309,7 +473,7 @@ public void TestGetClientOptions(bool isSubClient) } } - [TestCaseSource(nameof(SubClientTestCases), Category = SubClientsCategory)] + [TestCaseSource(nameof(SubClientFactoryMethodTestCases), Category = SubClientsCategory)] public void TestSubClientAccessorFactoryMethods(InputClient client, bool hasSubClients) { var clientProvider = new ClientProvider(client); @@ -341,7 +505,6 @@ public void TestSubClientAccessorFactoryMethods(InputClient client, bool hasSubC { Assert.AreEqual(0, subClientAccessorFactoryMethods.Count); } - } [Test] @@ -426,7 +589,6 @@ public void ValidateClientWithSpread(InputClient inputClient) Assert.AreEqual(new CSharpType(typeof(string)), convenienceMethods[0].Signature.Parameters[0].Type); Assert.AreEqual("p1", convenienceMethods[0].Signature.Parameters[0].Name); - } [TestCaseSource(nameof(RequestOptionsParameterInSignatureTestCases))] @@ -597,6 +759,59 @@ protected override MethodProvider[] BuildMethods() protected override PropertyProvider[] BuildProperties() => []; } + public static IEnumerable BuildAuthFieldsTestCases + { + get + { + yield return new TestCaseData(new List + { + InputFactory.Parameter( + "optionalParam", + InputPrimitiveType.String, + location: RequestLocation.None, + kind: InputOperationParameterKind.Client), + InputFactory.Parameter( + KnownParameters.Endpoint.Name, + InputPrimitiveType.String, + location:RequestLocation.None, + kind: InputOperationParameterKind.Client, + isEndpoint: true) + }); + yield return new TestCaseData(new List + { + // have to explicitly set isRequired because we now call CreateParameter in buildFields + InputFactory.Parameter( + "optionalNullableParam", + InputPrimitiveType.String, + location: RequestLocation.None, + defaultValue: InputFactory.Constant.String("someValue"), + kind: InputOperationParameterKind.Client, + isRequired: false), + InputFactory.Parameter( + "requiredParam2", + InputPrimitiveType.String, + location: RequestLocation.None, + defaultValue: InputFactory.Constant.String("someValue"), + kind: InputOperationParameterKind.Client, + isRequired: true), + InputFactory.Parameter( + "requiredParam3", + InputPrimitiveType.Int64, + location: RequestLocation.None, + defaultValue: InputFactory.Constant.Int64(2), + kind: InputOperationParameterKind.Client, + isRequired: true), + InputFactory.Parameter( + KnownParameters.Endpoint.Name, + InputPrimitiveType.String, + location: RequestLocation.None, + defaultValue: null, + kind: InputOperationParameterKind.Client, + isEndpoint: true) + }); + } + } + public static IEnumerable BuildFieldsTestCases { get @@ -614,7 +829,13 @@ public static IEnumerable BuildFieldsTestCases location:RequestLocation.None, kind: InputOperationParameterKind.Client, isEndpoint: true) - }, false); + }, + new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), true), "_optionalParam") + } + ); yield return new TestCaseData(new List { // have to explicitly set isRequired because we now call CreateParameter in buildFields @@ -646,18 +867,51 @@ public static IEnumerable BuildFieldsTestCases defaultValue: null, kind: InputOperationParameterKind.Client, isEndpoint: true) - }, true); + }, + new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), true), "_optionalNullableParam"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), false), "_requiredParam2"), + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(long), false), "_requiredParam3") + }); } } - public static IEnumerable SubClientTestCases + public static IEnumerable SubClientAuthFieldsTestCases { get { - yield return new TestCaseData(InputFactory.Client(TestClientName), true); - yield return new TestCaseData(_animalClient, true); - yield return new TestCaseData(_dogClient, true); - yield return new TestCaseData(_huskyClient, false); + yield return new TestCaseData(InputFactory.Client(TestClientName)); + yield return new TestCaseData(_animalClient); + yield return new TestCaseData(_dogClient); + yield return new TestCaseData(_huskyClient); + } + } + + public static IEnumerable SubClientFieldsTestCases + { + get + { + yield return new TestCaseData(InputFactory.Client(TestClientName), new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"), + new(FieldModifiers.Private, new ExpectedCSharpType("Animal", "Sample", false), "_cachedAnimal"), + }); + yield return new TestCaseData(_animalClient, new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"), + new(FieldModifiers.Private, new ExpectedCSharpType("Dog", "Sample", false), "_cachedDog"), + }); + yield return new TestCaseData(_dogClient, new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"), + new(FieldModifiers.Private, new ExpectedCSharpType("Husky", "Sample", false), "_cachedHusky"), + }); + yield return new TestCaseData(_huskyClient, new List + { + new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint") + }); } } @@ -684,6 +938,17 @@ public static IEnumerable ValidateClientWithSpreadTestCases } } + public static IEnumerable SubClientFactoryMethodTestCases + { + get + { + yield return new TestCaseData(InputFactory.Client(TestClientName), true); + yield return new TestCaseData(_animalClient, true); + yield return new TestCaseData(_dogClient, true); + yield return new TestCaseData(_huskyClient, false); + } + } + public static IEnumerable BuildConstructorsTestCases { get @@ -702,7 +967,7 @@ public static IEnumerable BuildConstructorsTestCases defaultValue: InputFactory.Constant.String("someValue"), kind: InputOperationParameterKind.Client, isEndpoint: true) - }); + }).SetProperty("caseName", "WithDefault"); // scenario where endpoint is required yield return new TestCaseData(new List { @@ -718,7 +983,7 @@ public static IEnumerable BuildConstructorsTestCases InputPrimitiveType.String, location: RequestLocation.None, kind: InputOperationParameterKind.Client) - }); + }).SetProperty("caseName", "WithRequired"); } } @@ -841,83 +1106,199 @@ public static IEnumerable RequestOptionsParameterInSignatureTestCa } } - private static IEnumerable EndpointParamInitializationValueTestCases() + private static IEnumerable EndpointParamInitializationValueTestCases + { + get + { + // string primitive type + yield return new TestCaseData( + InputFactory.Parameter( + "param", + InputPrimitiveType.String, + location: RequestLocation.None, + kind: InputOperationParameterKind.Client, + isEndpoint: true, + defaultValue: InputFactory.Constant.String("mockValue")), + New.Instance(KnownParameters.Endpoint.Type, Literal("mockvalue"))); + } + } + + private static IEnumerable ValidateApiVersionPathParameterTestCases { - // string primitive type - yield return new TestCaseData( - InputFactory.Parameter( - "param", + get + { + InputParameter endpointParameter = InputFactory.Parameter( + "endpoint", InputPrimitiveType.String, - location: RequestLocation.None, + location: RequestLocation.Uri, + isRequired: true, kind: InputOperationParameterKind.Client, isEndpoint: true, - defaultValue: InputFactory.Constant.String("mockValue")), - New.Instance(KnownParameters.Endpoint.Type, Literal("mockvalue"))); - } - - private static IEnumerable ValidateApiVersionPathParameterTestCases() - { - InputParameter endpointParameter = InputFactory.Parameter( - "endpoint", - InputPrimitiveType.String, - location: RequestLocation.Uri, - isRequired: true, - kind: InputOperationParameterKind.Client, - isEndpoint: true, - isApiVersion: false); - - InputParameter stringApiVersionParameter = InputFactory.Parameter( - "apiVersion", - InputPrimitiveType.String, - location: RequestLocation.Uri, - isRequired: true, - kind: InputOperationParameterKind.Client, - isApiVersion: true); - - InputParameter enumApiVersionParameter = InputFactory.Parameter( - "apiVersion", - InputFactory.Enum( - "InputEnum", + isApiVersion: false); + + InputParameter stringApiVersionParameter = InputFactory.Parameter( + "apiVersion", InputPrimitiveType.String, - usage: InputModelTypeUsage.Input, - isExtensible: true, - values: - [ - InputFactory.EnumMember.String("value1", "value1"), + location: RequestLocation.Uri, + isRequired: true, + kind: InputOperationParameterKind.Client, + isApiVersion: true); + + InputParameter enumApiVersionParameter = InputFactory.Parameter( + "apiVersion", + InputFactory.Enum( + "InputEnum", + InputPrimitiveType.String, + usage: InputModelTypeUsage.Input, + isExtensible: true, + values: + [ + InputFactory.EnumMember.String("value1", "value1"), InputFactory.EnumMember.String("value2", "value2") - ]), - location: RequestLocation.Uri, - isRequired: true, - kind: InputOperationParameterKind.Client, - isApiVersion: true); - - yield return new TestCaseData( - InputFactory.Client( - "TestClient", - operations: - [ - InputFactory.Operation( + ]), + location: RequestLocation.Uri, + isRequired: true, + kind: InputOperationParameterKind.Client, + isApiVersion: true); + + yield return new TestCaseData( + InputFactory.Client( + "TestClient", + operations: + [ + InputFactory.Operation( "TestOperation", uri: "{endpoint}/{apiVersion}") - ], - parameters: [ - endpointParameter, + ], + parameters: [ + endpointParameter, stringApiVersionParameter - ])); + ])); - yield return new TestCaseData( - InputFactory.Client( - "TestClient", - operations: - [ - InputFactory.Operation( + yield return new TestCaseData( + InputFactory.Client( + "TestClient", + operations: + [ + InputFactory.Operation( "TestOperation", uri: "{endpoint}/{apiVersion}") - ], - parameters: [ - endpointParameter, + ], + parameters: [ + endpointParameter, enumApiVersionParameter - ])); + ])); + } + } + + // TODO -- this is temporary here before System.ClientModel officially supports OAuth2 auth + private record TestClientPipelineApi : ClientPipelineProvider + { + private static ClientPipelineApi? _instance; + internal new static ClientPipelineApi Instance => _instance ??= new TestClientPipelineApi(Empty); + + public TestClientPipelineApi(ValueExpression original) : base(original) + { + } + + public override CSharpType TokenCredentialType => typeof(FakeTokenCredential); + + public override ClientPipelineApi FromExpression(ValueExpression expression) + => new TestClientPipelineApi(expression); + + public override ValueExpression TokenAuthorizationPolicy(ValueExpression credential, ValueExpression scopes) + => Original.Invoke("GetFakeTokenAuthorizationPolicy", [credential, scopes]); + + public override ClientPipelineApi ToExpression() => this; + } + + internal class FakeTokenCredential { } + + public record ExpectedCSharpType + { + public string Name { get; } + + public string Namespace { get; } + + public bool IsFrameworkType { get; } + + public Type FrameworkType => _frameworkType ?? throw new InvalidOperationException(); + + public bool IsNullable { get; } + + private readonly Type? _frameworkType; + + public ExpectedCSharpType(Type frameworkType, bool isNullable) + { + _frameworkType = frameworkType; + IsFrameworkType = true; + IsNullable = isNullable; + Name = frameworkType.Name; + Namespace = frameworkType.Namespace!; + } + + public ExpectedCSharpType(string name, string ns, bool isNullable) + { + IsFrameworkType = false; + IsNullable = isNullable; + Name = name; + Namespace = ns; + } + + public static implicit operator ExpectedCSharpType(CSharpType type) + { + if (type.IsFrameworkType) + { + return new(type.FrameworkType, type.IsNullable); + } + else + { + return new(type.Name, type.Namespace, type.IsNullable); + } + } + } + + public record ExpectedFieldProvider(FieldModifiers Modifiers, ExpectedCSharpType Type, string Name); + + private static void AssertCSharpTypeAreEqual(ExpectedCSharpType expected, CSharpType type) + { + if (expected.IsFrameworkType) + { + Assert.IsTrue(type.IsFrameworkType); + Assert.AreEqual(expected.FrameworkType, type.FrameworkType); + } + else + { + Assert.IsFalse(type.IsFrameworkType); + Assert.AreEqual(expected.Name, type.Name); + Assert.AreEqual(expected.Namespace, type.Namespace); + } + Assert.AreEqual(expected.IsNullable, type.IsNullable); + } + + private static void AssertFieldAreEqual(ExpectedFieldProvider expected, FieldProvider field) + { + Assert.AreEqual(expected.Name, field.Name); + AssertCSharpTypeAreEqual(expected.Type, field.Type); + Assert.AreEqual(expected.Modifiers, field.Modifiers); + } + + private static void AssertHasFields(TypeProvider provider, IReadOnlyList expectedFields) + { + var fields = provider.Fields; + + // validate the length of the result + Assert.GreaterOrEqual(fields.Count, expectedFields.Count); + + // validate each of them + var fieldDict = fields.ToDictionary(f => f.Name); + for (int i = 0; i < expectedFields.Count; i++) + { + var expected = expectedFields[i]; + + Assert.IsTrue(fieldDict.TryGetValue(expected.Name, out var actual), $"Field {expected.Name} not present"); + AssertFieldAreEqual(expected, actual!); + } } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,False,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,False,0).cs new file mode 100644 index 0000000000..3fcf6856cf --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,False,0).cs @@ -0,0 +1,6 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), Array.Empty(), Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,True,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,True,0).cs new file mode 100644 index 0000000000..71d1e06f4d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,False,True,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(tokenCredential, nameof(tokenCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_tokenCredential = tokenCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { this.GetFakeTokenAuthorizationPolicy(_tokenCredential, AuthorizationScopes) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,False,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,False,0).cs new file mode 100644 index 0000000000..227fc6726b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,False,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(keyCredential, nameof(keyCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_keyCredential = keyCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { global::System.ClientModel.Primitives.ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(_keyCredential, AuthorizationHeader) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,0).cs new file mode 100644 index 0000000000..227fc6726b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(keyCredential, nameof(keyCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_keyCredential = keyCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { global::System.ClientModel.Primitives.ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(_keyCredential, AuthorizationHeader) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,1).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,1).cs new file mode 100644 index 0000000000..71d1e06f4d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithDefault,True,True,1).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(tokenCredential, nameof(tokenCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_tokenCredential = tokenCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { this.GetFakeTokenAuthorizationPolicy(_tokenCredential, AuthorizationScopes) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,False,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,False,0).cs new file mode 100644 index 0000000000..3fcf6856cf --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,False,0).cs @@ -0,0 +1,6 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), Array.Empty(), Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,True,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,True,0).cs new file mode 100644 index 0000000000..71d1e06f4d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,False,True,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(tokenCredential, nameof(tokenCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_tokenCredential = tokenCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { this.GetFakeTokenAuthorizationPolicy(_tokenCredential, AuthorizationScopes) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,False,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,False,0).cs new file mode 100644 index 0000000000..227fc6726b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,False,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(keyCredential, nameof(keyCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_keyCredential = keyCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { global::System.ClientModel.Primitives.ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(_keyCredential, AuthorizationHeader) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,0).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,0).cs new file mode 100644 index 0000000000..227fc6726b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,0).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(keyCredential, nameof(keyCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_keyCredential = keyCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { global::System.ClientModel.Primitives.ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(_keyCredential, AuthorizationHeader) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,1).cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,1).cs new file mode 100644 index 0000000000..71d1e06f4d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderTests/TestBuildConstructors_PrimaryConstructor(WithRequired,True,True,1).cs @@ -0,0 +1,8 @@ +global::Sample.Argument.AssertNotNull(endpoint, nameof(endpoint)); +global::Sample.Argument.AssertNotNull(tokenCredential, nameof(tokenCredential)); + +options ??= new global::Sample.TestClientOptions(); + +_endpoint = endpoint; +_tokenCredential = tokenCredential; +Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty(), new global::System.ClientModel.Primitives.PipelinePolicy[] { this.GetFakeTokenAuthorizationPolicy(_tokenCredential, AuthorizationScopes) }, Array.Empty()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/TestHelpers/MockHelpers.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/TestHelpers/MockHelpers.cs index dc80039dca..d699ff9921 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/TestHelpers/MockHelpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/TestHelpers/MockHelpers.cs @@ -47,9 +47,9 @@ public static Mock LoadMockPlugin( Func>? createSerializationsCore = null, Func? createCSharpTypeCore = null, Func? matchConditionsType = null, - Func? keyCredentialType = null, Func? createParameterCore = null, Func? apiKeyAuth = null, + Func? oauth2Auth = null, Func>? apiVersions = null, Func>? inputEnums = null, Func>? inputModels = null, @@ -65,7 +65,7 @@ public static Mock LoadMockPlugin( IReadOnlyList inputNsEnums = inputEnums?.Invoke() ?? []; IReadOnlyList inputNsClients = clients?.Invoke() ?? []; IReadOnlyList inputNsModels = inputModels?.Invoke() ?? []; - InputAuth inputNsAuth = apiKeyAuth != null ? new InputAuth(apiKeyAuth(), null) : new InputAuth(); + InputAuth inputNsAuth = new InputAuth(apiKeyAuth?.Invoke(), oauth2Auth?.Invoke()); var mockTypeFactory = new Mock() { CallBase = true }; var mockInputNs = new Mock( string.Empty, @@ -82,11 +82,6 @@ public static Mock LoadMockPlugin( mockTypeFactory.Setup(p => p.MatchConditionsType).Returns(matchConditionsType); } - if (keyCredentialType is not null) - { - mockTypeFactory.Setup(p => p.KeyCredentialType).Returns(keyCredentialType); - } - if (createParameterCore is not null) { mockTypeFactory.Protected().Setup("CreateParameterCore", ItExpr.IsAny()).Returns(createParameterCore); diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.cs b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.cs index 278862d6b0..0173263461 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.cs @@ -17,9 +17,9 @@ namespace UnbrandedTypeSpec public partial class UnbrandedTypeSpecClient { private readonly Uri _endpoint; - private const string AuthorizationHeader = "my-api-key"; /// A credential used to authenticate to the service. private readonly ApiKeyCredential _keyCredential; + private const string AuthorizationHeader = "my-api-key"; private readonly string _apiVersion; /// Initializes a new instance of UnbrandedTypeSpecClient for mocking.