Skip to content

Commit

Permalink
refactor code and finish test cases for token credential
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcturusZhang committed Dec 3, 2024
1 parent 8f62878 commit 2136cec
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class ClientProvider : TypeProvider
private readonly FieldProvider? _tokenCredentialScopesField;
#endregion
private FieldProvider? _apiVersionField;
private readonly List<ParameterProvider> _subClientInternalConstructorParams;
private readonly Lazy<IReadOnlyList<ParameterProvider>> _subClientInternalConstructorParams;
private IReadOnlyList<Lazy<ClientProvider>>? _subClients;
private ParameterProvider? _clientOptionsParameter;
private ClientOptionsProvider? _clientOptions;
Expand Down Expand Up @@ -130,10 +130,6 @@ public ClientProvider(InputClient inputClient)
body: new AutoPropertyBody(false),
enclosingType: this);

_subClientInternalConstructorParams = _apiKeyAuthField != null
? [PipelineProperty.AsParameter, _apiKeyAuthField.AsParameter, _endpointParameter]
: [PipelineProperty.AsParameter, _endpointParameter];

if (_inputClient.Parent != null)
{
// _clientCachingField will only have subClients (children)
Expand All @@ -146,23 +142,43 @@ public ClientProvider(InputClient inputClient)
}

_endpointParameterName = new(GetEndpointParameterName);
_additionalClientFields = new Lazy<List<FieldProvider>>(() => BuildAdditionalClientFields());
_additionalClientFields = new(BuildAdditionalClientFields);
_allClientParameters = _inputClient.Parameters.Concat(_inputClient.Operations.SelectMany(op => op.Parameters).Where(p => p.Kind == InputOperationParameterKind.Client)).DistinctBy(p => p.Name).ToArray();
_subClientInternalConstructorParams = new(GetSubClientInternalConstructorParameters);
_clientParameters = new(GetClientParameters);
}

foreach (var field in _additionalClientFields.Value)
private IReadOnlyList<ParameterProvider> GetSubClientInternalConstructorParameters()
{
var subClientParameters = new List<ParameterProvider>
{
PipelineProperty.AsParameter
};

if (_apiKeyAuthField != null)
{
subClientParameters.Add(_apiKeyAuthField.AsParameter);
}
if (_tokenCredentialField != null)
{
_subClientInternalConstructorParams.Add(field.AsParameter);
subClientParameters.Add(_tokenCredentialField.AsParameter);
}
subClientParameters.Add(_endpointParameter);
subClientParameters.AddRange(ClientParameters);

return subClientParameters;
}

private List<ParameterProvider>? _clientParameters;
internal IReadOnlyList<ParameterProvider> GetClientParameters()
private Lazy<IReadOnlyList<ParameterProvider>> _clientParameters;
internal IReadOnlyList<ParameterProvider> ClientParameters => _clientParameters.Value;
private IReadOnlyList<ParameterProvider> GetClientParameters()
{
if (_clientParameters is null)
var parameters = new List<ParameterProvider>(_additionalClientFields.Value.Count);
foreach (var field in _additionalClientFields.Value)
{
_ = Constructors;
parameters.Add(field.AsParameter);
}
return _clientParameters ?? [];
return parameters;
}

private Lazy<string?> _endpointParameterName;
Expand Down Expand Up @@ -272,20 +288,19 @@ protected override ConstructorProvider[] BuildConstructors()
// handle sub-client constructors
if (ClientOptionsParameter is null)
{
_clientParameters = _subClientInternalConstructorParams;
List<MethodBodyStatement> body = new(3) { EndpointField.Assign(_endpointParameter).Terminate() };
foreach (var p in _subClientInternalConstructorParams)
{
var assignment = p.Field?.Assign(p).Terminate() ?? p.Property?.Assign(p).Terminate();
if (assignment != null)
List<MethodBodyStatement> body = new(3) { EndpointField.Assign(_endpointParameter).Terminate() };
foreach (var p in _subClientInternalConstructorParams.Value)
{
body.Add(assignment);
var assignment = p.Field?.Assign(p).Terminate() ?? p.Property?.Assign(p).Terminate();
if (assignment != null)
{
body.Add(assignment);
}
}
}
var subClientConstructor = new ConstructorProvider(
new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Internal, _subClientInternalConstructorParams),
body,
this);
var subClientConstructor = new ConstructorProvider(
new ConstructorSignature(Type, _publicCtorDescription, MethodSignatureModifiers.Internal, _subClientInternalConstructorParams.Value),
body,
this);

