Skip to content

Commit

Permalink
Handle repeatability header params in client (#4203)
Browse files Browse the repository at this point in the history
Fixes: #3936
  • Loading branch information
jorgerangel-msft authored Aug 19, 2024
1 parent da540e1 commit 261fa52
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,15 @@ public static ParameterProvider ClientOptions(CSharpType clientOptionsType)
public static readonly ParameterProvider MatchConditionsParameter = new("matchConditions", $"The content to send as the request conditions of the request.", ClientModelPlugin.Instance.TypeFactory.MatchConditionsType(), DefaultOf(ClientModelPlugin.Instance.TypeFactory.MatchConditionsType()));
public static readonly ParameterProvider RequestOptions = new("options", $"The request options, which can override default behaviors of the client pipeline on a per-call basis.", typeof(RequestOptions));
public static readonly ParameterProvider BinaryContent = new("content", $"The content to send as the body of the request.", typeof(BinaryContent)) { Validation = ParameterValidationType.AssertNotNull };

// Known header parameters
public static readonly ParameterProvider RepeatabilityRequestId = new("repeatabilityRequestId", FormattableStringHelpers.Empty, typeof(Guid))
{
DefaultValue = Static(typeof(Guid)).Invoke(nameof(Guid.NewGuid)).Invoke(nameof(string.ToString))
};
public static readonly ParameterProvider RepeatabilityFirstSent = new("repeatabilityFirstSent", FormattableStringHelpers.Empty, typeof(DateTimeOffset))
{
DefaultValue = Static(typeof(DateTimeOffset)).Property(nameof(DateTimeOffset.Now))
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Net.Http;
Expand All @@ -21,6 +22,13 @@ namespace Microsoft.Generator.CSharp.ClientModel.Providers
{
public class RestClientProvider : TypeProvider
{
private const string RepeatabilityRequestIdHeader = "Repeatability-Request-ID";
private const string RepeatabilityFirstSentHeader = "Repeatability-First-Sent";
private static readonly Dictionary<string, ParameterProvider> _knownSpecialHeaderParams = new(StringComparer.OrdinalIgnoreCase)
{
{ RepeatabilityRequestIdHeader, ScmKnownParameters.RepeatabilityRequestId },
{ RepeatabilityFirstSentHeader, ScmKnownParameters.RepeatabilityFirstSent }
};
private Dictionary<InputOperation, MethodProvider>? _methodCache;
private Dictionary<InputOperation, MethodProvider> MethodCache => _methodCache ??= [];

Expand Down Expand Up @@ -186,8 +194,6 @@ private PropertyProvider GetClassifier(InputOperation operation)

private IEnumerable<MethodBodyStatement> AppendHeaderParameters(ScopedApi<PipelineRequest> request, InputOperation operation, Dictionary<string, ParameterProvider> paramMap)
{
//TODO handle special headers like Repeatability-First-Sent which shouldn't be params but sent as DateTimeOffset.Now.ToString("R")
//https://github.com/microsoft/typespec/issues/3936
List<MethodBodyStatement> statements = new(operation.Parameters.Count);

foreach (var inputParameter in operation.Parameters)
Expand Down Expand Up @@ -328,6 +334,11 @@ private static void GetParamInfo(Dictionary<string, ParameterProvider> paramMap,
valueExpression = Literal((inputParam.Type as InputLiteralType)?.Value);
format = ClientModelPlugin.Instance.TypeFactory.GetSerializationFormat(inputParam.Type).ToFormatSpecifier();
}
else if (TryGetSpecialHeaderParam(inputParam, out var parameterProvider))
{
valueExpression = parameterProvider.DefaultValue!;
format = ClientModelPlugin.Instance.TypeFactory.GetSerializationFormat(inputParam.Type).ToFormatSpecifier();
}
else
{
var paramProvider = paramMap[inputParam.Name];
Expand All @@ -345,6 +356,17 @@ private static void GetParamInfo(Dictionary<string, ParameterProvider> paramMap,
}
}

private static bool TryGetSpecialHeaderParam(InputParameter inputParameter, [NotNullWhen(true)] out ParameterProvider? parameterProvider)
{
if (inputParameter.Location == RequestLocation.Header)
{
return _knownSpecialHeaderParams.TryGetValue(inputParameter.NameInRequest, out parameterProvider);
}

parameterProvider = null;
return false;
}

internal MethodProvider GetCreateRequestMethod(InputOperation operation)
{
_ = Methods; // Ensure methods are built
Expand All @@ -356,7 +378,7 @@ internal static List<ParameterProvider> GetMethodParameters(InputOperation opera
List<ParameterProvider> methodParameters = new();
foreach (InputParameter inputParam in operation.Parameters)
{
if (inputParam.Kind != InputOperationParameterKind.Method)
if (inputParam.Kind != InputOperationParameterKind.Method || TryGetSpecialHeaderParam(inputParam, out var _))
continue;

ParameterProvider? parameter = ClientModelPlugin.Instance.TypeFactory.CreateParameter(inputParam);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Generator.CSharp.ClientModel.Providers;
Expand All @@ -10,7 +11,7 @@
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;

namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers
namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.ClientProviders
{
public class RestClientProviderTests
{
Expand All @@ -35,7 +36,15 @@ public void TestRestClientMethods(InputOperation inputOperation)

var parameters = signature.Parameters;
Assert.IsNotNull(parameters);
Assert.AreEqual(inputOperation.Parameters.Count + 1, parameters.Count);
var specialHeaderParamCount = inputOperation.Parameters.Count(p => p.Location == RequestLocation.Header);
Assert.AreEqual(inputOperation.Parameters.Count - specialHeaderParamCount + 1, parameters.Count);

if (specialHeaderParamCount > 0)
{
Assert.IsFalse(parameters.Any(p =>
p.Name.Equals("repeatabilityFirstSent", StringComparison.OrdinalIgnoreCase) &&
p.Name.Equals("repeatabilityRequestId", StringComparison.OrdinalIgnoreCase)));
}
}

[Test]
Expand Down Expand Up @@ -97,10 +106,48 @@ public void ValidateProperties()
Assert.IsFalse(pipelineMessageClassifier2xxAnd4xx.Body.HasSetter);
}

[TestCaseSource(nameof(GetMethodParametersTestCases))]
public void TestGetMethodParameters(InputOperation inputOperation)
{
var methodParameters = RestClientProvider.GetMethodParameters(inputOperation);

Assert.IsTrue(methodParameters.Count > 0);

if (inputOperation.Parameters.Any(p => p.Location == RequestLocation.Header))
{
// validate no special header parameters are in the method parameters
Assert.IsFalse(methodParameters.Any(p =>
p.Name.Equals("repeatabilityFirstSent", StringComparison.OrdinalIgnoreCase) &&
p.Name.Equals("repeatabilityRequestId", StringComparison.OrdinalIgnoreCase)));
}
}

[Test]
public void ValidateClientWithSpecialHeaders()
{
var clientProvider = new ClientProvider(SingleOpInputClient);
var restClientProvider = new MockClientProvider(SingleOpInputClient, clientProvider);
var writer = new TypeProviderWriter(restClientProvider);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

private readonly static InputOperation BasicOperation = InputFactory.Operation(
"CreateMessage",
parameters:
[
InputFactory.Parameter(
"repeatabilityFirstSent",
new InputDateTimeType(DateTimeKnownEncoding.Rfc7231, "utcDateTime", "TypeSpec.utcDateTime", InputPrimitiveType.String),
nameInRequest: "repeatability-first-sent",
location: RequestLocation.Header,
isRequired: false),
InputFactory.Parameter(
"repeatabilityRequestId",
InputPrimitiveType.String,
nameInRequest: "repeatability-request-ID",
location: RequestLocation.Header,
isRequired: false),
InputFactory.Parameter("message", InputPrimitiveType.Boolean, isRequired: true)
]);

Expand All @@ -110,5 +157,26 @@ public void ValidateProperties()
[
new TestCaseData(BasicOperation)
];

private static IEnumerable<TestCaseData> GetMethodParametersTestCases =>
[
new TestCaseData(BasicOperation)
];

private class MockClientProvider : RestClientProvider
{
public MockClientProvider(InputClient inputClient, ClientProvider clientProvider) : base(inputClient, clientProvider) { }

protected override MethodProvider[] BuildMethods()
{
return [.. base.BuildMethods()];
}

protected override FieldProvider[] BuildFields() => [];
protected override ConstructorProvider[] BuildConstructors() => [];
protected override PropertyProvider[] BuildProperties() => [];

protected override TypeProvider[] BuildNestedTypes() => [];
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;

namespace sample.namespace
{
/// <summary></summary>
public partial class TestClient
{
internal global::System.ClientModel.Primitives.PipelineMessage CreateCreateMessageRequest(global::System.ClientModel.BinaryContent content, global::System.ClientModel.Primitives.RequestOptions options)
{
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage();
message.ResponseClassifier = PipelineMessageClassifier200;
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
request.Method = "GET";
global::sample.namespace.ClientUriBuilder uri = new global::sample.namespace.ClientUriBuilder();
uri.Reset(_endpoint);
request.Uri = uri.ToUri();
request.Headers.Set("repeatability-first-sent", global::System.DateTimeOffset.Now.ToString("R"));
request.Headers.Set("repeatability-request-ID", global::System.Guid.NewGuid().ToString());
request.Content = content;
message.Apply(options);
return message;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Microsoft.Generator.CSharp.ClientModel.Primitives;
using Microsoft.Generator.CSharp.Primitives;
using NUnit.Framework;
using static Microsoft.Generator.CSharp.Snippets.Snippet;

namespace Microsoft.Generator.CSharp.ClientModel.Tests
{
public class ScmKnownParametersTests
{
[Test]
public void BinaryDataParametersHasValidation()
[OneTimeSetUp]
public void Setup()
{
MockHelpers.LoadMockPlugin();

}

[Test]
public void BinaryDataParameterHasValidation()
{
var parameter = ScmKnownParameters.BinaryContent;
Assert.AreEqual(ParameterValidationType.AssertNotNull, parameter.Validation);
}

[Test]
public void RepeatabilityRequestIdParamHasDefaultValue()
{
var parameter = ScmKnownParameters.RepeatabilityRequestId;
var expectedDefaultValue = Static(typeof(Guid)).Invoke(nameof(Guid.NewGuid)).Invoke(nameof(string.ToString));
Assert.AreEqual(expectedDefaultValue, parameter.DefaultValue);
}

[Test]
public void RepeatabilityFirstSentParamHasDefaultValue()
{
var parameter = ScmKnownParameters.RepeatabilityFirstSent;
var expectedDefaultValue = Static(typeof(DateTimeOffset)).Property(nameof(DateTimeOffset.Now));
Assert.AreEqual(expectedDefaultValue, parameter.DefaultValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public static InputConstant Int64(long value)
public static InputParameter Parameter(
string name,
InputType type,
string? nameInRequest = null,
InputConstant? defaultValue = null,
RequestLocation location = RequestLocation.Body,
bool isRequired = false,
Expand All @@ -64,7 +65,7 @@ public static InputParameter Parameter(
{
return new InputParameter(
name,
name,
nameInRequest ?? name,
$"{name} description",
type,
location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ internal PipelineMessage CreateFriendlyModelRequest(RequestOptions options)
return message;
}

internal PipelineMessage CreateAddTimeHeaderRequest(DateTimeOffset repeatabilityFirstSent, RequestOptions options)
internal PipelineMessage CreateAddTimeHeaderRequest(RequestOptions options)
{
PipelineMessage message = Pipeline.CreateMessage();
message.ResponseClassifier = PipelineMessageClassifier204;
Expand All @@ -220,7 +220,7 @@ internal PipelineMessage CreateAddTimeHeaderRequest(DateTimeOffset repeatability
uri.Reset(_endpoint);
uri.AppendPath("/", false);
request.Uri = uri.ToUri();
request.Headers.Set("Repeatability-First-Sent", repeatabilityFirstSent.ToString("R"));
request.Headers.Set("Repeatability-First-Sent", DateTimeOffset.Now.ToString("R"));
message.Apply(options);
return message;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,13 +737,12 @@ public virtual async Task<ClientResult<Friend>> FriendlyModelAsync()
/// </item>
/// </list>
/// </summary>
/// <param name="repeatabilityFirstSent"></param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual ClientResult AddTimeHeader(DateTimeOffset repeatabilityFirstSent, RequestOptions options)
public virtual ClientResult AddTimeHeader(RequestOptions options)
{
using PipelineMessage message = CreateAddTimeHeaderRequest(repeatabilityFirstSent, options);
using PipelineMessage message = CreateAddTimeHeaderRequest(options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

Expand All @@ -755,30 +754,27 @@ public virtual ClientResult AddTimeHeader(DateTimeOffset repeatabilityFirstSent,
/// </item>
/// </list>
/// </summary>
/// <param name="repeatabilityFirstSent"></param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual async Task<ClientResult> AddTimeHeaderAsync(DateTimeOffset repeatabilityFirstSent, RequestOptions options)
public virtual async Task<ClientResult> AddTimeHeaderAsync(RequestOptions options)
{
using PipelineMessage message = CreateAddTimeHeaderRequest(repeatabilityFirstSent, options);
using PipelineMessage message = CreateAddTimeHeaderRequest(options);
return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

/// <summary> addTimeHeader. </summary>
/// <param name="repeatabilityFirstSent"></param>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
public virtual ClientResult AddTimeHeader(DateTimeOffset repeatabilityFirstSent)
public virtual ClientResult AddTimeHeader()
{
return AddTimeHeader(repeatabilityFirstSent, null);
return AddTimeHeader(null);
}

/// <summary> addTimeHeader. </summary>
/// <param name="repeatabilityFirstSent"></param>
/// <exception cref="ClientResultException"> Service returned a non-success status code. </exception>
public virtual async Task<ClientResult> AddTimeHeaderAsync(DateTimeOffset repeatabilityFirstSent)
public virtual async Task<ClientResult> AddTimeHeaderAsync()
{
return await AddTimeHeaderAsync(repeatabilityFirstSent, null).ConfigureAwait(false);
return await AddTimeHeaderAsync(null).ConfigureAwait(false);
}

/// <summary>
Expand Down

0 comments on commit 261fa52

Please sign in to comment.