Skip to content

Commit

Permalink
feat: Proper imports for generated openapi code (#2718)
Browse files Browse the repository at this point in the history
* feat: load scalaFmt only once

When generating multiple files, that are all formatted with the same scalafmt.conf, load it only once.

In addition, writeFiles now returns a list of written paths. This will allow to use it in a source generation step e.g. in sbt.

* feat: Proper imports for generated openapi code

Extend the code generator, such that each generation step
can add imports, which are then rendered at the file or object level.

This fixes a few issues where imports were missing for generated case
classes, that had `Chunk` fields. It also unifies the handling of UUIDs
and opens the dor for additional future configuration options. For
example the collection type can be made configurable, or user defined
types can be added to the imports.

`writeFiles` now returns a list of written files. This allows for use
in sbt code generation tasks. I will add an sbt plugin in a later PR.
  • Loading branch information
runtologist authored Mar 10, 2024
1 parent 91bcf2a commit ec7f24a
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 139 deletions.
112 changes: 59 additions & 53 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,28 @@ final case class EndpointGen() {
case s => Code.PathSegmentCode(s, Code.CodecType.Literal)
}

val (imports, endpoints) =
List(
pathItem.delete.map(op => fieldName(op, "delete") -> endpoint(segments, op, openAPI, Method.DELETE)),
pathItem.get.map(op => fieldName(op, "get") -> endpoint(segments, op, openAPI, Method.GET)),
pathItem.head.map(op => fieldName(op, "head") -> endpoint(segments, op, openAPI, Method.HEAD)),
pathItem.options.map(op => fieldName(op, "options") -> endpoint(segments, op, openAPI, Method.OPTIONS)),
pathItem.post.map(op => fieldName(op, "post") -> endpoint(segments, op, openAPI, Method.POST)),
pathItem.put.map(op => fieldName(op, "put") -> endpoint(segments, op, openAPI, Method.PUT)),
pathItem.patch.map(op => fieldName(op, "patch") -> endpoint(segments, op, openAPI, Method.PATCH)),
pathItem.trace.map(op => fieldName(op, "trace") -> endpoint(segments, op, openAPI, Method.TRACE)),
).flatten.map { case (name, (imports, endpoint)) =>
(imports, name -> endpoint)
}.unzip
Code.File(
packageName.split('.').toList :+ s"$className.scala",
pkgPath = packageName.split('.').toList,
imports = List(Code.Import.FromBase("component._")),
imports = (Code.Import.FromBase("component._") :: imports.flatten).distinct,
objects = List(
Code.Object(
className,
schema = false,
endpoints = List(
pathItem.delete.map(op => fieldName(op, "delete") -> endpoint(segments, op, openAPI, Method.DELETE)),
pathItem.get.map(op => fieldName(op, "get") -> endpoint(segments, op, openAPI, Method.GET)),
pathItem.head.map(op => fieldName(op, "head") -> endpoint(segments, op, openAPI, Method.HEAD)),
pathItem.options.map(op => fieldName(op, "options") -> endpoint(segments, op, openAPI, Method.OPTIONS)),
pathItem.post.map(op => fieldName(op, "post") -> endpoint(segments, op, openAPI, Method.POST)),
pathItem.put.map(op => fieldName(op, "put") -> endpoint(segments, op, openAPI, Method.PUT)),
pathItem.patch.map(op => fieldName(op, "patch") -> endpoint(segments, op, openAPI, Method.PATCH)),
pathItem.trace.map(op => fieldName(op, "trace") -> endpoint(segments, op, openAPI, Method.TRACE)),
).flatten.toMap,
endpoints = endpoints.toMap,
objects = anonymousTypes.values.toList,
caseClasses = Nil,
enums = Nil,
Expand All @@ -129,26 +133,26 @@ final case class EndpointGen() {
op: OpenAPI.Operation,
openAPI: OpenAPI,
method: Method,
) = {
): (List[Code.Import], Code.EndpointCode) = {

val params = op.parameters.map {
val params = op.parameters.map {
case OpenAPI.ReferenceOr.Or(param: OpenAPI.Parameter) => param
case OpenAPI.ReferenceOr.Reference(ParameterRef(key), _, _) => resolveParameterRef(openAPI, key)
case other => throw new Exception(s"Unexpected parameter definition: $other")
}
// TODO: Resolve query and header parameters from components
val queryParams = params.collect {
val queryParams = params.collect {
case p if p.in == "query" =>
schemaToQueryParamCodec(
p.schema.get.asInstanceOf[ReferenceOr.Or[JsonSchema]].value,
openAPI,
p.name,
)
}
val headers = params.collect { case p if p.in == "header" => Code.HeaderCode(p.name) }.toList
val inType =
val headers = params.collect { case p if p.in == "header" => Code.HeaderCode(p.name) }.toList
val (inImports, inType) =
op.requestBody.flatMap {
case OpenAPI.ReferenceOr.Reference(RequestBodyRef(key), _, _) => Some(key)
case OpenAPI.ReferenceOr.Reference(RequestBodyRef(key), _, _) => Some(Nil -> key)
case OpenAPI.ReferenceOr.Or(body: OpenAPI.RequestBody) =>
body.content
.get("application/json")
Expand All @@ -157,9 +161,9 @@ final case class EndpointGen() {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null =>
Inline.Null
Nil -> Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) =>
ref
Nil -> ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
Expand All @@ -176,29 +180,29 @@ final case class EndpointGen() {
caseClasses = code.caseClasses,
enums = code.enums,
)
s"$method.${Inline.RequestBodyType}"
Nil -> s"$method.${Inline.RequestBodyType}"
}
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => Nil -> ref
case other => throw new Exception(s"Unexpected request body schema: $other")
}
}
case other => throw new Exception(s"Unexpected request body definition: $other")
}.getOrElse("Unit")
}.getOrElse(Nil -> "Unit")

val outCodes: Iterable[Code.OutCode] =
val (outImports: Iterable[List[Code.Import]], outCodes: Iterable[Code.OutCode]) =
// TODO: ignore default for now. Not sure how to handle it
op.responses.collect {
case (OpenAPI.StatusOrDefault.StatusValue(status), OpenAPI.ReferenceOr.Reference(ResponseRef(key), _, _)) =>
val response = resolveResponseRef(openAPI, key)
Code.OutCode(
outType = response.content
val response = resolveResponseRef(openAPI, key)
val (imports, code) =
response.content
.get("application/json")
.map { mt =>
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.Null => Nil -> Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => Nil -> ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
Expand All @@ -221,27 +225,30 @@ final case class EndpointGen() {
enums = obj.enums ++ code.enums,
)
}
s"$method.${Inline.ResponseBodyType}"
Nil -> s"$method.${Inline.ResponseBodyType}"
}
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => Nil -> ref
case other => throw new Exception(s"Unexpected response body schema: $other")
}
}
.getOrElse("Unit"),
status = status,
mediaType = Some("application/json"),
doc = None,
)
.getOrElse(Nil -> "Unit")
imports ->
Code.OutCode(
outType = code,
status = status,
mediaType = Some("application/json"),
doc = None,
)
case (OpenAPI.StatusOrDefault.StatusValue(status), OpenAPI.ReferenceOr.Or(response: OpenAPI.Response)) =>
Code.OutCode(
outType = response.content
val (imports, code) =
response.content
.get("application/json")
.map { mt =>
mt.schema match {
case ReferenceOr.Or(s) =>
s.withoutAnnotations match {
case JsonSchema.Null => Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => ref
case JsonSchema.Null => Nil -> Inline.Null
case JsonSchema.RefSchema(SchemaRef(ref)) => Nil -> ref
case schema if schema.isPrimitive || schema.isCollection =>
CodeGen.render("")(schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType)
case schema =>
Expand All @@ -264,20 +271,23 @@ final case class EndpointGen() {
enums = obj.enums ++ code.enums,
)
}
s"$method.${Inline.ResponseBodyType}"
Nil -> s"$method.${Inline.ResponseBodyType}"
}
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref
case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => Nil -> ref
case other => throw new Exception(s"Unexpected response body schema: $other")
}
}
.getOrElse("Unit"),
.getOrElse(Nil -> "Unit")
imports -> Code.OutCode(
outType = code,
status = status,
mediaType = Some("application/json"),
doc = None,
)
}
}.unzip

