diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 28df810885..d15b1ab6eb 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -1388,10 +1388,17 @@ private MethodBodyStatement CreateListSerializationStatement( ScopedApi array, SerializationFormat serializationFormat) { + // Handle ReadOnlyMemory serialization + bool isReadOnlySpan = array.Type.ElementType.IsFrameworkType && array.Type.ElementType.FrameworkType == typeof(ReadOnlySpan<>); + CSharpType itemType = isReadOnlySpan ? array.Type.ElementType.Arguments[0] : array.Type.Arguments[0]; + var collection = isReadOnlySpan + ? array.NullableStructValue(array.Type.ElementType).Property(nameof(ReadOnlyMemory.Span)) + : array; + return new[] { _utf8JsonWriterSnippet.WriteStartArray(), - new ForeachStatement("item", array, out VariableExpression item) + new ForeachStatement(itemType, "item", collection, false, out VariableExpression item) { TypeRequiresNullCheckInSerialization(item.Type) ? new IfStatement(item.Equal(Null)) { _utf8JsonWriterSnippet.WriteNullValue(), Continue } : MethodBodyStatement.Empty, @@ -1583,8 +1590,9 @@ private MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFor private static ScopedApi GetEnumerableExpression(ValueExpression expression, CSharpType enumerableType) { - CSharpType itemType = enumerableType.IsReadOnlyMemory ? new CSharpType(typeof(ReadOnlySpan<>), enumerableType.Arguments[0]) : - enumerableType.ElementType; + CSharpType itemType = enumerableType.IsReadOnlyMemory + ? new CSharpType(typeof(ReadOnlySpan<>), enumerableType.IsNullable, enumerableType.Arguments[0]) + : enumerableType.ElementType; return expression.As(new CSharpType(typeof(IEnumerable<>), itemType)); } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderCustomizationTests.cs index b4fae5cbc3..c5f7f8996d 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/ClientProviderCustomizationTests.cs @@ -114,5 +114,83 @@ public async Task CanAddMethodSameName() Assert.IsNull(customMethods[0].BodyExpression); Assert.AreEqual(string.Empty, customMethods[0].BodyStatements!.ToDisplayString()); } + + [Test] + public async Task CanReplaceOpMethod() + { + var inputOperation = InputFactory.Operation("HelloAgain", parameters: + [ + InputFactory.Parameter("p1", InputFactory.Array(InputPrimitiveType.String)) + ]); + var inputClient = InputFactory.Client("TestClient", operations: [inputOperation]); + var plugin = await MockHelpers.LoadMockPluginAsync( + clients: () => [inputClient], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + // Find the client provider + var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider); + Assert.IsNotNull(clientProvider); + + // The client provider method should not have a protocol method + var clientProviderMethods = clientProvider!.Methods; + Assert.AreEqual(3, clientProviderMethods.Count); + + bool hasBinaryContentParameter = clientProviderMethods + .Any(m => m.Signature.Name == "HelloAgain" && m.Signature.Parameters + .Any(p => p.Type.Equals(typeof(BinaryContent)))); + Assert.IsFalse(hasBinaryContentParameter); + + // The custom code view should contain the method + var customCodeView = clientProvider.CustomCodeView; + Assert.IsNotNull(customCodeView); + var customMethods = customCodeView!.Methods; + Assert.AreEqual(1, customMethods.Count); + Assert.AreEqual("HelloAgain", customMethods[0].Signature.Name); + Assert.IsNull(customMethods[0].BodyExpression); + Assert.AreEqual(string.Empty, customMethods[0].BodyStatements!.ToDisplayString()); + + } + + // Validates that a method with a struct parameter can be replaced + [Test] + public async Task CanReplaceStructMethod() + { + var inputOperation = InputFactory.Operation("HelloAgain", parameters: + [ + InputFactory.Parameter("p1", InputFactory.Model("myStruct", modelAsStruct: true), isRequired: false) + ]); + var inputClient = InputFactory.Client("TestClient", operations: [inputOperation]); + var plugin = await MockHelpers.LoadMockPluginAsync( + clients: () => [inputClient], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + // Find the client provider + var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider); + Assert.IsNotNull(clientProvider); + + // The client provider method should not have a protocol method + var clientProviderMethods = clientProvider!.Methods; + Assert.AreEqual(3, clientProviderMethods.Count); + + bool hasStructParam = clientProviderMethods + .Any(m => m.Signature.Name == "HelloAgain" && m.Signature.Parameters + .Any(p => p.Type.IsStruct)); + Assert.IsFalse(hasStructParam); + + // The custom code view should contain the method + var customCodeView = clientProvider.CustomCodeView; + Assert.IsNotNull(customCodeView); + + var customMethods = customCodeView!.Methods; + Assert.AreEqual(1, customMethods.Count); + Assert.AreEqual("HelloAgain", customMethods[0].Signature.Name); + + var customMethodParams = customMethods[0].Signature.Parameters; + Assert.AreEqual(1, customMethodParams.Count); + Assert.AreEqual("p1", customMethodParams[0].Name); + Assert.AreEqual("MyStruct", customMethodParams[0].Type.Name); + Assert.IsTrue(customMethodParams[0].Type.IsStruct); + Assert.IsTrue(customMethodParams[0].Type.IsNullable); + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceOpMethod/CanReplaceOpMethod.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceOpMethod/CanReplaceOpMethod.cs new file mode 100644 index 0000000000..fa3576d8aa --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceOpMethod/CanReplaceOpMethod.cs @@ -0,0 +1,22 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Sample +{ + /// + public partial class TestClient + { + public virtual ClientResult HelloAgain(BinaryContent content, RequestOptions options) + { + + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceStructMethod/CanReplaceStructMethod.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceStructMethod/CanReplaceStructMethod.cs new file mode 100644 index 0000000000..fd2b551fb5 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/ClientProviders/TestData/ClientProviderCustomizationTests/CanReplaceStructMethod/CanReplaceStructMethod.cs @@ -0,0 +1,22 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Sample +{ + /// + public partial class TestClient + { + public virtual ClientResult HelloAgain(MyStruct? p1) + { + + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/ClientPipelineExtensionsDefCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/ClientPipelineExtensionsDefCustomizationTests.cs new file mode 100644 index 0000000000..fe39159437 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/ClientPipelineExtensionsDefCustomizationTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Generator.CSharp.ClientModel.Providers; +using Microsoft.Generator.CSharp.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.Definitions +{ + public class ClientPipelineExtensionsDefCustomizationTests + { + [Test] + public async Task CanReplaceMethod() + { + var plugin = await MockHelpers.LoadMockPluginAsync(compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + // Find the extension definition + var definition = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientPipelineExtensionsDefinition); + Assert.IsNotNull(definition); + + // The definitions should not have the custom method + var definitionMethods = definition!.Methods; + Assert.AreEqual(3, definitionMethods.Count); + Assert.IsFalse(definitionMethods.Any(m => m.Signature.Name == "ProcessMessageAsync")); + + // The custom code view should contain the method + var customCodeView = definition.CustomCodeView; + Assert.IsNotNull(customCodeView); + var customMethods = customCodeView!.Methods; + Assert.AreEqual(1, customMethods.Count); + Assert.AreEqual("ProcessMessageAsync", customMethods[0].Signature.Name); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/TestData/ClientPipelineExtensionsDefCustomizationTests/CanReplaceMethod/CanReplaceMethod.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/TestData/ClientPipelineExtensionsDefCustomizationTests/CanReplaceMethod/CanReplaceMethod.cs new file mode 100644 index 0000000000..de170c6edd --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/Definitions/TestData/ClientPipelineExtensionsDefCustomizationTests/CanReplaceMethod/CanReplaceMethod.cs @@ -0,0 +1,25 @@ +#nullable disable + +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Threading.Tasks; + +namespace Sample +{ + internal static partial class ClientPipelineExtensions + { + public static async ValueTask ProcessMessageAsync(this ClientPipeline pipeline, PipelineMessage message, RequestOptions options) + { + await pipeline.SendAsync(message).ConfigureAwait(false); + + if (message.Response.IsError && (options?.ErrorOptions & ClientErrorBehaviors.NoThrow) != ClientErrorBehaviors.NoThrow) + { + // log instead of throw + Console.WriteLine("Error: " + message.Response); + } + + PipelineResponse response = message.BufferResponse ? message.Response : message.ExtractResponse(); + return response; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/SerializationCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/SerializationCustomizationTests.cs index 52c3c0bf76..b8751f67b0 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/SerializationCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/SerializationCustomizationTests.cs @@ -37,5 +37,37 @@ public async Task CanChangeEnumMemberName() var file = writer.Write(); Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); } + + [Test] + public async Task CanReplaceMethod() + { + var enumValues = new[] + { + InputFactory.EnumMember.Int32("Red", 1), + InputFactory.EnumMember.Int32("Green", 2), + InputFactory.EnumMember.Int32("Blue", 3) + }; + var inputEnum = InputFactory.Enum("mockInputModel", underlyingType: InputPrimitiveType.String, values: enumValues); + var plugin = await MockHelpers.LoadMockPluginAsync( + inputEnums: () => [inputEnum], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + var enumProvider = plugin.Object.OutputLibrary.TypeProviders.FirstOrDefault(t => t.IsEnum); + Assert.IsNotNull(enumProvider); + + var serializationProvider = enumProvider!.SerializationProviders.FirstOrDefault(); + Assert.IsNotNull(serializationProvider); + + var serializationProviderMethods = serializationProvider!.Methods; + Assert.AreEqual(1, serializationProviderMethods.Count); + Assert.IsFalse(serializationProviderMethods.Any(m => m.Signature.Name == "ToSerialString")); + + //The custom code view should contain the method + var customCodeView = serializationProvider.CustomCodeView; + Assert.IsNotNull(customCodeView); + var customMethods = customCodeView!.Methods; + Assert.AreEqual(1, customMethods.Count); + Assert.AreEqual("ToSerialString", customMethods[0].Signature.Name); + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/TestData/SerializationCustomizationTests/CanReplaceMethod/MockInputModel.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/TestData/SerializationCustomizationTests/CanReplaceMethod/MockInputModel.cs new file mode 100644 index 0000000000..7d2a2d73a7 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/EnumProvider/TestData/SerializationCustomizationTests/CanReplaceMethod/MockInputModel.cs @@ -0,0 +1,16 @@ +// + +#nullable disable + +using System; + +namespace Sample.Models +{ + internal static partial class MockInputModelExtensions + { + public static string ToSerialString(this MockInputModel value) => value switch + { + throw new ArgumentOutOfRangeException(nameof(value), value, "Unknown MockInputModel value.") + }; + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SerializationCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SerializationCustomizationTests.cs index 0ccd027663..71de185ae2 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SerializationCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SerializationCustomizationTests.cs @@ -22,7 +22,33 @@ public async Task CanChangePropertyName() InputFactory.Property("Prop1", InputFactory.Array(InputPrimitiveType.String)) }; - var inputModel = InputFactory.Model("mockInputModel", properties: props, usage: InputModelTypeUsage.Json); + var inputModel = InputFactory.Model("Model", properties: props, usage: InputModelTypeUsage.Json); + var plugin = await MockHelpers.LoadMockPluginAsync( + inputModels: () => [inputModel], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + var modelProvider = plugin.Object.OutputLibrary.TypeProviders.Single(t => t is ModelProvider); + var serializationProvider = modelProvider.SerializationProviders.Single(t => t is MrwSerializationTypeDefinition); + Assert.IsNotNull(serializationProvider); + Assert.AreEqual(0, serializationProvider!.Fields.Count); + + // validate the methods use the custom member name + var writer = new TypeProviderWriter(serializationProvider); + var file = writer.Write(); + Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); + } + + [Test] + public async Task ReadOnlyMemPropertyType() + { + var props = new[] + { + InputFactory.Property("Prop1", InputFactory.Array(InputPrimitiveType.String)), + InputFactory.Property("Prop2", InputFactory.Array(InputPrimitiveType.String)) + + }; + + var inputModel = InputFactory.Model("Model", properties: props, usage: InputModelTypeUsage.Json); var plugin = await MockHelpers.LoadMockPluginAsync( inputModels: () => [inputModel], compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName.cs index a31206c319..74a77a1b95 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName.cs @@ -12,9 +12,9 @@ namespace Sample.Models { /// - public partial class MockInputModel : global::System.ClientModel.Primitives.IJsonModel + public partial class Model : global::System.ClientModel.Primitives.IJsonModel { - void global::System.ClientModel.Primitives.IJsonModel.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + void global::System.ClientModel.Primitives.IJsonModel.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { writer.WriteStartObject(); this.JsonModelWriteCore(writer, options); @@ -25,16 +25,16 @@ public partial class MockInputModel : global::System.ClientModel.Primitives.IJso /// The client options for reading and writing models. protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { - string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; if ((format != "J")) { - throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format."); + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support writing '{format}' format."); } if (global::Sample.Optional.IsDefined(Prop2)) { writer.WritePropertyName("prop1"u8); writer.WriteStartArray(); - foreach (var item in Prop2) + foreach (string item in Prop2) { if ((item == null)) { @@ -62,22 +62,22 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite } } - global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options)); + global::Sample.Models.Model global::System.ClientModel.Primitives.IJsonModel.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Model)this.JsonModelCreateCore(ref reader, options)); /// The JSON reader. /// The client options for reading and writing models. - protected virtual global::Sample.Models.MockInputModel JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + protected virtual global::Sample.Models.Model JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { - string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; if ((format != "J")) { - throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{format}' format."); + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support reading '{format}' format."); } using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader); - return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options); + return global::Sample.Models.Model.DeserializeModel(document.RootElement, options); } - internal static global::Sample.Models.MockInputModel DeserializeMockInputModel(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + internal static global::Sample.Models.Model DeserializeModel(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null)) { @@ -113,57 +113,57 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite additionalBinaryDataProperties.Add(prop.Name, global::System.BinaryData.FromString(prop.Value.GetRawText())); } } - return new global::Sample.Models.MockInputModel(prop2, additionalBinaryDataProperties); + return new global::Sample.Models.Model(prop2, additionalBinaryDataProperties); } - global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options); + global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options); /// The client options for reading and writing models. protected virtual global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { - string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; switch (format) { case "J": return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options); default: - throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{options.Format}' format."); + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support writing '{options.Format}' format."); } } - global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options)); + global::Sample.Models.Model global::System.ClientModel.Primitives.IPersistableModel.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Model)this.PersistableModelCreateCore(data, options)); /// The data to parse. /// The client options for reading and writing models. - protected virtual global::Sample.Models.MockInputModel PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + protected virtual global::Sample.Models.Model PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) { - string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; switch (format) { case "J": using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data)) { - return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, options); + return global::Sample.Models.Model.DeserializeModel(document.RootElement, options); } default: - throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support reading '{options.Format}' format."); + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support reading '{options.Format}' format."); } } - string global::System.ClientModel.Primitives.IPersistableModel.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J"; + string global::System.ClientModel.Primitives.IPersistableModel.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J"; - /// The to serialize into . - public static implicit operator BinaryContent(global::Sample.Models.MockInputModel mockInputModel) + /// The to serialize into . + public static implicit operator BinaryContent(global::Sample.Models.Model model) { - return global::System.ClientModel.BinaryContent.Create(mockInputModel, global::Sample.ModelSerializationExtensions.WireOptions); + return global::System.ClientModel.BinaryContent.Create(model, global::Sample.ModelSerializationExtensions.WireOptions); } - /// The to deserialize the from. - public static explicit operator MockInputModel(global::System.ClientModel.ClientResult result) + /// The to deserialize the from. + public static explicit operator Model(global::System.ClientModel.ClientResult result) { using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse(); using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content); - return global::Sample.Models.MockInputModel.DeserializeMockInputModel(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions); + return global::Sample.Models.Model.DeserializeModel(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/MockInputModel.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/Model.cs similarity index 82% rename from packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/MockInputModel.cs rename to packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/Model.cs index b19814f8bf..ac986f73f0 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/MockInputModel.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/CanChangePropertyName/Model.cs @@ -4,7 +4,7 @@ namespace Sample.Models { - public partial class MockInputModel + public partial class Model { [CodeGenMember("Prop1")] public string[] Prop2 { get; set; } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType.cs new file mode 100644 index 0000000000..ba713810ca --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType.cs @@ -0,0 +1,179 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using Sample; + +namespace Sample.Models +{ + /// + public partial class Model : global::System.ClientModel.Primitives.IJsonModel + { + void global::System.ClientModel.Primitives.IJsonModel.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + writer.WriteStartObject(); + this.JsonModelWriteCore(writer, options); + writer.WriteEndObject(); + } + + /// The JSON writer. + /// The client options for reading and writing models. + protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if ((format != "J")) + { + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support writing '{format}' format."); + } + writer.WritePropertyName("prop1"u8); + writer.WriteStartArray(); + foreach (byte item in NewProp1.Span) + { + writer.WriteNumberValue(item); + } + writer.WriteEndArray(); + if (global::Sample.Optional.IsDefined(NewProp2)) + { + writer.WritePropertyName("prop2"u8); + writer.WriteStartArray(); + foreach (byte item in NewProp2.Value.Span) + { + writer.WriteNumberValue(item); + } + writer.WriteEndArray(); + } + if (((options.Format != "W") && (_additionalBinaryDataProperties != null))) + { + foreach (var item in _additionalBinaryDataProperties) + { + writer.WritePropertyName(item.Key); +#if NET6_0_OR_GREATER + writer.WriteRawValue(item.Value); +#else + using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(item.Value)) + { + global::System.Text.Json.JsonSerializer.Serialize(writer, document.RootElement); + } +#endif + } + } + } + + global::Sample.Models.Model global::System.ClientModel.Primitives.IJsonModel.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Model)this.JsonModelCreateCore(ref reader, options)); + + /// The JSON reader. + /// The client options for reading and writing models. + protected virtual global::Sample.Models.Model JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if ((format != "J")) + { + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support reading '{format}' format."); + } + using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader); + return global::Sample.Models.Model.DeserializeModel(document.RootElement, options); + } + + internal static global::Sample.Models.Model DeserializeModel(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null)) + { + return null; + } + global::System.ReadOnlyMemory newProp1 = default; + global::System.ReadOnlyMemory? newProp2 = default; + global::System.Collections.Generic.IDictionary additionalBinaryDataProperties = new global::Sample.ChangeTrackingDictionary(); + foreach (var prop in element.EnumerateObject()) + { + if (prop.NameEquals("prop1"u8)) + { + if ((prop.Value.ValueKind == global::System.Text.Json.JsonValueKind.Null)) + { + continue; + } + global::System.Collections.Generic.List array = new global::System.Collections.Generic.List(); + foreach (var item in prop.Value.EnumerateArray()) + { + array.Add(item.GetByte()); + } + newProp1 = array; + continue; + } + if (prop.NameEquals("prop2"u8)) + { + if ((prop.Value.ValueKind == global::System.Text.Json.JsonValueKind.Null)) + { + continue; + } + global::System.Collections.Generic.List array = new global::System.Collections.Generic.List(); + foreach (var item in prop.Value.EnumerateArray()) + { + array.Add(item.GetByte()); + } + newProp2 = array; + continue; + } + if ((options.Format != "W")) + { + additionalBinaryDataProperties.Add(prop.Name, global::System.BinaryData.FromString(prop.Value.GetRawText())); + } + } + return new global::Sample.Models.Model(newProp1, newProp2, additionalBinaryDataProperties); + } + + global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options); + + /// The client options for reading and writing models. + protected virtual global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + switch (format) + { + case "J": + return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options); + default: + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support writing '{options.Format}' format."); + } + } + + global::Sample.Models.Model global::System.ClientModel.Primitives.IPersistableModel.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Model)this.PersistableModelCreateCore(data, options)); + + /// The data to parse. + /// The client options for reading and writing models. + protected virtual global::Sample.Models.Model PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + switch (format) + { + case "J": + using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data)) + { + return global::Sample.Models.Model.DeserializeModel(document.RootElement, options); + } + default: + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Model)} does not support reading '{options.Format}' format."); + } + } + + string global::System.ClientModel.Primitives.IPersistableModel.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J"; + + /// The to serialize into . + public static implicit operator BinaryContent(global::Sample.Models.Model model) + { + return global::System.ClientModel.BinaryContent.Create(model, global::Sample.ModelSerializationExtensions.WireOptions); + } + + /// The to deserialize the from. + public static explicit operator Model(global::System.ClientModel.ClientResult result) + { + using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse(); + using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content); + return global::Sample.Models.Model.DeserializeModel(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType/ReadOnlyMemPropertyType.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType/ReadOnlyMemPropertyType.cs new file mode 100644 index 0000000000..a613a22a14 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/SerializationCustomizationTests/ReadOnlyMemPropertyType/ReadOnlyMemPropertyType.cs @@ -0,0 +1,16 @@ +#nullable disable + +using Microsoft.Generator.CSharp.Customization; +using System; +using System.Collections.Generic; + +namespace Sample.Models +{ + public partial class Model + { + [CodeGenMember("Prop1")] + public ReadOnlyMemory NewProp1 { get; set; } + [CodeGenMember("Prop2")] + public ReadOnlyMemory? NewProp2 { get; set; } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs index d60dbc1b06..09aa2f1c2e 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Xml; using System.Xml.Linq; @@ -15,6 +14,7 @@ namespace Microsoft.Generator.CSharp.Providers { internal sealed class NamedTypeSymbolProvider : TypeProvider { + private const string GlobalPrefix = "global::"; private INamedTypeSymbol _namedTypeSymbol; public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol) @@ -286,14 +286,11 @@ private CSharpType GetCSharpType(ITypeSymbol typeSymbol) Type? type = System.Type.GetType(fullyQualifiedName); if (type is null) { - if (typeSymbol.TypeKind == TypeKind.Error) - throw new InvalidOperationException($"Unable to convert ITypeSymbol: {fullyQualifiedName} to a CSharpType in {Name}"); - return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol); } CSharpType result = new CSharpType(type); - if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType) + if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType && !result.IsNullable) { return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)]); } @@ -306,19 +303,34 @@ private CSharpType ConstructCSharpTypeFromSymbol( string fullyQualifiedName, INamedTypeSymbol? namedTypeSymbol) { + var typeArg = namedTypeSymbol?.TypeArguments.FirstOrDefault(); bool isValueType = typeSymbol.IsValueType; bool isEnum = typeSymbol.TypeKind == TypeKind.Enum; - var pieces = fullyQualifiedName.Split('.'); + bool isNullable = typeSymbol.NullableAnnotation == NullableAnnotation.Annotated; + bool isNullableUnknownType = isNullable && typeArg?.TypeKind == TypeKind.Error; + string name = isNullableUnknownType ? fullyQualifiedName : typeSymbol.Name; + string[] pieces = fullyQualifiedName.Split('.'); + + // handle nullables + if (isNullable) + { + // System.Nullable`1[T] -> T + name = typeArg != null ? GetFullyQualifiedName(typeArg) : fullyQualifiedName; + pieces = name.Split('.'); + } + return new CSharpType( - typeSymbol.Name, + name, string.Join('.', pieces.Take(pieces.Length - 1)), isValueType, - typeSymbol.NullableAnnotation == NullableAnnotation.Annotated, + isNullable, typeSymbol.ContainingType is not null ? GetCSharpType(typeSymbol.ContainingType) : null, - namedTypeSymbol is not null ? [.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)] : [], + namedTypeSymbol is not null && !isNullableUnknownType ? [.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)] : [], typeSymbol.DeclaredAccessibility == Accessibility.Public, isValueType && !isEnum, - baseType: typeSymbol.BaseType is not null && typeSymbol.BaseType.TypeKind != TypeKind.Error ? GetCSharpType(typeSymbol.BaseType) : null, + baseType: typeSymbol.BaseType is not null && typeSymbol.BaseType.TypeKind != TypeKind.Error && !isNullableUnknownType + ? GetCSharpType(typeSymbol.BaseType) + : null, underlyingEnumType: namedTypeSymbol is not null && namedTypeSymbol.EnumUnderlyingType is not null ? GetCSharpType(namedTypeSymbol.EnumUnderlyingType).FrameworkType : null); @@ -368,13 +380,34 @@ private static string GetFullyQualifiedName(ITypeSymbol typeSymbol) // Handle array types if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) { - var elementType = GetFullyQualifiedName(arrayTypeSymbol.ElementType); - return elementType + "[]"; + return GetFullyQualifiedName(arrayTypeSymbol.ElementType) + "[]"; } // Handle generic types if (typeSymbol is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) { + // Handle nullable types + if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated && !IsCollectionType(namedTypeSymbol)) + { + const string nullableTypeName = "System.Nullable"; + var argTypeSymbol = namedTypeSymbol.TypeArguments.FirstOrDefault(); + + if (argTypeSymbol != null) + { + if (argTypeSymbol.TypeKind == TypeKind.Error) + { + return GetFullyQualifiedName(argTypeSymbol); + } + + string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(arg => "[" + GetFullyQualifiedName(arg) + "]")]; + return $"{nullableTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; + } + } + else if (namedTypeSymbol.TypeArguments.Length > 0 && !IsCollectionType(namedTypeSymbol)) + { + return GetNonNullableGenericTypeName(namedTypeSymbol); + } + var typeNameSpan = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).AsSpan(); var start = typeNameSpan.IndexOf(':') + 2; var end = typeNameSpan.IndexOf('<'); @@ -386,11 +419,37 @@ private static string GetFullyQualifiedName(ITypeSymbol typeSymbol) return GetFullyQualifiedNameFromDisplayString(typeSymbol); } + private static string GetNonNullableGenericTypeName(INamedTypeSymbol namedTypeSymbol) + { + string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(GetFullyQualifiedName)]; + var fullName = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + // Remove the type arguments from the fully qualified name + var typeArgumentStartIndex = fullName.IndexOf('<'); + var genericTypeName = typeArgumentStartIndex >= 0 ? fullName.Substring(0, typeArgumentStartIndex) : fullName; + + // Remove global:: prefix + if (genericTypeName.StartsWith(GlobalPrefix, StringComparison.Ordinal)) + { + genericTypeName = genericTypeName.Substring(GlobalPrefix.Length); + } + + return $"{genericTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; + } + + private static bool IsCollectionType(INamedTypeSymbol typeSymbol) + { + // Check if the type implements IEnumerable, ICollection, or IEnumerable + return typeSymbol.AllInterfaces.Any(i => + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T || + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_ICollection_T || + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_IEnumerable); + } + private static string GetFullyQualifiedNameFromDisplayString(ISymbol typeSymbol) { - const string globalPrefix = "global::"; var fullyQualifiedName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - return fullyQualifiedName.StartsWith(globalPrefix, StringComparison.Ordinal) ? fullyQualifiedName.Substring(globalPrefix.Length) : fullyQualifiedName; + return fullyQualifiedName.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? fullyQualifiedName.Substring(GlobalPrefix.Length) : fullyQualifiedName; } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs index d70019b3d3..24d2ed519d 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs @@ -424,7 +424,7 @@ private static bool IsMatch(TypeProvider enclosingType, MethodSignatureBase sign private static bool IsMatch(MethodSignatureBase customMethod, MethodSignatureBase method) { - if (customMethod.Parameters.Count != method.Parameters.Count || customMethod.Name != method.Name) + if (customMethod.Parameters.Count != method.Parameters.Count || !customMethod.Name.EndsWith(method.Name)) { return false; } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs new file mode 100644 index 0000000000..2d7e173abb --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Generator.CSharp.Input; +using Microsoft.Generator.CSharp.Providers; +using Microsoft.Generator.CSharp.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.Generator.CSharp.Tests.Providers.ModelFactories +{ + public class ModelFactoriesCustomizationTests + { + [Test] + public async Task CanReplaceModelMethod() + { + var plugin = await MockHelpers.LoadMockPluginAsync( + inputModelTypes: [ + InputFactory.Model( + "mockInputModel", + properties: + [ + InputFactory.Property("Prop1", InputPrimitiveType.String), + InputFactory.Property("OptionalBool", InputPrimitiveType.Boolean, isRequired: false) + ]), + InputFactory.Model( + "otherModel", + properties: [InputFactory.Property("Prop2", InputPrimitiveType.String)]), + ], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + // Find the model factory provider + var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + + // The model factory method should be replaced + var modelFactoryMethods = modelFactory!.Methods; + Assert.AreEqual(1, modelFactoryMethods.Count); + Assert.AreEqual("OtherModel", modelFactoryMethods[0].Signature.Name); + + var customCodeView = modelFactory.CustomCodeView; + Assert.IsNotNull(customCodeView); + + // The custom code view should contain the method + var customMethods = customCodeView!.Methods; + Assert.AreEqual(1, customMethods.Count); + Assert.AreEqual("MockInputModel", customMethods[0].Signature.Name); + Assert.IsNull(customMethods[0].BodyExpression); + Assert.AreEqual(string.Empty, customMethods[0].BodyStatements!.ToDisplayString()); + } + + [Test] + public async Task DoesNotReplaceMethodIfNotCustomized() + { + var plugin = MockHelpers.LoadMockPlugin( + inputModelTypes: [ + InputFactory.Model( + "mockInputModel", + // specify a different property to ensure the method is not replaced + properties: [InputFactory.Property("Prop2", InputPrimitiveType.String)]) + ]); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + // Find the model factory provider + var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + + // The model factory method should not be replaced + var modelFactoryMethods = modelFactory!.Methods; + Assert.AreEqual(1, modelFactoryMethods.Count); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanReplaceModelMethod/SampleNamespaceModelFactory.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanReplaceModelMethod/SampleNamespaceModelFactory.cs new file mode 100644 index 0000000000..b952d5c2fb --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanReplaceModelMethod/SampleNamespaceModelFactory.cs @@ -0,0 +1,13 @@ +using Microsoft.Generator.CSharp.Customization; +using System; + +namespace Sample.Models +{ + public static partial class SampleNamespaceModelFactory + { + public static MockInputModel MockInputModel(string prop1 = default, bool? optionalBool = default) + { + + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/NamedTypeSymbolProviders/NamedTypeSymbolProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/NamedTypeSymbolProviders/NamedTypeSymbolProviderTests.cs index 3e526c0533..89cd6e5d37 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/NamedTypeSymbolProviders/NamedTypeSymbolProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/NamedTypeSymbolProviders/NamedTypeSymbolProviderTests.cs @@ -20,11 +20,11 @@ public class NamedTypeSymbolProviderTests public NamedTypeSymbolProviderTests() { - var compilation = CompilationHelper.LoadCompilation([new NamedSymbol(), new PropertyType()]); + _namedSymbol = new NamedSymbol(); + var compilation = CompilationHelper.LoadCompilation([_namedSymbol, new PropertyType()]); var iNamedSymbol = GetSymbol(compilation.Assembly.Modules.First().GlobalNamespace, "NamedSymbol"); _namedTypeSymbolProvider = new NamedTypeSymbolProvider(iNamedSymbol!); - _namedSymbol = new NamedSymbol(); } [Test] @@ -66,6 +66,63 @@ public void ValidateProperties() } } + [TestCase(typeof(int))] + [TestCase(typeof(string))] + [TestCase(typeof(double?))] + [TestCase(typeof(float?))] + [TestCase(typeof(PropertyType))] + [TestCase(typeof(IList))] + [TestCase(typeof(IList))] + [TestCase(typeof(IList))] + [TestCase(typeof(ReadOnlyMemory?))] + [TestCase(typeof(ReadOnlyMemory))] + [TestCase(typeof(ReadOnlyMemory))] + [TestCase(typeof(IEnumerable))] + [TestCase(typeof(IEnumerable))] + [TestCase(typeof(IEnumerable))] + [TestCase(typeof(string[]))] + [TestCase(typeof(IDictionary))] + [TestCase(typeof(BinaryData))] + public void ValidatePropertyTypes(Type propertyType) + { + // setup + var namedSymbol = new NamedSymbol(propertyType); + _namedSymbol = namedSymbol; + var compilation = CompilationHelper.LoadCompilation([namedSymbol, new PropertyType()]); + var iNamedSymbol = GetSymbol(compilation.Assembly.Modules.First().GlobalNamespace, "NamedSymbol"); + + _namedTypeSymbolProvider = new NamedTypeSymbolProvider(iNamedSymbol!); + + Assert.AreEqual(_namedSymbol.Properties.Count, _namedTypeSymbolProvider.Properties.Count); + + var property = _namedTypeSymbolProvider.Properties.FirstOrDefault(); + Assert.IsNotNull(property); + + bool isNullable = Nullable.GetUnderlyingType(propertyType) != null; + var expectedType = new CSharpType(propertyType, isNullable); + var propertyCSharpType = property!.Type; + + Assert.AreEqual(expectedType.Name, propertyCSharpType.Name); + Assert.AreEqual(expectedType.IsNullable, propertyCSharpType.IsNullable); + Assert.AreEqual(expectedType.IsList, propertyCSharpType.IsList); + Assert.AreEqual(expectedType.Arguments.Count, propertyCSharpType.Arguments.Count); + Assert.AreEqual(expectedType.IsCollection, propertyCSharpType.IsCollection); + + for (var i = 0; i < expectedType.Arguments.Count; i++) + { + Assert.AreEqual(expectedType.Arguments[i].Name, propertyCSharpType.Arguments[i].Name); + Assert.AreEqual(expectedType.Arguments[i].IsNullable, propertyCSharpType.Arguments[i].IsNullable); + } + + // validate the underlying types aren't nullable + if (isNullable && expectedType.IsFrameworkType) + { + var underlyingType = propertyCSharpType.FrameworkType; + Assert.IsTrue(Nullable.GetUnderlyingType(underlyingType) == null); + } + + } + [Test] public void ValidateMethods() { @@ -134,12 +191,18 @@ public void ValidateFields() private class NamedSymbol : TypeProvider { + private readonly Type? _propertyType; protected override string BuildRelativeFilePath() => "."; protected override string BuildName() => "NamedSymbol"; protected override string GetNamespace() => CodeModelPlugin.Instance.Configuration.ModelNamespace; + public NamedSymbol(Type? propertyType = null) : base() + { + _propertyType = propertyType; + } + protected override FieldProvider[] BuildFields() { return @@ -153,12 +216,20 @@ protected override FieldProvider[] BuildFields() protected override PropertyProvider[] BuildProperties() { + if (_propertyType == null) + { + return + [ + new PropertyProvider($"IntProperty property", MethodSignatureModifiers.Public, typeof(int), "IntProperty", new AutoPropertyBody(true), this), + new PropertyProvider($"StringProperty property no setter", MethodSignatureModifiers.Public, typeof(string), "StringProperty", new AutoPropertyBody(false), this), + new PropertyProvider($"InternalStringProperty property no setter", MethodSignatureModifiers.Public, typeof(string), "InternalStringProperty", new AutoPropertyBody(false), this), + new PropertyProvider($"PropertyTypeProperty property", MethodSignatureModifiers.Public, new PropertyType().Type, "PropertyTypeProperty", new AutoPropertyBody(true), this), + ]; + } + return [ - new PropertyProvider($"IntProperty property", MethodSignatureModifiers.Public, typeof(int), "IntProperty", new AutoPropertyBody(true), this), - new PropertyProvider($"StringProperty property no setter", MethodSignatureModifiers.Public, typeof(string), "StringProperty", new AutoPropertyBody(false), this), - new PropertyProvider($"InternalStringProperty property no setter", MethodSignatureModifiers.Public, typeof(string), "InternalStringProperty", new AutoPropertyBody(false), this), - new PropertyProvider($"PropertyTypeProperty property", MethodSignatureModifiers.Public, new PropertyType().Type, "PropertyTypeProperty", new AutoPropertyBody(true), this), + new PropertyProvider($"p1", MethodSignatureModifiers.Public, _propertyType, "P1", new AutoPropertyBody(true), this) ]; } diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/RoundTripModel.Serialization.cs b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/RoundTripModel.Serialization.cs index e4c8f52445..157a5f4103 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/RoundTripModel.Serialization.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/RoundTripModel.Serialization.cs @@ -40,7 +40,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit writer.WriteStringValue(RequiredInt.ToString()); writer.WritePropertyName("requiredCollection"u8); writer.WriteStartArray(); - foreach (var item in RequiredCollection) + foreach (StringFixedEnum item in RequiredCollection) { writer.WriteStringValue(item.ToSerialString()); } @@ -64,7 +64,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("intExtensibleEnumCollection"u8); writer.WriteStartArray(); - foreach (var item in IntExtensibleEnumCollection) + foreach (IntExtensibleEnum item in IntExtensibleEnumCollection) { writer.WriteNumberValue(item.ToSerialInt32()); } @@ -84,7 +84,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("floatExtensibleEnumCollection"u8); writer.WriteStartArray(); - foreach (var item in FloatExtensibleEnumCollection) + foreach (FloatExtensibleEnum item in FloatExtensibleEnumCollection) { writer.WriteNumberValue(item.ToSerialSingle()); } @@ -104,7 +104,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("floatFixedEnumCollection"u8); writer.WriteStartArray(); - foreach (var item in FloatFixedEnumCollection) + foreach (FloatFixedEnum item in FloatFixedEnumCollection) { writer.WriteNumberValue(item.ToSerialSingle()); } @@ -119,7 +119,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("intFixedEnumCollection"u8); writer.WriteStartArray(); - foreach (var item in IntFixedEnumCollection) + foreach (IntFixedEnum item in IntFixedEnumCollection) { writer.WriteNumberValue((int)item); } diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/Thing.Serialization.cs b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/Thing.Serialization.cs index 2d140b1a4a..4e411543c5 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/Thing.Serialization.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/Models/Thing.Serialization.cs @@ -81,7 +81,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("optionalNullableList"u8); writer.WriteStartArray(); - foreach (var item in OptionalNullableList) + foreach (int item in OptionalNullableList) { writer.WriteNumberValue(item); } @@ -96,7 +96,7 @@ protected virtual void JsonModelWriteCore(Utf8JsonWriter writer, ModelReaderWrit { writer.WritePropertyName("requiredNullableList"u8); writer.WriteStartArray(); - foreach (var item in RequiredNullableList) + foreach (int item in RequiredNullableList) { writer.WriteNumberValue(item); }