Skip to content

Commit

Permalink
feat: Added Cohere AllOf inheritance/polymorphism support.
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed Dec 17, 2024
1 parent b8da9f5 commit ff2004d
Show file tree
Hide file tree
Showing 2,203 changed files with 427 additions and 23,265 deletions.
30 changes: 25 additions & 5 deletions src/libs/AutoSDK/Models/ModelData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ public readonly record struct ModelData(
ImmutableArray<PropertyData> Properties,
ImmutableArray<PropertyData> EnumValues,
string Summary,
bool IsDeprecated
bool IsDeprecated,
string BaseClass,
bool IsBaseClass,
bool IsDerivedClass,
string DiscriminatorPropertyName,
EquatableArray<(string ClassName, string Discriminator)> DerivedTypes
)
{
public static ModelData FromSchemaContext(
Expand Down Expand Up @@ -41,16 +46,31 @@ public static ModelData FromSchemaContext(
Namespace: context.Settings.Namespace,
Style: context.Schema.IsEnum() ? ModelStyle.Enumeration : context.Settings.ModelStyle,
Settings: context.Settings,
Properties: !context.Schema.IsEnum()
? context.Children
Properties: context.IsDerivedClass
? context.DerivedClassContext.Children
.Where(x => x is { IsProperty: true, PropertyData: not null })
.SelectMany(x => x.ComputedProperties)
.ToImmutableArray() : [],
.ToImmutableArray()
: !context.Schema.IsEnum()
? context.Children
.Where(x => x is { IsProperty: true, PropertyData: not null })
.SelectMany(x => x.ComputedProperties)
.ToImmutableArray()
: [],
EnumValues: context.Schema.IsEnum()
? context.ComputeEnum().Values.ToImmutableArray()
: [],
Summary: context.Schema.GetSummary(),
IsDeprecated: context.Schema.Deprecated
IsDeprecated: context.Schema.Deprecated,
BaseClass: context.IsDerivedClass
? context.BaseClassContext.Id
: string.Empty,
IsBaseClass: context.IsBaseClass,
IsDerivedClass: context.IsDerivedClass,
DiscriminatorPropertyName: context.Schema.Discriminator?.PropertyName ?? string.Empty,
DerivedTypes: context.Schema.Discriminator?.Mapping?
.Select(x => (ClassName: x.Value.Replace("#/components/schemas/", string.Empty), Discriminator: x.Key))
.ToImmutableArray() ?? []
);
}

Expand Down
35 changes: 31 additions & 4 deletions src/libs/AutoSDK/Models/SchemaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public class SchemaContext
public IList<SchemaContext> Children { get; set; } = [];

public required Settings Settings { get; init; }
public required OpenApiSchema Schema { get; init; }
public required OpenApiSchema Schema { get; set; }
public required string Id { get; set; }
public required string Type { get; init; }
public required string Type { get; set; }

public string? ReferenceId { get; init; }
public bool IsReference => ReferenceId != null;
Expand Down Expand Up @@ -52,7 +52,9 @@ public class SchemaContext

public TypeData TypeData { get; set; } = TypeData.Default;

public bool IsClass => Type == "class";// || ResolvedReference?.IsClass == true;
public bool IsClass =>
Type == "class" ||
IsDerivedClass;// || ResolvedReference?.IsClass == true;
//public ModelData? ClassData { get; set; }
public ModelData? ClassData => IsClass
? //IsReference
Expand All @@ -78,7 +80,32 @@ public class SchemaContext

public bool IsAnyOf => Schema.IsAnyOf();
public bool IsOneOf => Schema.IsOneOf();
public bool IsAllOf => Schema.IsAllOf();
public bool IsAllOf => Schema.IsAllOf() && !IsDerivedClass;
public bool IsBaseClass => this is { IsComponent: true, Schema.Discriminator.Mapping: not null };
public bool IsDerivedClass => Schema.IsAllOf() &&
Schema.AllOf is { Count: 2 } allOf &&
(allOf[0].Reference != null &&
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null ||
allOf[1].Reference != null &&
allOf[1].ResolveIfRequired().Discriminator?.Mapping != null);
public SchemaContext DerivedClassContext =>
Schema.IsAllOf() &&
Schema.AllOf is { Count: 2 } allOf
? allOf[0].Reference != null &&
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null
? Children.First(x => x.ReferenceId == allOf[1].Reference?.Id)
: Children.First(x => x.ReferenceId == allOf[0].Reference?.Id)
: throw new InvalidOperationException("Schema is not derived class.");

public SchemaContext BaseClassContext =>
Schema.IsAllOf() &&
Schema.AllOf is { Count: 2 } allOf
? allOf[0].Reference != null &&
allOf[0].ResolveIfRequired().Discriminator?.Mapping != null
? Children.First(x => x.ReferenceId == allOf[0].Reference?.Id)
: Children.First(x => x.ReferenceId == allOf[1].Reference?.Id)
: throw new InvalidOperationException("Schema is not derived class.");

public bool IsAnyOfLikeStructure => IsAnyOf || IsOneOf || IsAllOf;
public bool IsNamedAnyOfLike => IsAnyOfLikeStructure &&
(IsComponent || Schema.Discriminator != null);
Expand Down
1 change: 1 addition & 0 deletions src/libs/AutoSDK/Models/TypeData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ public static string GetCSharpType(SchemaContext context)
$"{context.Children.FirstOrDefault(x => x.Hint == Hint.ArrayItem)?.TypeData.CSharpTypeWithoutNullability}".AsArray(),

(_, _) when context.IsNamedAnyOfLike => $"global::{context.Settings.Namespace}.{context.Id}",
(_, _) when context.IsDerivedClass => $"global::{context.Settings.Namespace}.{context.Id}",

(_, _) when context.Schema.IsAnyOf() => $"global::{context.Settings.Namespace}.AnyOf<{string.Join(", ", context.Children.Where(x => x.Hint == Hint.AnyOf).Select(x => x.TypeData.CSharpTypeWithNullabilityForValueTypes))}>",
(_, _) when context.Schema.IsOneOf() => $"global::{context.Settings.Namespace}.OneOf<{string.Join(", ", context.Children.Where(x => x.Hint == Hint.OneOf).Select(x => x.TypeData.CSharpTypeWithNullabilityForValueTypes))}>",
Expand Down
45 changes: 31 additions & 14 deletions src/libs/AutoSDK/Sources/Sources.Models.Json.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ public static string GenerateClassFromToJsonMethods(
ModelData modelData,
CancellationToken cancellationToken = default)
{
if (modelData.IsDerivedClass)
{
return string.Empty;
}

return GenerateModelFromToJsonMethods(
@namespace: modelData.Namespace,
className: modelData.ClassName,
settings: modelData.Settings,
isValueType: false,
baseClassName: modelData.BaseClass,
isBaseClass: modelData.IsBaseClass,
cancellationToken);
}

Expand All @@ -32,6 +39,8 @@ public static string GenerateAnyOfFromToJsonMethods(
className: className,
settings: anyOfData.Settings,
isValueType: true,
baseClassName: string.Empty,
isBaseClass: false,
cancellationToken);
}