Code.EndpointCode(
val imports = inImports ++ outImports.flatten
val code = Code.EndpointCode(
method = method,
pathPatternCode = Code.PathPatternCode(segments),
queryParamsCode = queryParams,
Expand All @@ -286,7 +296,7 @@ final case class EndpointGen() {
outCodes = outCodes.filterNot(_.status.isError).toList,
errorsCode = outCodes.filter(_.status.isError).toList,
)

imports -> code
}

private def parameterToPathCodec(openAPI: OpenAPI, param: OpenAPI.Parameter): Code.PathSegmentCode = {
Expand Down Expand Up @@ -520,7 +530,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = dataImports(caseClasses.flatMap(_.fields)) ++
imports = DataImports ++
(if (noDiscriminator || caseNames.nonEmpty) List(Code.Import("zio.schema.annotation._")) else Nil),
objects = Nil,
caseClasses = Nil,
Expand Down Expand Up @@ -567,7 +577,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = dataImports(fields),
imports = DataImports,
objects = Nil,
caseClasses = List(
Code.CaseClass(
Expand Down Expand Up @@ -622,7 +632,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = dataImports(caseClasses.flatMap(_.fields)) ++
imports = DataImports ++
(if (noDiscriminator || caseNames.nonEmpty) List(Code.Import("zio.schema.annotation._")) else Nil),
objects = Nil,
caseClasses = Nil,
Expand Down Expand Up @@ -669,7 +679,7 @@ final case class EndpointGen() {
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = dataImports(fields),
imports = DataImports,
objects = nestedObjects.toList,
caseClasses = List(
Code.CaseClass(
Expand Down Expand Up @@ -780,8 +790,4 @@ final case class EndpointGen() {
}
}

private def dataImports(fields: Iterable[Code.Field]) = {
if (fields.exists(_.fieldType.isInstanceOf[Code.Collection.Seq])) List(Code.Import("zio._"))
else Nil
} ++ DataImports
}
Loading

0 comments on commit ec7f24a

Please sign in to comment.