Skip to content

Commit

Permalink
OpenAPI code gen collections fix (zio#2620)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 20, 2024
1 parent 49d4aa4 commit 2072383
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 35 deletions.
58 changes: 40 additions & 18 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package zio.http.gen.openapi

import scala.annotation.tailrec

import zio.Chunk

import zio.http.Method
import zio.http.endpoint.openapi.OpenAPI.ReferenceOr
import zio.http.endpoint.openapi.{JsonSchema, OpenAPI}
import zio.http.gen.scala.Code
import zio.http.gen.scala.Code.ScalaType

import scala.annotation.tailrec

object EndpointGen {

private object Inline {
Expand All @@ -21,6 +20,7 @@ object EndpointGen {
private val DataImports =
List(
Code.Import("zio.schema._"),
Code.Import("zio._"),
)

private val RequestBodyRef = "#/components/requestBodies/(.*)".r
Expand Down Expand Up @@ -156,11 +156,17 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case schema if schema.isPrimitive =>
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case schema =>
val code = schemaToCode(schema, openAPI, Inline.RequestBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -195,11 +201,17 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case schema if schema.isPrimitive =>
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case schema =>
val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -238,11 +250,17 @@ final case class EndpointGen() {
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case schema if schema.isPrimitive =>
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.ArrayType(Some(JsonSchema.RefSchema(SchemaRef(ref)))) =>
s"Chunk[$ref]"
case JsonSchema.ArrayType(Some(schema)) if schema.isPrimitive =>
s"Chunk[${schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString}]"
case JsonSchema.ArrayType(None) =>
"Chunk[String]"
case schema if schema.isPrimitive =>
schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString
case schema =>
case schema =>
val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for request body $schema"),
Expand Down Expand Up @@ -636,9 +654,10 @@ final case class EndpointGen() {
)
case JsonSchema.Number(_) => None
case JsonSchema.ArrayType(None) => None
case JsonSchema.ArrayType(Some(schema)) => schemaToCode(schema, openAPI, name, annotations)
case JsonSchema.ArrayType(Some(schema)) =>
schemaToCode(schema, openAPI, name, annotations)
// TODO use additionalProperties
case JsonSchema.Object(properties, additionalProperties, required) =>
case JsonSchema.Object(properties, additionalProperties, required) =>
val fields = properties.map { case (name, schema) =>
val field = schemaToField(schema, openAPI, name, annotations)
.getOrElse(
Expand All @@ -648,7 +667,10 @@ final case class EndpointGen() {
if (required.contains(name)) field else field.copy(fieldType = field.fieldType.opt)
}.toList
val nested = properties.collect {
case (name, schema) if !schema.isInstanceOf[JsonSchema.RefSchema] && !schema.isPrimitive =>
case (name, schema)
if !schema.isInstanceOf[JsonSchema.RefSchema]
&& !schema.isPrimitive
&& !schema.isCollection =>
schemaToCode(schema, openAPI, name.capitalize, Chunk.empty)
.getOrElse(
throw new Exception(s"Could not generate code for field $name of object $name"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ object CodeGen {
case col: Code.Collection =>
col match {
case Code.Collection.Seq(elementType) =>
s"Seq[${render(basePackage)(elementType)}]"
s"Chunk[${render(basePackage)(elementType)}]"
case Code.Collection.Set(elementType) =>
s"Set[${render(basePackage)(elementType)}]"
case Code.Collection.Map(elementType) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package zio.http.gen.model

import zio.Chunk
import zio.schema._

case class UserNameArray(id: Int, name: Chunk[String])
object UserNameArray {
implicit val codec: Schema[UserNameArray] = DeriveSchema.gen[UserNameArray]
}
132 changes: 128 additions & 4 deletions zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package zio.http.gen.openapi

import java.nio.file._

import zio._
import zio.test._

import zio.http._
import zio.http.codec.HeaderCodec
import zio.http.codec.HttpCodec.{query, queryInt}
Expand All @@ -13,6 +9,9 @@ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle.Inline
import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen}
import zio.http.gen.model._
import zio.http.gen.scala.Code
import zio.test._

import java.nio.file._

object EndpointGenSpec extends ZIOSpecDefault {
override def spec: Spec[TestEnvironment with Scope, Any] =
Expand Down Expand Up @@ -594,6 +593,68 @@ object EndpointGenSpec extends ZIOSpecDefault {
)
assertTrue(scala.files.head == expected)
},
test("seq request") {
val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").in[Chunk[User]]
val openAPI = OpenAPIGen.fromEndpoints(endpoint)
val scala = EndpointGen.fromOpenAPI(openAPI)
val expected = Code.File(
List("api", "v1", "Users.scala"),
pkgPath = List("api", "v1"),
imports = List(Code.Import.FromBase(path = "component._")),
objects = List(
Code.Object(
"Users",
Map(
Code.Field("get") -> Code.EndpointCode(
Method.GET,
Code.PathPatternCode(segments =
List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")),
),
queryParamsCode = Set.empty,
headersCode = Code.HeadersCode.empty,
inCode = Code.InCode("Chunk[User]"),
outCodes = Nil,
errorsCode = Nil,
),
),
),
),
caseClasses = Nil,
enums = Nil,
)
assertTrue(scala.files.head == expected)
},
test("seq response") {
val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[Chunk[User]]
val openAPI = OpenAPIGen.fromEndpoints(endpoint)
val scala = EndpointGen.fromOpenAPI(openAPI)
val expected = Code.File(
List("api", "v1", "Users.scala"),
pkgPath = List("api", "v1"),
imports = List(Code.Import.FromBase(path = "component._")),
objects = List(
Code.Object(
"Users",
Map(
Code.Field("get") -> Code.EndpointCode(
Method.GET,
Code.PathPatternCode(segments =
List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")),
),
queryParamsCode = Set.empty,
headersCode = Code.HeadersCode.empty,
inCode = Code.InCode("Unit"),
outCodes = List(Code.OutCode.json("Chunk[User]", Status.Ok)),
errorsCode = Nil,
),
),
),
),
caseClasses = Nil,
enums = Nil,
)
assertTrue(scala.files.head == expected)
},
),
suite("data gen spec")(
test("generates case class, companion object and schema") {
Expand Down Expand Up @@ -929,6 +990,69 @@ object EndpointGenSpec extends ZIOSpecDefault {

assertTrue(scala.files.head == expected)
},
test("generates case class with seq field for request") {
val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User]
val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint)
val scala = EndpointGen.fromOpenAPI(openAPI)
val fields = List(
Code.Field("id", Code.Primitive.ScalaInt),
Code.Field("name", Code.Primitive.ScalaString),
)
val expected = Code.File(
List("api", "v1", "Users.scala"),
pkgPath = List("api", "v1"),
imports = List(Code.Import.FromBase(path = "component._")),
objects = List(
Code.Object(
"Users",
schema = false,
endpoints = Map(
Code.Field("post") -> Code.EndpointCode(
Method.POST,
Code.PathPatternCode(segments =
List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")),
),
queryParamsCode = Set.empty,
headersCode = Code.HeadersCode.empty,
inCode = Code.InCode("POST.RequestBody"),
outCodes = List(Code.OutCode.json("POST.ResponseBody", Status.Ok)),
errorsCode = Nil,
),
),
objects = List(
Code.Object(
"POST",
schema = false,
endpoints = Map.empty,
objects = Nil,
caseClasses = List(
Code
.CaseClass(
"RequestBody",
fields = List(
Code.Field("id", Code.Primitive.ScalaInt),
Code.Field("name", Code.Primitive.ScalaString.seq),
),
companionObject = Some(Code.Object.schemaCompanion("RequestBody")),
),
Code.CaseClass(
"ResponseBody",
fields = fields,
companionObject = Some(Code.Object.schemaCompanion("ResponseBody")),
),
),
enums = Nil,
),
),
caseClasses = Nil,
enums = Nil,
),
),
caseClasses = Nil,
enums = Nil,
)
assertTrue(scala.files.head == expected)
},
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ sealed trait JsonSchema extends Product with Serializable { self =>
case _ => false
}

def isCollection: Boolean = self match {
case _: JsonSchema.ArrayType => true
case _ => false
}

}

object JsonSchema {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package zio.http.endpoint.openapi

import java.util.UUID

import scala.annotation.tailrec
import scala.collection.{immutable, mutable}

import zio.Chunk
import zio.http._
import zio.http.codec.HttpCodec.Metadata
import zio.http.codec._
import zio.http.endpoint._
import zio.http.endpoint.openapi.JsonSchema.SchemaStyle
import zio.json.EncoderOps
import zio.json.ast.Json

import zio.schema.Schema.Record
import zio.schema.codec.JsonCodec
import zio.schema.{Schema, TypeId}

import zio.http._
import zio.http.codec.HttpCodec.Metadata
import zio.http.codec._
import zio.http.endpoint._
import zio.http.endpoint.openapi.JsonSchema.SchemaStyle
import java.util.UUID
import scala.annotation.tailrec
import scala.collection.{immutable, mutable}

object OpenAPIGen {
private val PathWildcard = "pathWildcard"
Expand Down Expand Up @@ -662,12 +659,43 @@ object OpenAPIGen {
(endpoint.input.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content)
++ endpoint.error.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content)
++ endpoint.output.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content)).collect {
case MetaCodec(HttpCodec.Content(schema, _, _, _), _) if nominal(schema, referenceType).isDefined =>
case MetaCodec(HttpCodec.Content(schema, _, _, _), _) if nominal(schema, referenceType).isDefined =>
val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType)
schemas.children.map { case (key, schema) =>
OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema)
} + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get ->
OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema))))

case MetaCodec(HttpCodec.Content(setSchema, _, _, _), _)
if setSchema.isInstanceOf[Schema.Set[_]]
&& nominal(setSchema.asInstanceOf[Schema.Set[_]].elementSchema, referenceType).isDefined =>
val schema = setSchema.asInstanceOf[Schema.Set[_]].elementSchema
val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType)
schemas.children.map { case (key, schema) =>
OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema)
} + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get ->
OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema))))

case MetaCodec(HttpCodec.Content(seqSchema, _, _, _), _)
if seqSchema.isInstanceOf[Schema.Sequence[_, _, _]]
&& nominal(seqSchema.asInstanceOf[Schema.Sequence[_, _, _]].elementSchema, referenceType).isDefined =>
val schema = seqSchema.asInstanceOf[Schema.Sequence[_, _, _]].elementSchema
val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType)
schemas.children.map { case (key, schema) =>
OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema)
} + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get ->
OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema))))

case MetaCodec(HttpCodec.Content(mapSchema, _, _, _), _)
if mapSchema.isInstanceOf[Schema.Map[_, _]]
&& nominal(mapSchema.asInstanceOf[Schema.Map[_, _]].valueSchema, referenceType).isDefined =>
val schema = mapSchema.asInstanceOf[Schema.Map[_, _]].valueSchema
val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType)
schemas.children.map { case (key, schema) =>
OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema)
} + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get ->
OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema))))

case MetaCodec(HttpCodec.ContentStream(schema, _, _, _), _) if nominal(schema, referenceType).isDefined =>
val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType)
schemas.children.map { case (key, schema) =>
Expand Down

0 comments on commit 2072383

Please sign in to comment.