Skip to content

Commit

Permalink
[http-client-csharp] Remove incorrect casting for unknown discriminat…
Browse files Browse the repository at this point in the history
…ed subtype models (#4963)

fixes: #4958
  • Loading branch information
jorgerangel-msft authored Nov 4, 2024
1 parent 3b17396 commit 49c0527
Show file tree
Hide file tree
Showing 25 changed files with 127 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,22 @@ internal MethodProvider BuildPersistableModelCreateCoreMethod()
/// </summary>
internal MethodProvider BuildJsonModelCreateMethod()
{
ValueExpression createCoreInvocation = This.Invoke(JsonModelCreateCoreMethodName, [_utf8JsonReaderParameter, _serializationOptionsParameter]);
var createCoreReturnType = _model.Type.RootType;

// If the return type of the create core method is not the same as the interface type, cast it to the interface type since
// the Core methods will always return the root type of the model. The interface type will be the model type unless the model
// is an unknown discriminated model.
if (createCoreReturnType != _jsonModelTInterface.Arguments[0])
{
createCoreInvocation = createCoreInvocation.CastTo(_model.Type);
}

// T IJsonModel<T>.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) => JsonModelCreateCore(ref reader, options);
return new MethodProvider
(
new MethodSignature(nameof(IJsonModel<object>.Create), null, MethodSignatureModifiers.None, _jsonModelTInterface.Arguments[0], null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelTInterface),
This.Invoke(JsonModelCreateCoreMethodName, [_utf8JsonReaderParameter, _serializationOptionsParameter]).CastTo(_model.Type),
createCoreInvocation,
this
);
}
Expand Down Expand Up @@ -514,11 +525,21 @@ internal MethodProvider BuildPersistableModelWriteMethod()
internal MethodProvider BuildPersistableModelCreateMethod()
{
ParameterProvider dataParameter = new("data", $"The data to parse.", typeof(BinaryData));
ValueExpression createCoreInvocation = This.Invoke(PersistableModelCreateCoreMethodName, [dataParameter, _serializationOptionsParameter]);
var createCoreReturnType = _model.Type.RootType;

// If the return type of the create core method is not the same as the interface type, cast it to the interface type since
// the Core methods will always return the root type of the model. The interface type will be the model type unless the model
// is an unknown discriminated model.
if (createCoreReturnType != _persistableModelTInterface.Arguments[0])
{
createCoreInvocation = createCoreInvocation.CastTo(_model.Type);
}
// IPersistableModel<T>.Create(BinaryData data, ModelReaderWriterOptions options) => PersistableModelCreateCore(data, options);
return new MethodProvider
(
new MethodSignature(nameof(IPersistableModel<object>.Create), null, MethodSignatureModifiers.None, _persistableModelTInterface.Arguments[0], null, [dataParameter, _serializationOptionsParameter], ExplicitInterface: _persistableModelTInterface),
This.Invoke(PersistableModelCreateCoreMethodName, [dataParameter, _serializationOptionsParameter]).CastTo(_model.Type),
createCoreInvocation,
this
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Generator.CSharp.ClientModel.Providers;
using Microsoft.Generator.CSharp.Input;
using Microsoft.Generator.CSharp.Primitives;
using Microsoft.Generator.CSharp.Providers;
Expand Down Expand Up @@ -200,5 +201,58 @@ public void DiscriminatorDeserializationUsesCorrectDiscriminatorPropName()
Assert.IsTrue(deserializationMethod?.BodyStatements!.ToDisplayString().Contains(
$"if (element.TryGetProperty(\"foo\"u8, out global::System.Text.Json.JsonElement discriminator))"));
}

[Test]
public void TestBuildJsonModelCreateMethodProperlyCastsForDiscriminatedType()
{
MockHelpers.LoadMockPlugin(inputModels: () => [_baseModel, _catModel]);
var outputLibrary = ClientModelPlugin.Instance.OutputLibrary;
var model = outputLibrary.TypeProviders.OfType<ModelProvider>().FirstOrDefault(t => t.Name == "Cat");
Assert.IsNotNull(model);

var serialization = model!.SerializationProviders.FirstOrDefault() as MrwSerializationTypeDefinition;
Assert.IsNotNull(serialization);
var method = serialization!.Methods.FirstOrDefault(m => m.Signature.Name == "Create" && m.Signature.ExplicitInterface?.Name == "IJsonModel");

Assert.IsNotNull(method);

var expectedJsonInterface = new CSharpType(typeof(IJsonModel<>), model!.Type);
var methodSignature = method?.Signature;
Assert.IsNotNull(methodSignature);

var expectedReturnType = expectedJsonInterface.Arguments[0];
Assert.AreEqual(expectedReturnType, methodSignature?.ReturnType);

var invocationExpression = method!.BodyExpression;
Assert.IsNotNull(invocationExpression);
Assert.AreEqual(
"((global::Sample.Models.Cat)this.JsonModelCreateCore(ref reader, options))",
invocationExpression!.ToDisplayString());
}

[Test]
public void TestBuildJsonModelCreateMethodProperlyDoesNotCastForUnknown()
{
MockHelpers.LoadMockPlugin(inputModels: () => [_baseModel, _catModel]);
var outputLibrary = ClientModelPlugin.Instance.OutputLibrary;
var model = outputLibrary.TypeProviders.OfType<ModelProvider>().FirstOrDefault(t => t.Name == "UnknownPet");
Assert.IsNotNull(model);

var serialization = model!.SerializationProviders.FirstOrDefault() as MrwSerializationTypeDefinition;
Assert.IsNotNull(serialization);
var method = serialization!.Methods.FirstOrDefault(m => m.Signature.Name == "Create" && m.Signature.ExplicitInterface?.Name == "IJsonModel");

Assert.IsNotNull(method);

var methodSignature = method?.Signature;
Assert.IsNotNull(methodSignature);
Assert.AreEqual("Pet", methodSignature?.ReturnType!.Name);

var invocationExpression = method!.BodyExpression;
Assert.IsNotNull(invocationExpression);
Assert.AreEqual(
"this.JsonModelCreateCore(ref reader, options)",
invocationExpression!.ToDisplayString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ public void TestBuildJsonModelCreateMethod()
Assert.AreEqual(2, methodSignature?.Parameters.Count);
var expectedReturnType = expectedJsonInterface.Arguments[0];
Assert.AreEqual(expectedReturnType, methodSignature?.ReturnType);

var invocationExpression = method!.BodyExpression;
Assert.IsNotNull(invocationExpression);
Assert.AreEqual("this.JsonModelCreateCore(ref reader, options)", invocationExpression!.ToDisplayString());
}

// This test validates the json model serialization create core method is built correctly
Expand Down Expand Up @@ -353,13 +357,9 @@ public void TestBuildPersistableModelCreateMethod()
// Validate body
var methodBody = method?.BodyStatements;
Assert.IsNull(methodBody);
var bodyExpression = method?.BodyExpression as CastExpression;
var bodyExpression = method?.BodyExpression;
Assert.IsNotNull(bodyExpression);
var invocationExpression = bodyExpression?.Inner as InvokeMethodExpression;
Assert.IsNotNull(invocationExpression);
Assert.AreEqual("PersistableModelCreateCore", invocationExpression?.MethodName);
Assert.IsNotNull(invocationExpression?.InstanceReference);
Assert.AreEqual(2, invocationExpression?.Arguments.Count);
Assert.AreEqual("this.PersistableModelCreateCore(data, options)", bodyExpression!.ToDisplayString());
}

// This test validates the persistable model serialization create core method is built correctly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -109,7 +109,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -109,7 +109,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -106,7 +106,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -113,7 +113,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -118,7 +118,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.JsonModelCreateCore(ref reader, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.MockInputModel>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.JsonModelCreateCore(ref reader, options);

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down Expand Up @@ -106,7 +106,7 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
}

global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.MockInputModel)this.PersistableModelCreateCore(data, options));
global::Sample.Models.MockInputModel global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelCreateCore(data, options);

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
Expand Down
Loading

0 comments on commit 49c0527

Please sign in to comment.