From 56c3d6262981dac0141aaa95d960becdd838b12a Mon Sep 17 00:00:00 2001 From: Jorge Rangel <102122018+jorgerangel-msft@users.noreply.github.com> Date: Thu, 3 Oct 2024 18:57:45 -0500 Subject: [PATCH] Fix: Models with nested Discriminators (#4596) fixes: https://github.com/microsoft/typespec/issues/4597 --- .../MrwSerializationTypeDefinition.cs | 19 ++++- .../MrwSerializationTypeDefinitionTests.cs | 34 ++++++++ .../src/Providers/ModelProvider.cs | 81 ++++++++++++------- .../ModelFactoryProviderTests.cs | 5 +- .../ModelProviders/DiscriminatorTests.cs | 38 ++++++++- 5 files changed, 140 insertions(+), 37 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 533deeb3b2..c67842c302 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -735,9 +735,22 @@ private List BuildDeserializePropertiesStatements(ScopedApi CreateDeserializeAdditionalPropsValueKindCheck(jsonProperty, additionalPropsValueKindBodyStatements)); } - // deserialize the raw binary data for the model - var rawBinaryData = _rawDataField - ?? _model.BaseModelProvider?.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); + // deserialize the raw binary data for the model by searching for the raw binary data field in the model and any base models. + var rawBinaryData = _rawDataField; + if (rawBinaryData == null) + { + var baseModelProvider = _model.BaseModelProvider; + while (baseModelProvider != null) + { + var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); + if (field != null) + { + rawBinaryData = field; + break; + } + baseModelProvider = baseModelProvider.BaseModelProvider; + } + } if (_additionalBinaryDataProperty != null) { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs index cd22d2c7a7..744283ab88 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs @@ -531,6 +531,40 @@ public void TestBuildDeserializationMethod() Assert.IsNotNull(methodBody); } + [Test] + public void TestBuildDeserializationMethodNestedSARD() + { + var baseModel = InputFactory.Model("BaseModel"); + var nestedModel = InputFactory.Model("NestedModel", baseModel: baseModel); + var inputModel = InputFactory.Model("mockInputModel", baseModel: nestedModel); + var (baseModelProvider, baseSerialization) = CreateModelAndSerialization(baseModel); + var (nestedModelProvider, nestedSerialization) = CreateModelAndSerialization(nestedModel); + var (model, serialization) = CreateModelAndSerialization(inputModel); + + Assert.AreEqual(0, model.Fields.Count); + Assert.AreEqual(0, nestedModelProvider.Fields.Count); + Assert.AreEqual(1, baseModelProvider.Fields.Count); + + var deserializationMethod = serialization.BuildDeserializationMethod(); + Assert.IsNotNull(deserializationMethod); + + var signature = deserializationMethod?.Signature; + Assert.IsNotNull(signature); + Assert.AreEqual($"Deserialize{model.Name}", signature?.Name); + Assert.AreEqual(2, signature?.Parameters.Count); + Assert.AreEqual(new CSharpType(typeof(JsonElement)), signature?.Parameters[0].Type); + Assert.AreEqual(new CSharpType(typeof(ModelReaderWriterOptions)), signature?.Parameters[1].Type); + Assert.AreEqual(model.Type, signature?.ReturnType); + Assert.AreEqual(MethodSignatureModifiers.Internal | MethodSignatureModifiers.Static, signature?.Modifiers); + + var methodBody = deserializationMethod?.BodyStatements; + Assert.IsNotNull(methodBody); + // validate that only one SARD variable is created. + var methodBodyString = methodBody!.ToDisplayString(); + var sardDeclaration = "global::System.Collections.Generic.IDictionary additionalBinaryDataProperties"; + Assert.AreEqual(1, methodBodyString.Split(sardDeclaration).Length - 1); + } + [Test] public void TestBuildImplicitToBinaryContent() { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs index a72f12edaa..7828063ac7 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs @@ -9,7 +9,6 @@ using Microsoft.Generator.CSharp.Input; using Microsoft.Generator.CSharp.Primitives; using Microsoft.Generator.CSharp.Snippets; -using Microsoft.Generator.CSharp.SourceInput; using Microsoft.Generator.CSharp.Statements; using static Microsoft.Generator.CSharp.Snippets.Snippet; @@ -331,12 +330,12 @@ protected override PropertyProvider[] BuildProperties() var properties = new List(propertiesCount + 1); Dictionary baseProperties = _inputModel.BaseModel?.Properties.ToDictionary(p => p.Name) ?? []; - + var baseModelDiscriminator = _inputModel.BaseModel?.DiscriminatorProperty; for (int i = 0; i < propertiesCount; i++) { var property = _inputModel.Properties[i]; - if (property.IsDiscriminator && Type.BaseType is not null) + if (property.IsDiscriminator && property.Name == baseModelDiscriminator?.Name) continue; var outputProperty = CodeModelPlugin.Instance.TypeFactory.CreateProperty(property, this); @@ -458,9 +457,8 @@ private ConstructorProvider BuildFullConstructor() if (isPrimaryConstructor) { - baseProperties = _inputModel.GetAllBaseModels() - .Reverse() - .SelectMany(model => CodeModelPlugin.Instance.TypeFactory.CreateModel(model)?.Properties ?? []); + // the primary ctor should only include the properties of the direct base model + baseProperties = BaseModelProvider?.Properties ?? []; } else if (BaseModelProvider?.FullConstructor.Signature != null) { @@ -515,24 +513,7 @@ p.Property is null var type = discriminator.Type; if (IsUnknownDiscriminatorModel) { - var discriminatorExpression = discriminator.AsParameter.AsExpression; - if (!type.IsFrameworkType && type.IsEnum) - { - if (type.IsStruct) - { - /* kind != default ? kind : "unknown" */ - return new TernaryConditionalExpression(discriminatorExpression.NotEqual(Default), discriminatorExpression, Literal(_inputModel.DiscriminatorValue)); - } - else - { - return discriminatorExpression; - } - } - else - { - /* kind ?? "unknown" */ - return discriminatorExpression.NullCoalesce(Literal(_inputModel.DiscriminatorValue)); - } + return GetUnknownDiscriminatorExpression(discriminator); } else { @@ -558,10 +539,16 @@ p.Property is null private ValueExpression GetExpressionForCtor(ParameterProvider parameter, HashSet overriddenProperties, bool isPrimaryConstructor) { - if (parameter.Property is not null && parameter.Property.IsDiscriminator && _inputModel.DiscriminatorValue != null && - (isPrimaryConstructor || !isPrimaryConstructor && IsUnknownDiscriminatorModel)) + if (parameter.Property is not null && parameter.Property.IsDiscriminator && _inputModel.DiscriminatorValue != null) { - return DiscriminatorValueExpression ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}"); + if (isPrimaryConstructor) + { + return DiscriminatorValueExpression ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}"); + } + else if (IsUnknownDiscriminatorModel) + { + return GetUnknownDiscriminatorExpression(parameter.Property) ?? throw new InvalidOperationException($"invalid discriminator {_inputModel.DiscriminatorValue} for property {parameter.Property.Name}"); + } } var paramToUse = parameter.Property is not null && overriddenProperties.Contains(parameter.Property) ? Properties.First(p => p.Name == parameter.Property.Name).AsParameter : parameter; @@ -569,6 +556,35 @@ private ValueExpression GetExpressionForCtor(ParameterProvider parameter, HashSe return paramToUse.Property is not null ? GetConversion(paramToUse.Property) : paramToUse; } + private ValueExpression? GetUnknownDiscriminatorExpression(PropertyProvider property) + { + if (!property.IsDiscriminator || _inputModel.DiscriminatorValue == null) + { + return null; + } + + var discriminatorExpression = property.AsParameter.AsExpression; + var type = property.Type; + + if (!type.IsFrameworkType && type.IsEnum) + { + if (type.IsStruct) + { + /* kind != default ? kind : "unknown" */ + return new TernaryConditionalExpression(discriminatorExpression.NotEqual(Default), discriminatorExpression, Literal(_inputModel.DiscriminatorValue)); + } + else + { + return discriminatorExpression; + } + } + else + { + /* kind ?? "unknown" */ + return discriminatorExpression.NullCoalesce(Literal(_inputModel.DiscriminatorValue)); + } + } + private static void AddInitializationParameterForCtor( List parameters, PropertyProvider property, @@ -697,10 +713,15 @@ private ValueExpression GetConversion(PropertyProvider property) /// The constructed if the model should generate the field. private FieldProvider? BuildRawDataField() { - // check if there is a raw data field on my base, if so, we do not have to have one here - if (BaseModelProvider?.RawDataField != null) + // check if there is a raw data field on any of the base models, if so, we do not have to have one here. + var baseModelProvider = BaseModelProvider; + while (baseModelProvider != null) { - return null; + if (baseModelProvider.RawDataField != null) + { + return null; + } + baseModelProvider = baseModelProvider.BaseModelProvider; } var modifiers = FieldModifiers.Private; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoryProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoryProviderTests.cs index d27a68693a..80be39ccff 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoryProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoryProviderTests.cs @@ -89,10 +89,9 @@ public void DiscriminatorEnumParamShape() Assert.IsNotNull(method); foreach (var property in model!.Properties.Where(p => p.Type.IsEnum)) { + // enum discriminator properties are not included in the factory method var parameter = method!.Signature.Parameters.FirstOrDefault(p => p.Name == property.Name.ToVariableName()); - Assert.IsNotNull(parameter); - Assert.IsTrue(parameter!.Type.IsFrameworkType); - Assert.AreEqual(typeof(int), parameter!.Type.FrameworkType); + Assert.IsNull(parameter); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/DiscriminatorTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/DiscriminatorTests.cs index 903b0c22a0..dfc920763e 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/DiscriminatorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/DiscriminatorTests.cs @@ -25,10 +25,21 @@ public class DiscriminatorTests InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), InputFactory.Property("likesBones", InputPrimitiveType.Boolean, isRequired: true) ]); + + private static readonly InputModelType _anotherAnimal = InputFactory.Model("anotherAnimal", discriminatedKind: "dog", properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("other", InputPrimitiveType.String, isRequired: true, isDiscriminator: true) + ]); private static readonly InputModelType _baseModel = InputFactory.Model( "pet", properties: [InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true)], - discriminatedModels: new Dictionary() { { "cat", _catModel }, { "dog", _dogModel } }); + discriminatedModels: new Dictionary() + { + { "cat", _catModel }, + { "dog", _dogModel }, + { "otherAnimal", _anotherAnimal } + }); private static readonly InputEnumType _petEnum = InputFactory.Enum("pet", InputPrimitiveType.String, isExtensible: true, values: [ @@ -206,5 +217,30 @@ public void DerivedHasNoKindProperty() var kindProperty = catModel!.Properties.FirstOrDefault(p => p.Name == "Kind"); Assert.IsNull(kindProperty); } + + [Test] + public void ModelWithNestedDiscriminators() + { + MockHelpers.LoadMockPlugin(inputModelTypes: [_baseEnumModel, _dogEnumModel, _anotherAnimal]); + var outputLibrary = CodeModelPlugin.Instance.OutputLibrary; + var anotherDogModel = outputLibrary.TypeProviders.OfType().FirstOrDefault(t => t.Name == "AnotherAnimal"); + Assert.IsNotNull(anotherDogModel); + + var serializationCtor = anotherDogModel!.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal)); + Assert.IsNotNull(serializationCtor); + Assert.AreEqual(3, serializationCtor!.Signature.Parameters.Count); + + // ensure both discriminators are present + var kindParam = serializationCtor!.Signature.Parameters.FirstOrDefault(p => p.Name == "kind"); + Assert.IsNotNull(kindParam); + var otherParam = serializationCtor!.Signature.Parameters.FirstOrDefault(p => p.Name == "other"); + Assert.IsNotNull(otherParam); + + // the primary ctor should only have the model's own discriminator + var publicCtor = anotherDogModel.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public)); + Assert.IsNotNull(publicCtor); + Assert.AreEqual(1, publicCtor!.Signature.Parameters.Count); + Assert.AreEqual("other", publicCtor.Signature.Parameters[0].Name); + } } }