Skip to content

Commit

Permalink
Support replacing methods (#4547)
Browse files Browse the repository at this point in the history
This PR adds support for replacing generated methods with custom code.
It also fixes several issues with constructing nullable types from
custom code.

fixes: #4472
  • Loading branch information
jorgerangel-msft authored Oct 2, 2024
1 parent f549f60 commit 7e5e3a6
Show file tree
Hide file tree
Showing 20 changed files with 742 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1388,10 +1388,17 @@ private MethodBodyStatement CreateListSerializationStatement(
ScopedApi array,
SerializationFormat serializationFormat)
{
// Handle ReadOnlyMemory<T> serialization
bool isReadOnlySpan = array.Type.ElementType.IsFrameworkType && array.Type.ElementType.FrameworkType == typeof(ReadOnlySpan<>);
CSharpType itemType = isReadOnlySpan ? array.Type.ElementType.Arguments[0] : array.Type.Arguments[0];
var collection = isReadOnlySpan
? array.NullableStructValue(array.Type.ElementType).Property(nameof(ReadOnlyMemory<byte>.Span))
: array;

return new[]
{
_utf8JsonWriterSnippet.WriteStartArray(),
new ForeachStatement("item", array, out VariableExpression item)
new ForeachStatement(itemType, "item", collection, false, out VariableExpression item)
{
TypeRequiresNullCheckInSerialization(item.Type) ?
new IfStatement(item.Equal(Null)) { _utf8JsonWriterSnippet.WriteNullValue(), Continue } : MethodBodyStatement.Empty,
Expand Down Expand Up @@ -1583,8 +1590,9 @@ private MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFor

private static ScopedApi GetEnumerableExpression(ValueExpression expression, CSharpType enumerableType)
{
CSharpType itemType = enumerableType.IsReadOnlyMemory ? new CSharpType(typeof(ReadOnlySpan<>), enumerableType.Arguments[0]) :
enumerableType.ElementType;
CSharpType itemType = enumerableType.IsReadOnlyMemory
? new CSharpType(typeof(ReadOnlySpan<>), enumerableType.IsNullable, enumerableType.Arguments[0])
: enumerableType.ElementType;

return expression.As(new CSharpType(typeof(IEnumerable<>), itemType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,83 @@ public async Task CanAddMethodSameName()
Assert.IsNull(customMethods[0].BodyExpression);
Assert.AreEqual(string.Empty, customMethods[0].BodyStatements!.ToDisplayString());
}

[Test]
public async Task CanReplaceOpMethod()
{
var inputOperation = InputFactory.Operation("HelloAgain", parameters:
[
InputFactory.Parameter("p1", InputFactory.Array(InputPrimitiveType.String))
]);
var inputClient = InputFactory.Client("TestClient", operations: [inputOperation]);
var plugin = await MockHelpers.LoadMockPluginAsync(
clients: () => [inputClient],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

// Find the client provider
var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider);
Assert.IsNotNull(clientProvider);

// The client provider method should not have a protocol method
var clientProviderMethods = clientProvider!.Methods;
Assert.AreEqual(3, clientProviderMethods.Count);

bool hasBinaryContentParameter = clientProviderMethods
.Any(m => m.Signature.Name == "HelloAgain" && m.Signature.Parameters
.Any(p => p.Type.Equals(typeof(BinaryContent))));
Assert.IsFalse(hasBinaryContentParameter);

// The custom code view should contain the method
var customCodeView = clientProvider.CustomCodeView;
Assert.IsNotNull(customCodeView);
var customMethods = customCodeView!.Methods;
Assert.AreEqual(1, customMethods.Count);
Assert.AreEqual("HelloAgain", customMethods[0].Signature.Name);
Assert.IsNull(customMethods[0].BodyExpression);
Assert.AreEqual(string.Empty, customMethods[0].BodyStatements!.ToDisplayString());

}

// Validates that a method with a struct parameter can be replaced
[Test]
public async Task CanReplaceStructMethod()
{
var inputOperation = InputFactory.Operation("HelloAgain", parameters:
[
InputFactory.Parameter("p1", InputFactory.Model("myStruct", modelAsStruct: true), isRequired: false)
]);
var inputClient = InputFactory.Client("TestClient", operations: [inputOperation]);
var plugin = await MockHelpers.LoadMockPluginAsync(
clients: () => [inputClient],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

// Find the client provider
var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider);
Assert.IsNotNull(clientProvider);

// The client provider method should not have a protocol method
var clientProviderMethods = clientProvider!.Methods;
Assert.AreEqual(3, clientProviderMethods.Count);

bool hasStructParam = clientProviderMethods
.Any(m => m.Signature.Name == "HelloAgain" && m.Signature.Parameters
.Any(p => p.Type.IsStruct));
Assert.IsFalse(hasStructParam);

// The custom code view should contain the method
var customCodeView = clientProvider.CustomCodeView;
Assert.IsNotNull(customCodeView);

var customMethods = customCodeView!.Methods;
Assert.AreEqual(1, customMethods.Count);
Assert.AreEqual("HelloAgain", customMethods[0].Signature.Name);

var customMethodParams = customMethods[0].Signature.Parameters;
Assert.AreEqual(1, customMethodParams.Count);
Assert.AreEqual("p1", customMethodParams[0].Name);
Assert.AreEqual("MyStruct", customMethodParams[0].Type.Name);
Assert.IsTrue(customMethodParams[0].Type.IsStruct);
Assert.IsTrue(customMethodParams[0].Type.IsNullable);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace Sample
{
/// <summary></summary>
public partial class TestClient
{
public virtual ClientResult HelloAgain(BinaryContent content, RequestOptions options)
{

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace Sample
{
/// <summary></summary>
public partial class TestClient
{
public virtual ClientResult HelloAgain(MyStruct? p1)
{

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Linq;
using System.Threading.Tasks;
using Microsoft.Generator.CSharp.ClientModel.Providers;
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;

namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.Definitions
{
public class ClientPipelineExtensionsDefCustomizationTests
{
[Test]
public async Task CanReplaceMethod()
{
var plugin = await MockHelpers.LoadMockPluginAsync(compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

// Find the extension definition
var definition = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientPipelineExtensionsDefinition);
Assert.IsNotNull(definition);

// The definitions should not have the custom method
var definitionMethods = definition!.Methods;
Assert.AreEqual(3, definitionMethods.Count);
Assert.IsFalse(definitionMethods.Any(m => m.Signature.Name == "ProcessMessageAsync"));

// The custom code view should contain the method
var customCodeView = definition.CustomCodeView;
Assert.IsNotNull(customCodeView);
var customMethods = customCodeView!.Methods;
Assert.AreEqual(1, customMethods.Count);
Assert.AreEqual("ProcessMessageAsync", customMethods[0].Signature.Name);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#nullable disable

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading.Tasks;

namespace Sample
{
internal static partial class ClientPipelineExtensions
{
public static async ValueTask<PipelineResponse> ProcessMessageAsync(this ClientPipeline pipeline, PipelineMessage message, RequestOptions options)
{
await pipeline.SendAsync(message).ConfigureAwait(false);

if (message.Response.IsError && (options?.ErrorOptions & ClientErrorBehaviors.NoThrow) != ClientErrorBehaviors.NoThrow)
{
// log instead of throw
Console.WriteLine("Error: " + message.Response);
}

PipelineResponse response = message.BufferResponse ? message.Response : message.ExtractResponse();
return response;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,37 @@ public async Task CanChangeEnumMemberName()
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
public async Task CanReplaceMethod()
{
var enumValues = new[]
{
InputFactory.EnumMember.Int32("Red", 1),
InputFactory.EnumMember.Int32("Green", 2),
InputFactory.EnumMember.Int32("Blue", 3)
};
var inputEnum = InputFactory.Enum("mockInputModel", underlyingType: InputPrimitiveType.String, values: enumValues);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputEnums: () => [inputEnum],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var enumProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t.IsEnum);
Assert.IsNotNull(enumProvider);

var serializationProvider = enumProvider!.SerializationProviders.FirstOrDefault();
Assert.IsNotNull(serializationProvider);

var serializationProviderMethods = serializationProvider!.Methods;
Assert.AreEqual(1, serializationProviderMethods.Count);
Assert.IsFalse(serializationProviderMethods.Any(m => m.Signature.Name == "ToSerialString"));

//The custom code view should contain the method
var customCodeView = serializationProvider.CustomCodeView;
Assert.IsNotNull(customCodeView);
var customMethods = customCodeView!.Methods;
Assert.AreEqual(1, customMethods.Count);
Assert.AreEqual("ToSerialString", customMethods[0].Signature.Name);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// <auto-generated/>

#nullable disable

using System;

namespace Sample.Models
{
internal static partial class MockInputModelExtensions
{
public static string ToSerialString(this MockInputModel value) => value switch
{
throw new ArgumentOutOfRangeException(nameof(value), value, "Unknown MockInputModel value.")
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,33 @@ public async Task CanChangePropertyName()
InputFactory.Property("Prop1", InputFactory.Array(InputPrimitiveType.String))
};

var inputModel = InputFactory.Model("mockInputModel", properties: props, usage: InputModelTypeUsage.Json);
var inputModel = InputFactory.Model("Model", properties: props, usage: InputModelTypeUsage.Json);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [inputModel],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelProvider = plugin.Object.OutputLibrary.TypeProviders.Single(t => t is ModelProvider);
var serializationProvider = modelProvider.SerializationProviders.Single(t => t is MrwSerializationTypeDefinition);
Assert.IsNotNull(serializationProvider);
Assert.AreEqual(0, serializationProvider!.Fields.Count);

// validate the methods use the custom member name
var writer = new TypeProviderWriter(serializationProvider);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
public async Task ReadOnlyMemPropertyType()
{
var props = new[]
{
InputFactory.Property("Prop1", InputFactory.Array(InputPrimitiveType.String)),
InputFactory.Property("Prop2", InputFactory.Array(InputPrimitiveType.String))

};

var inputModel = InputFactory.Model("Model", properties: props, usage: InputModelTypeUsage.Json);
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModels: () => [inputModel],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());
Expand Down
Loading

0 comments on commit 7e5e3a6

Please sign in to comment.