Expand All @@ -40,12 +49,15 @@ public static string GenerateModelFromToJsonMethods(
string className,
Settings settings,
bool isValueType,
string baseClassName,
bool isBaseClass,
CancellationToken cancellationToken = default)
{
var typeName = $"global::{@namespace}.{className}";
var modifiers = isValueType
? "readonly partial struct"
: "sealed partial class";
: $"{(isBaseClass ? "" : "sealed ")}partial class";
var isDerivedClass = !string.IsNullOrWhiteSpace(baseClassName);

return settings.JsonSerializerType == JsonSerializerType.SystemTextJson
? @$"#nullable enable
Expand All @@ -60,7 +72,7 @@ public string ToJson(
{{
return global::System.Text.Json.JsonSerializer.Serialize(
this,
this.GetType(),
{(isDerivedClass ? $"typeof({baseClassName})" : "this.GetType()")},
jsonSerializerContext);
}}
Expand All @@ -74,45 +86,49 @@ public string ToJson(
{{
return global::System.Text.Json.JsonSerializer.Serialize(
this,
{(isDerivedClass ? $"typeof({baseClassName})," : string.Empty)}
jsonSerializerOptions);
}}
{"Deserializes a JSON string using the provided JsonSerializerContext.".ToXmlDocumentationSummary(level: 8)}
public static {typeName}? FromJson(
public static {typeName}? FromJson{(isDerivedClass ? "<T>" : string.Empty)}(
string json,
global::System.Text.Json.Serialization.JsonSerializerContext jsonSerializerContext)
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
{{
return global::System.Text.Json.JsonSerializer.Deserialize(
json,
typeof({typeName}),
jsonSerializerContext) as {typeName}{(isValueType ? "?" : "")};
typeof({(isDerivedClass ? baseClassName : typeName)}),
jsonSerializerContext) as {(isDerivedClass ? "T" : typeName)}{(isValueType ? "?" : "")};
}}
{"Deserializes a JSON string using the provided JsonSerializerOptions.".ToXmlDocumentationSummary(level: 8)}
#if NET8_0_OR_GREATER
[global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(""JSON serialization and deserialization might require types that cannot be statically analyzed. Use the overload that takes a JsonTypeInfo or JsonSerializerContext, or make sure all of the required types are preserved."")]
[global::System.Diagnostics.CodeAnalysis.RequiresDynamicCode(""JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use System.Text.Json source generation for native AOT applications."")]
#endif
public static {typeName}? FromJson(
public static {typeName}? FromJson{(isDerivedClass ? "<T>" : string.Empty)}(
string json,
global::System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null)
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
{{
return global::System.Text.Json.JsonSerializer.Deserialize<{typeName}>(
return global::System.Text.Json.JsonSerializer.Deserialize<{(isDerivedClass ? baseClassName : typeName)}>(
json,
jsonSerializerOptions);
jsonSerializerOptions){(isDerivedClass ? " as T" : string.Empty)};
}}
/// <summary>
/// Deserializes a JSON stream using the provided JsonSerializerContext.
/// </summary>
public static async global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync(
public static async global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync{(isDerivedClass ? "<T>" : string.Empty)}(
global::System.IO.Stream jsonStream,
global::System.Text.Json.Serialization.JsonSerializerContext jsonSerializerContext)
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
{{
return (await global::System.Text.Json.JsonSerializer.DeserializeAsync(
jsonStream,
typeof({typeName}),
jsonSerializerContext).ConfigureAwait(false)) as {typeName}{(isValueType ? "?" : "")};
typeof({(isDerivedClass ? baseClassName : typeName)}),
jsonSerializerContext).ConfigureAwait(false)) as {(isDerivedClass ? "T" : typeName)}{(isValueType ? "?" : "")};
}}
/// <summary>
Expand All @@ -122,17 +138,18 @@ public string ToJson(
[global::System.Diagnostics.CodeAnalysis.RequiresUnreferencedCode(""JSON serialization and deserialization might require types that cannot be statically analyzed. Use the overload that takes a JsonTypeInfo or JsonSerializerContext, or make sure all of the required types are preserved."")]
[global::System.Diagnostics.CodeAnalysis.RequiresDynamicCode(""JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. Use System.Text.Json source generation for native AOT applications."")]
#endif
public static global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync(
public static global::System.Threading.Tasks.ValueTask<{typeName}?> FromJsonStreamAsync{(isDerivedClass ? "<T>" : string.Empty)}(
global::System.IO.Stream jsonStream,
global::System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null)
{(isDerivedClass ? $"where T : {baseClassName}" : string.Empty)}
{{
return global::System.Text.Json.JsonSerializer.DeserializeAsync<{typeName}?>(
jsonStream,
jsonSerializerOptions);
jsonSerializerOptions){(isDerivedClass ? " as T" : string.Empty)};
}}
}}
}}
"
".RemoveBlankLinesWhereOnlyWhitespaces()
: @$"#nullable enable
namespace {@namespace}
Expand Down
4 changes: 2 additions & 2 deletions src/libs/AutoSDK/Sources/Sources.Models.Validation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public static string GenerateClassValidationMethods(
ModelData modelData,
CancellationToken cancellationToken = default)
{
return GenerateModelFromToJsonMethods(
return GenerateModelValidationMethods(
@namespace: modelData.Namespace,
className: modelData.ClassName,
settings: modelData.Settings,
Expand All @@ -26,7 +26,7 @@ public static string GenerateAnyOfValidationMethods(
? $"{anyOfData.SubType}{types}"
: anyOfData.Name;

return GenerateModelFromToJsonMethods(
return GenerateModelValidationMethods(
@namespace: anyOfData.Namespace,
className: className,
settings: anyOfData.Settings,
Expand Down
34 changes: 24 additions & 10 deletions src/libs/AutoSDK/Sources/Sources.Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,23 @@ public static string GenerateClassModel(
var additionalPropertiesPostfix = modelData.ClassName == "AdditionalProperties"
? "2"
: string.Empty;
var properties = modelData.Properties.Where(x =>
!modelData.IsBaseClass ||
x.Id != modelData.DiscriminatorPropertyName).ToArray();

return $@"
{modelData.Summary.ToXmlDocumentationSummary(level: 4)}
{(modelData.IsDeprecated ? "[global::System.Obsolete(\"This model marked as deprecated.\")]" : " ")}
public sealed partial class {modelData.ClassName}
{(modelData.Settings.JsonSerializerType == JsonSerializerType.SystemTextJson && modelData.IsBaseClass ? @$"
[global::System.Text.Json.Serialization.JsonPolymorphic(
TypeDiscriminatorPropertyName = ""{modelData.DiscriminatorPropertyName}"",
IgnoreUnrecognizedTypeDiscriminators = true,
UnknownDerivedTypeHandling = global::System.Text.Json.Serialization.JsonUnknownDerivedTypeHandling.FallBackToBaseType)]
{modelData.DerivedTypes.Select(x => $@"
[global::System.Text.Json.Serialization.JsonDerivedType(typeof({modelData.Namespace}.{x.ClassName}), typeDiscriminator: ""{x.Discriminator}"")]").Inject()}" : " ")}
public{(modelData.IsBaseClass ? "" : " sealed")} partial class {modelData.ClassName}{(!string.IsNullOrWhiteSpace(modelData.BaseClass) ? $" : {modelData.BaseClass}" : "")}
{{
{modelData.Properties.Select(property => @$"
{properties.Select(property => @$"
{property.Summary.ToXmlDocumentationSummary(level: 8)}
{property.DefaultValue?.ClearForXml().ToXmlDocumentationDefault(level: 8)}
{property.Example?.ToXmlDocumentationExample(level: 8)}
Expand All @@ -98,31 +108,35 @@ public sealed partial class {modelData.ClassName}
public{(property.IsRequired ? requiredKeyword : "")} {property.Type.CSharpType} {property.Name} {{ get; set; }}{GetDefaultValue(property, isRequiredKeywordSupported)}
").Inject()}
{(!modelData.IsDerivedClass ? $@"
{"Additional properties that are not explicitly defined in the schema".ToXmlDocumentationSummary(level: 8)}
{jsonSerializer.GenerateExtensionDataAttribute()}
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties{additionalPropertiesPostfix} {{ get; set; }} = new global::System.Collections.Generic.Dictionary<string, object>();
" : " ")}
{( properties.Any(static x => x.IsRequired || !x.IsDeprecated) ? $@"
/// <summary>
/// Initializes a new instance of the <see cref=""{modelData.ClassName}"" /> class.
/// </summary>
{modelData.Properties.Where(static x => x.IsRequired || !x.IsDeprecated).Select(x => $@"
{properties.Where(static x => x.IsRequired || !x.IsDeprecated).Select(x => $@"
{x.Summary.ToXmlDocumentationForParam(x.ParameterName, level: 8)}").Inject()}
{(modelData.Settings.TargetFramework.StartsWith("net8", StringComparison.OrdinalIgnoreCase) ? "[global::System.Diagnostics.CodeAnalysis.SetsRequiredMembers]" : " ")}
public {modelData.ClassName}(
{string.Join(",",
modelData.Properties.Where(static x => x.IsRequired).Select(x => $@"
properties.Where(static x => x.IsRequired).Select(x => $@"
{x.Type.CSharpType} {x.ParameterName}").Concat(
modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && (x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && (x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
{x.Type.CSharpType} {x.ParameterName}")).Concat(
modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && !(x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
properties.Where(static x => x is { IsRequired: false, IsDeprecated: false } && !(x.Type.CSharpTypeNullability || string.IsNullOrWhiteSpace(x.DefaultValue))).Select(x => $@"
{x.Type.CSharpType} {x.ParameterName}{GetDefaultValue(x, isRequiredKeywordSupported).TrimEnd(';')}")))})
{{
{modelData.Properties.Where(static x => x.IsRequired).Select(x => $@"
{properties.Where(static x => x.IsRequired).Select(x => $@"
this.{x.Name} = {x.ParameterName}{(x.Type.IsValueType ? "" : $" ?? throw new global::System.ArgumentNullException(nameof({x.ParameterName}))")};").Inject()}
{modelData.Properties.Where(static x => x is { IsRequired: false, IsDeprecated: false }).Select(x => $@"
{properties.Where(static x => x is { IsRequired: false, IsDeprecated: false }).Select(x => $@"
this.{x.Name} = {x.ParameterName};").Inject()}
}}
{(modelData.Properties.Any(static x => !x.IsDeprecated) ? $@"
" : " ")}
{(properties.Any(static x => !x.IsDeprecated) ? $@"
/// <summary>
/// Initializes a new instance of the <see cref=""{modelData.ClassName}"" /> class.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,5 @@ public sealed partial class ChatCompletionMeta
/// </summary>
[global::Newtonsoft.Json.JsonExtensionData]
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary<string, object>();

/// <summary>
/// Initializes a new instance of the <see cref="ChatCompletionMeta" /> class.
/// </summary>
public ChatCompletionMeta(
)
{
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,5 @@ public sealed partial class ChatCompletionVllmStreamingMessageMeta
/// </summary>
[global::Newtonsoft.Json.JsonExtensionData]
public global::System.Collections.Generic.IDictionary<string, object> AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary<string, object>();

/// <summary>
/// Initializes a new instance of the <see cref="ChatCompletionVllmStreamingMessageMeta" /> class.
/// </summary>
public ChatCompletionVllmStreamingMessageMeta(
)
{
}
}
}
Loading

0 comments on commit ff2004d

Please sign in to comment.