From 80f5c4c35ff0de844d1f38b66ac1e03919628110 Mon Sep 17 00:00:00 2001 From: Jorge Rangel <102122018+jorgerangel-msft@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:39:19 -0500 Subject: [PATCH] [http-client-csharp] fix: use correct custom ctor in model factory (#4921) This PR fixes an issue where the model factory method for a model was using the incorrect full constructor when the full constructor was suppressed and customized. fixes: https://github.com/microsoft/typespec/issues/4830 --- .../src/Providers/CanonicalTypeProvider.cs | 6 +- .../src/Providers/ModelFactoryProvider.cs | 56 +++++++++++++++---- .../src/Snippets/Snippet.cs | 2 + .../ModelFactories/DiscriminatorTests.cs | 2 +- .../ModelFactoriesCustomizationTests.cs | 40 +++++++++++++ .../MockInputModel.cs | 21 +++++++ .../CanCustomizePropertyIntoReadOnlyMemory.cs | 2 +- .../UnbrandedTypeSpecModelFactory.cs | 12 ++-- 8 files changed, 121 insertions(+), 20 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanCustomizeModelFullConstructor/MockInputModel.cs diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/CanonicalTypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/CanonicalTypeProvider.cs index e483af5416..e723d90fa1 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/CanonicalTypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/CanonicalTypeProvider.cs @@ -44,7 +44,11 @@ public CanonicalTypeProvider(TypeProvider generatedTypeProvider, InputType? inpu private protected override CanonicalTypeProvider GetCanonicalView() => this; - // TODO - Implement BuildMethods, BuildConstructors, etc as needed + // TODO - Implement BuildMethods, etc as needed + protected override ConstructorProvider[] BuildConstructors() + { + return [.. _generatedTypeProvider.Constructors, .. _generatedTypeProvider.CustomCodeView?.Constructors ?? []]; + } protected override PropertyProvider[] BuildProperties() { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs index d44fa6fb0f..c39eb1d43e 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs @@ -79,14 +79,39 @@ protected override MethodProvider[] BuildMethods() if (typeToInstantiate is null) continue; - var modelCtor = modelProvider.FullConstructor; + var fullConstructor = modelProvider.FullConstructor; + var binaryDataParam = fullConstructor.Signature.Parameters.FirstOrDefault(p => p.Name.Equals(AdditionalBinaryDataParameterName)); + + // Use a custom constructor if the generated full constructor was suppressed or customized + if (!modelProvider.Constructors.Contains(fullConstructor)) + { + foreach (var constructor in modelProvider.CanonicalView.Constructors) + { + var customCtorParamCount = constructor.Signature.Parameters.Count; + var fullCtorParamCount = fullConstructor.Signature.Parameters.Count; + + if (constructor.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal) + && customCtorParamCount >= fullCtorParamCount) + { + binaryDataParam = constructor.Signature.Parameters + .FirstOrDefault(p => p?.Type.Equals(typeof(IDictionary)) == true, binaryDataParam); + + if (customCtorParamCount > fullCtorParamCount) + { + fullConstructor = constructor; + break; + } + } + } + } + var signature = new MethodSignature( modelProvider.Name, null, MethodSignatureModifiers.Static | MethodSignatureModifiers.Public, modelProvider.Type, $"A new {modelProvider.Type:C} instance for mocking.", - GetParameters(modelProvider)); + GetParameters(modelProvider, fullConstructor)); var docs = new XmlDocProvider(); docs.Summary = modelProvider.XmlDocs?.Summary; @@ -100,7 +125,7 @@ protected override MethodProvider[] BuildMethods() [ .. GetCollectionInitialization(signature), MethodBodyStatement.EmptyLine, - Return(New.Instance(typeToInstantiate.Type, [.. GetCtorArgs(modelProvider, signature)])) + Return(New.Instance(typeToInstantiate.Type, [.. GetCtorArgs(modelProvider, signature, fullConstructor, binaryDataParam)])) ]); methods.Add(new MethodProvider(signature, statements, this, docs)); @@ -110,9 +135,11 @@ .. GetCollectionInitialization(signature), private static IReadOnlyList GetCtorArgs( ModelProvider modelProvider, - MethodSignature factoryMethodSignature) + MethodSignature factoryMethodSignature, + ConstructorProvider fullConstructor, + ParameterProvider? binaryDataParameter) { - var modelCtorFullSignature = modelProvider.FullConstructor.Signature; + var modelCtorFullSignature = fullConstructor.Signature; var expressions = new List(modelCtorFullSignature.Parameters.Count); for (int i = 0; i < modelCtorFullSignature.Parameters.Count; i++) @@ -153,10 +180,9 @@ private static IReadOnlyList GetCtorArgs( } } - if (modelCtorFullSignature.Parameters.Any(p => p.Name.Equals(AdditionalBinaryDataParameterName)) && - !modelProvider.SupportsBinaryDataAdditionalProperties) + if (binaryDataParameter != null && !modelProvider.SupportsBinaryDataAdditionalProperties) { - expressions.Add(Null); + expressions.Add(binaryDataParameter.PositionalReference(Null)); } return [.. expressions]; @@ -175,14 +201,22 @@ private IReadOnlyList GetCollectionInitialization(MethodSig return [.. statements]; } - private static IReadOnlyList GetParameters(ModelProvider modelProvider) + private static IReadOnlyList GetParameters( + ModelProvider modelProvider, + ConstructorProvider fullConstructor) { - var modelCtorParams = modelProvider.FullConstructor.Signature.Parameters; + var modelCtorParams = fullConstructor.Signature.Parameters; var parameters = new List(modelCtorParams.Count); + bool isCustomConstructor = fullConstructor != modelProvider.FullConstructor; + foreach (var param in modelCtorParams) { - if (param.Name.Equals(AdditionalBinaryDataParameterName) && !modelProvider.SupportsBinaryDataAdditionalProperties) + bool isBinaryDataParam = param.Name.Equals(AdditionalBinaryDataParameterName) + || (isCustomConstructor && param.Type.Equals(typeof(IDictionary))); + + if (isBinaryDataParam && !modelProvider.SupportsBinaryDataAdditionalProperties) continue; + // skip discriminator parameters if the model has a discriminator value as those shouldn't be exposed in the factory methods if (param.Property?.IsDiscriminator == true && modelProvider.DiscriminatorValue != null) continue; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Snippets/Snippet.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Snippets/Snippet.cs index 59cfc32452..ed1280b2ae 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Snippets/Snippet.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Snippets/Snippet.cs @@ -21,6 +21,8 @@ public static partial class Snippet public static ValueExpression NullConditional(this ParameterProvider parameter) => new NullConditionalExpression(parameter); public static ValueExpression NullCoalesce(this ParameterProvider parameter, ValueExpression value) => parameter.AsExpression.NullCoalesce(value); + public static ValueExpression PositionalReference(this ParameterProvider parameter, ValueExpression value) + => new PositionalParameterReferenceExpression(parameter.Name, value); public static DictionaryExpression AsDictionary(this FieldProvider field, CSharpType keyType, CSharpType valueType) => new(new KeyValuePairType(keyType, valueType), field); public static DictionaryExpression AsDictionary(this ParameterProvider parameter, CSharpType keyType, CSharpType valueType) => new(new KeyValuePairType(keyType, valueType), parameter); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/DiscriminatorTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/DiscriminatorTests.cs index 1f1307b001..eaa48bcb5d 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/DiscriminatorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/DiscriminatorTests.cs @@ -77,7 +77,7 @@ public void DerivedShouldUseItsDiscriminatorValueInModelFactory() // ensure the signature is correct and includes the base discriminator value // and the cat model's discriminator with literal value Assert.IsTrue(birdModelMethod!.BodyStatements!.ToDisplayString() - .Contains("return new global::Sample.Models.Bird(\"red\", \"bird\", name, null);")); + .Contains("return new global::Sample.Models.Bird(\"red\", \"bird\", name, additionalBinaryDataProperties: null);")); } private static ModelFactoryProvider SetupModelFactory() diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 5ab6e3b992..f15b0e5278 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -122,5 +122,45 @@ public async Task CanChangeAccessibilityOfModelFactory() Assert.IsFalse(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); ValidateModelFactoryCommon(modelFactory); } + + [Test] + public async Task CanCustomizeModelFullConstructor() + { + var plugin = await MockHelpers.LoadMockPluginAsync( + inputModelTypes: [ + InputFactory.Model( + "mockInputModel", + properties: + [ + InputFactory.Property("Prop1", InputPrimitiveType.String, isRequired: true), + ]) + ], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + // Find the model factory provider + var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + + // The model factory should be public + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + ValidateModelFactoryCommon(modelFactory); + + // The model factory method should be replaced + var modelFactoryMethods = modelFactory!.Methods; + Assert.AreEqual(1, modelFactoryMethods.Count); + + var modelFactoryMethod = modelFactoryMethods[0]; + Assert.AreEqual("MockInputModel", modelFactoryMethod.Signature.Name); + + Assert.AreEqual(2, modelFactoryMethod.Signature.Parameters.Count); + Assert.AreEqual("data", modelFactoryMethod.Signature.Parameters[0].Name); + Assert.AreEqual("prop1", modelFactoryMethod.Signature.Parameters[1].Name); + + Assert.IsTrue(modelFactoryMethod.BodyStatements!.ToDisplayString() + .Contains("return new global::Sample.Models.MockInputModel(data?.ToList(), prop1, additionalBinaryDataProperties: null);")); + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanCustomizeModelFullConstructor/MockInputModel.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanCustomizeModelFullConstructor/MockInputModel.cs new file mode 100644 index 0000000000..824a0339f6 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanCustomizeModelFullConstructor/MockInputModel.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using Microsoft.Generator.CSharp.Customization; + +#nullable disable + +namespace Sample.Models; + +[CodeGenSuppress("MockInputModel", typeof(string), typeof(IDictionary))] +public partial class MockInputModel +{ + private readonly IReadOnlyList _data; + + internal MockInputModel(IReadOnlyList data, string prop1, IDictionary additionalBinaryDataProperties) + { + Prop1 = prop1; + _data = data; + _additionalBinaryDataProperties = additionalBinaryDataProperties; + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ModelCustomizationTests/CanCustomizePropertyIntoReadOnlyMemory.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ModelCustomizationTests/CanCustomizePropertyIntoReadOnlyMemory.cs index bcf6e2dd7e..d7973c6c7f 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ModelCustomizationTests/CanCustomizePropertyIntoReadOnlyMemory.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ModelCustomizationTests/CanCustomizePropertyIntoReadOnlyMemory.cs @@ -15,7 +15,7 @@ public static partial class SampleNamespaceModelFactory public static global::Sample.Models.MockInputModel MockInputModel(global::System.ReadOnlyMemory prop1 = default) { - return new global::Sample.Models.MockInputModel(prop1, null); + return new global::Sample.Models.MockInputModel(prop1, additionalBinaryDataProperties: null); } } } diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecModelFactory.cs index f84f4639a2..a505e2db17 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecModelFactory.cs @@ -46,7 +46,7 @@ public static Thing Thing(string name = default, BinaryData requiredUnion = defa requiredBadDescription, optionalNullableList?.ToList(), requiredNullableList?.ToList(), - null); + additionalBinaryDataProperties: null); } /// this is a roundtrip model. @@ -113,7 +113,7 @@ public static RoundTripModel RoundTripModel(string requiredString = default, int readOnlyOptionalRecordUnknown, modelWithRequiredNullable, requiredBytes, - null); + additionalBinaryDataProperties: null); } /// A model with a few required nullable properties. @@ -124,7 +124,7 @@ public static RoundTripModel RoundTripModel(string requiredString = default, int public static ModelWithRequiredNullableProperties ModelWithRequiredNullableProperties(int? requiredNullablePrimitive = default, StringExtensibleEnum? requiredExtensibleEnum = default, StringFixedEnum? requiredFixedEnum = default) { - return new ModelWithRequiredNullableProperties(requiredNullablePrimitive, requiredExtensibleEnum, requiredFixedEnum, null); + return new ModelWithRequiredNullableProperties(requiredNullablePrimitive, requiredExtensibleEnum, requiredFixedEnum, additionalBinaryDataProperties: null); } /// this is not a friendly model but with a friendly name. @@ -133,7 +133,7 @@ public static ModelWithRequiredNullableProperties ModelWithRequiredNullablePrope public static Friend Friend(string name = default) { - return new Friend(name, null); + return new Friend(name, additionalBinaryDataProperties: null); } /// this is a model with a projected name. @@ -142,7 +142,7 @@ public static Friend Friend(string name = default) public static ProjectedModel ProjectedModel(string name = default) { - return new ProjectedModel(name, null); + return new ProjectedModel(name, additionalBinaryDataProperties: null); } /// The ReturnsAnonymousModelResponse. @@ -150,7 +150,7 @@ public static ProjectedModel ProjectedModel(string name = default) public static ReturnsAnonymousModelResponse ReturnsAnonymousModelResponse() { - return new ReturnsAnonymousModelResponse(null); + return new ReturnsAnonymousModelResponse(additionalBinaryDataProperties: null); } } }