return [mockingConstructor, subClientConstructor];
}
Expand Down Expand Up @@ -338,7 +353,6 @@ void AppendConstructors(FieldProvider? authField, List<ConstructorProvider> prim
private IReadOnlyList<ParameterProvider> GetRequiredParameters(FieldProvider? authField)
{
List<ParameterProvider> requiredParameters = [];
_clientParameters = [];

ParameterProvider? currentParam = null;
foreach (var parameter in _allClientParameters)
Expand All @@ -349,7 +363,6 @@ private IReadOnlyList<ParameterProvider> GetRequiredParameters(FieldProvider? au
currentParam = CreateParameter(parameter);
requiredParameters.Add(currentParam);
}
_clientParameters.Add(currentParam ?? CreateParameter(parameter));
}

if (authField is not null)
Expand Down Expand Up @@ -482,7 +495,7 @@ protected override MethodProvider[] BuildMethods()
List<ValueExpression> subClientConstructorArgs = new(3);

// Populate constructor arguments
foreach (var param in subClientInstance._subClientInternalConstructorParams)
foreach (var param in subClientInstance._subClientInternalConstructorParams.Value)
{
if (parentClientProperties.TryGetValue(param.Name, out var parentProperty))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private MethodProvider BuildCreateRequestMethod(InputOperation operation)
[.. parameters, options]);
var paramMap = new Dictionary<string, ParameterProvider>(signature.Parameters.ToDictionary(p => p.Name));

foreach (var param in ClientProvider.GetClientParameters())
foreach (var param in ClientProvider.ClientParameters)
{
paramMap[param.Name] = param;
}
Expand Down Expand Up @@ -386,7 +386,7 @@ private void AddUriSegments(
/* when the parameter is in operation.uri, it is client parameter
* It is not operation parameter and not in inputParamHash list.
*/
var isClientParameter = ClientProvider.GetClientParameters().Any(p => p.Name == paramName);
var isClientParameter = ClientProvider.ClientParameters.Any(p => p.Name == paramName);
CSharpType? type;
string? format;
ValueExpression valueExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,17 @@ public void SetUp()
{
MockHelpers.LoadMockPlugin(
oauth2Auth: () => new InputOAuth2Auth(["mock"]),
clients: () => [_animalClient, _dogClient, _huskyClient]);
clients: () => [_animalClient, _dogClient, _huskyClient],
clientPipelineApi: TestClientPipelineApi.Instance);
}
else
{
MockHelpers.LoadMockPlugin(oauth2Auth: () => new InputOAuth2Auth(["mock"]));
MockHelpers.LoadMockPlugin(
oauth2Auth: () => new InputOAuth2Auth(["mock"]),
clientPipelineApi: TestClientPipelineApi.Instance);
}
}

[Test]
public void TestBuildProperties()
{
var client = InputFactory.Client(TestClientName);
var clientProvider = new ClientProvider(client);

Assert.IsNotNull(clientProvider);

// validate the properties
var properties = clientProvider.Properties;
Assert.IsTrue(properties.Count > 0);
// there should be a pipeline property
Assert.AreEqual(1, properties.Count);

var pipelineProperty = properties[0];
Assert.AreEqual(typeof(ClientPipeline), pipelineProperty.Type.FrameworkType);
Assert.AreEqual("Pipeline", pipelineProperty.Name);
Assert.AreEqual(MethodSignatureModifiers.Public, pipelineProperty.Modifiers);
}

