From 20723837d54de195b38e7b3fefe7f5ea56b66367 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Sun, 21 Jan 2024 00:36:57 +0100 Subject: [PATCH] OpenAPI code gen collections fix (#2620) --- .../zio/http/gen/openapi/EndpointGen.scala | 58 +++++--- .../scala/zio/http/gen/scala/CodeGen.scala | 2 +- .../zio/http/gen/model/UserNameArray.scala | 9 ++ .../http/gen/openapi/EndpointGenSpec.scala | 132 +++++++++++++++++- .../http/endpoint/openapi/JsonSchema.scala | 5 + .../http/endpoint/openapi/OpenAPIGen.scala | 52 +++++-- 6 files changed, 223 insertions(+), 35 deletions(-) create mode 100644 zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index 972b30190a..2ff57082c6 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -1,15 +1,14 @@ package zio.http.gen.openapi -import scala.annotation.tailrec - import zio.Chunk - import zio.http.Method import zio.http.endpoint.openapi.OpenAPI.ReferenceOr import zio.http.endpoint.openapi.{JsonSchema, OpenAPI} import zio.http.gen.scala.Code import zio.http.gen.scala.Code.ScalaType +import scala.annotation.tailrec + object EndpointGen { private object Inline { @@ -21,6 +20,7 @@ object EndpointGen { private val DataImports = List( Code.Import("zio.schema._"), + Code.Import("zio._"), ) private val RequestBodyRef = "#/components/requestBodies/(.*)".r @@ -156,11 +156,17 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case schema if schema.isPrimitive => + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => + s"Chunk[$ref]" + case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => + s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" + case JsonSchema.ArrayType(None) => + "Chunk[String]" + case schema if schema.isPrimitive => schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case schema => val code = schemaToCode(schema, openAPI, Inline.RequestBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -195,11 +201,17 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case schema if schema.isPrimitive => + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => + s"Chunk[$ref]" + case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => + s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" + case JsonSchema.ArrayType(None) => + "Chunk[String]" + case schema if schema.isPrimitive => schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case schema => val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -238,11 +250,17 @@ final case class EndpointGen() { mt.schema match { case ReferenceOr.Or(s) => s.withoutAnnotations match { - case JsonSchema.Null => Inline.Null - case JsonSchema.RefSchema(SchemaRef(ref)) => ref - case schema if schema.isPrimitive => + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) => + s"Chunk[$ref]" + case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive => + s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]" + case JsonSchema.ArrayType(None) => + "Chunk[String]" + case schema if schema.isPrimitive => schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString - case schema => + case schema => val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for request body $schema"), @@ -636,9 +654,10 @@ final case class EndpointGen() { ) case JsonSchema.Number(_) => None case JsonSchema.ArrayType(None) => None - case JsonSchema.ArrayType(Some(schema)) => schemaToCode(schema, openAPI, name, annotations) + case JsonSchema.ArrayType(Some(schema)) => + schemaToCode(schema, openAPI, name, annotations) // TODO use additionalProperties - case JsonSchema.Object(properties, additionalProperties, required) => + case JsonSchema.Object(properties, additionalProperties, required) => val fields = properties.map { case (name, schema) => val field = schemaToField(schema, openAPI, name, annotations) .getOrElse( @@ -648,7 +667,10 @@ final case class EndpointGen() { if (required.contains(name)) field else field.copy(fieldType = field.fieldType.opt) }.toList val nested = properties.collect { - case (name, schema) if !schema.isInstanceOf[JsonSchema.RefSchema] && !schema.isPrimitive => + case (name, schema) + if !schema.isInstanceOf[JsonSchema.RefSchema] + && !schema.isPrimitive + && !schema.isCollection => schemaToCode(schema, openAPI, name.capitalize, Chunk.empty) .getOrElse( throw new Exception(s"Could not generate code for field $name of object $name"), diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala index 082fc981fb..4760289be4 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala @@ -97,7 +97,7 @@ object CodeGen { case col: Code.Collection => col match { case Code.Collection.Seq(elementType) => - s"Seq[${render(basePackage)(elementType)}]" + s"Chunk[${render(basePackage)(elementType)}]" case Code.Collection.Set(elementType) => s"Set[${render(basePackage)(elementType)}]" case Code.Collection.Map(elementType) => diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala new file mode 100644 index 0000000000..4a17fe8abe --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/UserNameArray.scala @@ -0,0 +1,9 @@ +package zio.http.gen.model + +import zio.Chunk +import zio.schema._ + +case class UserNameArray(id: Int, name: Chunk[String]) +object UserNameArray { + implicit val codec: Schema[UserNameArray] = DeriveSchema.gen[UserNameArray] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala index 126bfbc6e1..42fa303b52 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala @@ -1,10 +1,6 @@ package zio.http.gen.openapi -import java.nio.file._ - import zio._ -import zio.test._ - import zio.http._ import zio.http.codec.HeaderCodec import zio.http.codec.HttpCodec.{query, queryInt} @@ -13,6 +9,9 @@ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle.Inline import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} import zio.http.gen.model._ import zio.http.gen.scala.Code +import zio.test._ + +import java.nio.file._ object EndpointGenSpec extends ZIOSpecDefault { override def spec: Spec[TestEnvironment with Scope, Any] = @@ -594,6 +593,68 @@ object EndpointGenSpec extends ZIOSpecDefault { ) assertTrue(scala.files.head == expected) }, + test("seq request") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").in[Chunk[User]] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Chunk[User]"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("seq response") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[Chunk[User]] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = List(Code.OutCode.json("Chunk[User]", Status.Ok)), + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, ), suite("data gen spec")( test("generates case class, companion object and schema") { @@ -929,6 +990,69 @@ object EndpointGenSpec extends ZIOSpecDefault { assertTrue(scala.files.head == expected) }, + test("generates case class with seq field for request") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] + val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString), + ) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + schema = false, + endpoints = Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("POST.RequestBody"), + outCodes = List(Code.OutCode.json("POST.ResponseBody", Status.Ok)), + errorsCode = Nil, + ), + ), + objects = List( + Code.Object( + "POST", + schema = false, + endpoints = Map.empty, + objects = Nil, + caseClasses = List( + Code + .CaseClass( + "RequestBody", + fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString.seq), + ), + companionObject = Some(Code.Object.schemaCompanion("RequestBody")), + ), + Code.CaseClass( + "ResponseBody", + fields = fields, + companionObject = Some(Code.Object.schemaCompanion("ResponseBody")), + ), + ), + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, ), ) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index 04d37a4d89..52b0edf8e3 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -191,6 +191,11 @@ sealed trait JsonSchema extends Product with Serializable { self => case _ => false } + def isCollection: Boolean = self match { + case _: JsonSchema.ArrayType => true + case _ => false + } + } object JsonSchema { diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index c2460e942d..407f875de1 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -1,23 +1,20 @@ package zio.http.endpoint.openapi -import java.util.UUID - -import scala.annotation.tailrec -import scala.collection.{immutable, mutable} - import zio.Chunk +import zio.http._ +import zio.http.codec.HttpCodec.Metadata +import zio.http.codec._ +import zio.http.endpoint._ +import zio.http.endpoint.openapi.JsonSchema.SchemaStyle import zio.json.EncoderOps import zio.json.ast.Json - import zio.schema.Schema.Record import zio.schema.codec.JsonCodec import zio.schema.{Schema, TypeId} -import zio.http._ -import zio.http.codec.HttpCodec.Metadata -import zio.http.codec._ -import zio.http.endpoint._ -import zio.http.endpoint.openapi.JsonSchema.SchemaStyle +import java.util.UUID +import scala.annotation.tailrec +import scala.collection.{immutable, mutable} object OpenAPIGen { private val PathWildcard = "pathWildcard" @@ -662,12 +659,43 @@ object OpenAPIGen { (endpoint.input.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content) ++ endpoint.error.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content) ++ endpoint.output.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content)).collect { - case MetaCodec(HttpCodec.Content(schema, _, _, _), _) if nominal(schema, referenceType).isDefined => + case MetaCodec(HttpCodec.Content(schema, _, _, _), _) if nominal(schema, referenceType).isDefined => val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) schemas.children.map { case (key, schema) => OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + + case MetaCodec(HttpCodec.Content(setSchema, _, _, _), _) + if setSchema.isInstanceOf[Schema.Set[_]] + && nominal(setSchema.asInstanceOf[Schema.Set[_]].elementSchema, referenceType).isDefined => + val schema = setSchema.asInstanceOf[Schema.Set[_]].elementSchema + val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) + schemas.children.map { case (key, schema) => + OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) + } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> + OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + + case MetaCodec(HttpCodec.Content(seqSchema, _, _, _), _) + if seqSchema.isInstanceOf[Schema.Sequence[_, _, _]] + && nominal(seqSchema.asInstanceOf[Schema.Sequence[_, _, _]].elementSchema, referenceType).isDefined => + val schema = seqSchema.asInstanceOf[Schema.Sequence[_, _, _]].elementSchema + val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) + schemas.children.map { case (key, schema) => + OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) + } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> + OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + + case MetaCodec(HttpCodec.Content(mapSchema, _, _, _), _) + if mapSchema.isInstanceOf[Schema.Map[_, _]] + && nominal(mapSchema.asInstanceOf[Schema.Map[_, _]].valueSchema, referenceType).isDefined => + val schema = mapSchema.asInstanceOf[Schema.Map[_, _]].valueSchema + val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) + schemas.children.map { case (key, schema) => + OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) + } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> + OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + case MetaCodec(HttpCodec.ContentStream(schema, _, _, _), _) if nominal(schema, referenceType).isDefined => val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) schemas.children.map { case (key, schema) =>