[TestCaseSource(nameof(BuildFieldsTestCases))]
public void TestBuildFields(List<InputParameter> inputParameters, List<ExpectedFieldProvider> expectedFields)
{
Expand Down Expand Up @@ -160,7 +143,7 @@ private static void ValidatePrimaryConstructor(
// validate the order of the parameters (endpoint, credential, client options)
var endpointParam = primaryCtorParams?[0];
Assert.AreEqual(KnownParameters.Endpoint.Name, endpointParam?.Name);
Assert.AreEqual("keyCredential", primaryCtorParams?[1].Name);
Assert.AreEqual("tokenCredential", primaryCtorParams?[1].Name);
Assert.AreEqual("options", primaryCtorParams?[2].Name);

if (endpointParam?.DefaultValue != null)
Expand Down Expand Up @@ -191,13 +174,13 @@ private void ValidateSecondaryConstructor(
if (requiredParams.Count == 0)
{
// auth should be the only parameter if endpoint is optional
Assert.AreEqual("keyCredential", ctorParams?[0].Name);
Assert.AreEqual("tokenCredential", ctorParams?[0].Name);
}
else
{
// otherwise, it should only consist of the auth parameter
Assert.AreEqual(KnownParameters.Endpoint.Name, ctorParams?[0].Name);
Assert.AreEqual("keyCredential", ctorParams?[1].Name);
Assert.AreEqual("tokenCredential", ctorParams?[1].Name);
}

Assert.AreEqual(MethodBodyStatement.Empty, secondaryPublicConstructor?.BodyStatements);
Expand Down Expand Up @@ -282,8 +265,8 @@ public static IEnumerable<TestCaseData> BuildFieldsTestCases
new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), true), "_optionalParam")
}
);
Expand Down Expand Up @@ -322,8 +305,8 @@ public static IEnumerable<TestCaseData> BuildFieldsTestCases
new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), true), "_optionalNullableParam"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(string), false), "_requiredParam2"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(long), false), "_requiredParam3")
Expand All @@ -338,29 +321,29 @@ public static IEnumerable<TestCaseData> SubClientFieldsTestCases
yield return new TestCaseData(InputFactory.Client(TestClientName), new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential"),
new(FieldModifiers.Private, new ExpectedCSharpType("Animal", "Sample", true), "_cachedAnimal"),
});
yield return new TestCaseData(_animalClient, new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential"),
new(FieldModifiers.Private, new ExpectedCSharpType("Dog", "Sample", true), "_cachedDog"),
});
yield return new TestCaseData(_dogClient, new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential"),
new(FieldModifiers.Private, new ExpectedCSharpType("Husky", "Sample", true), "_cachedHusky"),
});
yield return new TestCaseData(_huskyClient, new List<ExpectedFieldProvider>
{
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(Uri)), "_endpoint"),
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string)), "AuthorizationHeader"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(ApiKeyCredential)), "_keyCredential")
new(FieldModifiers.Private | FieldModifiers.Const, new CSharpType(typeof(string[])), "AuthorizationScopes"),
new(FieldModifiers.Private | FieldModifiers.ReadOnly, new CSharpType(typeof(TestTokenCredential)), "_tokenCredential")
});
}
}
Expand Down Expand Up @@ -427,5 +410,50 @@ private static IEnumerable<TestCaseData> EndpointParamInitializationValueTestCas
defaultValue: InputFactory.Constant.String("mockValue")),
New.Instance(KnownParameters.Endpoint.Type, Literal("mockvalue")));
}

private record TestClientPipelineApi : ClientPipelineApi
{
private static ClientPipelineApi? _instance;
internal static ClientPipelineApi Instance => _instance ??= new TestClientPipelineApi(Empty);

public TestClientPipelineApi(ValueExpression original) : base(typeof(string), original)
{
}

public override CSharpType ClientPipelineType => typeof(string);

public override CSharpType ClientPipelineOptionsType => typeof(string);

public override CSharpType PipelinePolicyType => typeof(string);

public override CSharpType? KeyCredentialType => null;

public override CSharpType TokenCredentialType => typeof(TestTokenCredential);

public override ValueExpression Create(ValueExpression options, ValueExpression perRetryPolicies)
=> Original.Invoke("GetFakeCreate", [options, perRetryPolicies]);

public override ValueExpression CreateMessage(HttpRequestOptionsApi requestOptions, ValueExpression responseClassifier)
=> Original.Invoke("GetFakeCreateMessage", [requestOptions, responseClassifier]);

public override ClientPipelineApi FromExpression(ValueExpression expression)
=> new TestClientPipelineApi(expression);

public override ValueExpression ConsumeKeyAuth(ValueExpression credential, ValueExpression headerName, ValueExpression? keyPrefix = null)
=> throw new InvalidOperationException("ApiKey is not supported in this test");

public override ValueExpression ConsumeOAuth2Auth(ValueExpression credential, ValueExpression scopes)
=> Original.Invoke("GetFakeTokenAuthorizationPolicy", [credential, scopes]);

public override ClientPipelineApi ToExpression() => this;

public override MethodBodyStatement[] ProcessMessage(HttpMessageApi message, HttpRequestOptionsApi options)
=> [Original.Invoke("GetFakeProcessMessage", [message, options]).Terminate()];

public override MethodBodyStatement[] ProcessMessageAsync(HttpMessageApi message, HttpRequestOptionsApi options)
=> [Original.Invoke("GetFakeProcessMessageAsync", [message, options]).Terminate()];
}

internal class TestTokenCredential { }
}
}

0 comments on commit 2136cec

Please sign in to comment.