Skip to content

Commit

Permalink
In the code generation from OpenAPI schema, added the handling for th…
Browse files Browse the repository at this point in the history
…e string formats `date`, `date-time`, `time` and `duration` to be generated as `LocalDate`, `Instant`, `LocalTime` and `Duration`, and also to configure own formart to type mapping.
  • Loading branch information
gregor-rayman committed Dec 11, 2024
1 parent 3728252 commit eef8adc
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ final case class Config(
commonFieldsOnSuperType: Boolean,
generateSafeTypeAliases: Boolean,
fieldNamesNormalization: NormalizeFields,
stringFormatTypes: Map[String, String],
)
object Config {

Expand Down Expand Up @@ -73,11 +74,15 @@ object Config {
enableAutomatic = false,
manualOverrides = Map.empty,
),
stringFormatTypes = Map.empty,
)

def config: zio.Config[Config] = (
zio.Config.boolean("common-fields-on-super-type").withDefault(Config.default.commonFieldsOnSuperType) ++
zio.Config.boolean("generate-safe-type-aliases").withDefault(Config.default.generateSafeTypeAliases) ++
NormalizeFields.config.nested("fields-normalization")
NormalizeFields.config.nested("fields-normalization") ++ zio.Config.table(
"string-format-types",
zio.Config.string,
)
).to[Config]
}
26 changes: 26 additions & 0 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,14 @@ final case class EndpointGen(config: Config) {
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Long)
case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.UUID)
case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.LocalDate)
case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Instant)
case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.LocalTime)
case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Duration)
case JsonSchema.String(_, _, _, _) =>
Code.PathSegmentCode(name = name, segmentType = Code.CodecType.String)
case JsonSchema.Boolean =>
Expand Down Expand Up @@ -719,6 +727,14 @@ final case class EndpointGen(config: Config) {
Code.QueryParamCode(name = name, queryType = Code.CodecType.Long)
case JsonSchema.Integer(JsonSchema.IntegerFormat.Timestamp, _, _, _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.Long)
case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.LocalDate)
case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.Instant)
case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.Duration)
case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.LocalTime)
case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, _, _) =>
Code.QueryParamCode(name = name, queryType = Code.CodecType.UUID)
case JsonSchema.String(_, _, _, _) =>
Expand Down Expand Up @@ -1237,9 +1253,19 @@ final case class EndpointGen(config: Config) {
val annotations = addNumericValidations[Long](exclusiveMin, exclusiveMax)
Some(Code.Field(name, Code.Primitive.ScalaLong, annotations, config.fieldNamesNormalization))

case JsonSchema.String(Some(format), _, _, _) if config.stringFormatTypes.contains(format.value) =>
Some(Code.Field(name, Code.TypeRef(config.stringFormatTypes(format.value)), config.fieldNamesNormalization))
case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, maxLength, minLength) =>
val annotations = addStringValidations(minLength, maxLength)
Some(Code.Field(name, Code.Primitive.ScalaUUID, annotations, config.fieldNamesNormalization))
case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) =>
Some(Code.Field(name, Code.Primitive.ScalaLocalDate, config.fieldNamesNormalization))
case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) =>
Some(Code.Field(name, Code.Primitive.ScalaInstant, config.fieldNamesNormalization))
case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) =>
Some(Code.Field(name, Code.Primitive.ScalaTime, config.fieldNamesNormalization))
case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) =>
Some(Code.Field(name, Code.Primitive.ScalaDuration, config.fieldNamesNormalization))
case JsonSchema.String(_, _, maxLength, minLength) =>
val annotations = addStringValidations(minLength, maxLength)
Some(Code.Field(name, Code.Primitive.ScalaString, annotations, config.fieldNamesNormalization))
Expand Down
30 changes: 19 additions & 11 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,21 @@ object Code {
sealed trait Primitive extends ScalaType

object Primitive {
case object ScalaInt extends Primitive
case object ScalaLong extends Primitive
case object ScalaDouble extends Primitive
case object ScalaFloat extends Primitive
case object ScalaChar extends Primitive
case object ScalaByte extends Primitive
case object ScalaShort extends Primitive
case object ScalaBoolean extends Primitive
case object ScalaUnit extends Primitive
case object ScalaUUID extends Primitive
case object ScalaString extends Primitive
case object ScalaInt extends Primitive
case object ScalaLong extends Primitive
case object ScalaDouble extends Primitive
case object ScalaFloat extends Primitive
case object ScalaChar extends Primitive
case object ScalaByte extends Primitive
case object ScalaShort extends Primitive
case object ScalaBoolean extends Primitive
case object ScalaUnit extends Primitive
case object ScalaUUID extends Primitive
case object ScalaLocalDate extends Primitive
case object ScalaInstant extends Primitive
case object ScalaTime extends Primitive
case object ScalaDuration extends Primitive
case object ScalaString extends Primitive
}

final case class EndpointCode(
Expand All @@ -253,6 +257,10 @@ object Code {
case object Long extends CodecType
case object String extends CodecType
case object UUID extends CodecType
case object LocalDate extends CodecType
case object LocalTime extends CodecType
case object Duration extends CodecType
case object Instant extends CodecType
case class Aliased(underlying: CodecType, newtypeName: String) extends CodecType
}
final case class QueryParamCode(name: String, queryType: CodecType)
Expand Down
61 changes: 37 additions & 24 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import java.nio.charset.StandardCharsets
import java.nio.file.StandardOpenOption._
import java.nio.file._

import scala.util.matching.Regex

object CodeGen {

private val EndpointImports =
Expand Down Expand Up @@ -230,18 +232,21 @@ object CodeGen {
val multipleAnnotationsAboveContent = if (annotationValues.size > 1) "\n" + content else content
allImports -> annotationValues.mkString("", "\n", multipleAnnotationsAboveContent)

case Code.Primitive.ScalaBoolean => Nil -> "Boolean"
case Code.Primitive.ScalaByte => Nil -> "Byte"
case Code.Primitive.ScalaChar => Nil -> "Char"
case Code.Primitive.ScalaDouble => Nil -> "Double"
case Code.Primitive.ScalaFloat => Nil -> "Float"
case Code.Primitive.ScalaInt => Nil -> "Int"
case Code.Primitive.ScalaLong => Nil -> "Long"
case Code.Primitive.ScalaShort => Nil -> "Short"
case Code.Primitive.ScalaString => Nil -> "String"
case Code.Primitive.ScalaUnit => Nil -> "Unit"
case Code.Primitive.ScalaUUID => List(Code.Import("java.util.UUID")) -> "UUID"
case Code.ScalaType.Inferred => Nil -> ""
case Code.Primitive.ScalaBoolean => Nil -> "Boolean"
case Code.Primitive.ScalaByte => Nil -> "Byte"
case Code.Primitive.ScalaChar => Nil -> "Char"
case Code.Primitive.ScalaDouble => Nil -> "Double"
case Code.Primitive.ScalaFloat => Nil -> "Float"
case Code.Primitive.ScalaInt => Nil -> "Int"
case Code.Primitive.ScalaLong => Nil -> "Long"
case Code.Primitive.ScalaShort => Nil -> "Short"
case Code.Primitive.ScalaString => Nil -> "String"
case Code.Primitive.ScalaUnit => Nil -> "Unit"
case Code.Primitive.ScalaUUID => List(Code.Import("java.util.UUID")) -> "UUID"
case Code.Primitive.ScalaLocalDate => List(Code.Import("java.time.LocalDate")) -> "LocalDate"
case Code.Primitive.ScalaInstant => List(Code.Import("java.time.Instant")) -> "Instant"
case Code.Primitive.ScalaTime => List(Code.Import("java.time.LocalTime")) -> "LocalTime"
case Code.ScalaType.Inferred => Nil -> ""

case Code.EndpointCode(method, pathPatternCode, queryParamsCode, headersCode, inCode, outCodes, errorsCode) =>
val (queryImports, queryContent) = queryParamsCode.map(renderQueryCode).unzip
Expand All @@ -266,12 +271,16 @@ object CodeGen {

def renderSegmentType(name: String, segmentType: Code.CodecType): (String, List[Code.Import]) =
segmentType match {
case Code.CodecType.Boolean => s"""bool("$name")""" -> Nil
case Code.CodecType.Int => s"""int("$name")""" -> Nil
case Code.CodecType.Long => s"""long("$name")""" -> Nil
case Code.CodecType.String => s"""string("$name")""" -> Nil
case Code.CodecType.UUID => s"""uuid("$name")""" -> Nil
case Code.CodecType.Literal => s""""$name"""" -> Nil
case Code.CodecType.Boolean => s"""bool("$name")""" -> Nil
case Code.CodecType.Int => s"""int("$name")""" -> Nil
case Code.CodecType.Long => s"""long("$name")""" -> Nil
case Code.CodecType.String => s"""string("$name")""" -> Nil
case Code.CodecType.UUID => s"""uuid("$name")""" -> Nil
case Code.CodecType.LocalDate => s"""date("$name")""" -> Nil
case Code.CodecType.LocalTime => s"""time("$name")""" -> Nil
case Code.CodecType.Instant => s"""date-time("$name")""" -> Nil
case Code.CodecType.Duration => s"""duration("$name")""" -> Nil
case Code.CodecType.Literal => s""""$name"""" -> Nil
case Code.CodecType.Aliased(underlying, newtypeName) =>
val sb = new StringBuilder()
val (code, imports) = renderSegmentType(name, underlying)
Expand Down Expand Up @@ -379,12 +388,16 @@ object CodeGen {
def renderQueryCode(queryCode: Code.QueryParamCode): (List[Code.Import], String) = queryCode match {
case Code.QueryParamCode(name, queryType) =>
val (imports, tpe) = queryType match {
case Code.CodecType.Boolean => Nil -> "Boolean"
case Code.CodecType.Int => Nil -> "Int"
case Code.CodecType.Long => Nil -> "Long"
case Code.CodecType.String => Nil -> "String"
case Code.CodecType.UUID => List(Code.Import("java.util.UUID")) -> "UUID"
case Code.CodecType.Literal => throw new Exception("Literal query params are not supported")
case Code.CodecType.Boolean => Nil -> "Boolean"
case Code.CodecType.Int => Nil -> "Int"
case Code.CodecType.Long => Nil -> "Long"
case Code.CodecType.String => Nil -> "String"
case Code.CodecType.UUID => List(Code.Import("java.util.UUID")) -> "UUID"
case Code.CodecType.LocalDate => List(Code.Import("java.time.LocalDate")) -> "LocalDate"
case Code.CodecType.LocalTime => List(Code.Import("java.time.LocalTime")) -> "LocalTime"
case Code.CodecType.Instant => List(Code.Import("java.time.Instant")) -> "Instant"
case Code.CodecType.Duration => List(Code.Import("java.time.Duration")) -> "Duration"
case Code.CodecType.Literal => throw new Exception("Literal query params are not supported")
case Code.CodecType.Aliased(underlying, newtypeName) =>
val (imports, _) = renderQueryCode(Code.QueryParamCode(name, underlying))
(Code.Import.FromBase(s"components.$newtypeName") :: imports) -> (newtypeName + ".Type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ object Entries {
object POST {
import zio.schema.annotation.validate
import zio.schema.validation.Validation
import java.util.UUID
import java.time.Instant
import java.time.LocalTime
import java.time.LocalDate

case class RequestBody(
id: Int,
Expand All @@ -23,8 +27,12 @@ object Entries {
implicit val codec: Schema[RequestBody] = DeriveSchema.gen[RequestBody]
}
case class ResponseBody(
id: Int,
@validate[String](Validation.maxLength(255) && Validation.minLength(1)) name: String,
uuid: Option[UUID],
deadline: Option[Instant],
id: Int,
time: Option[LocalTime],
day: LocalDate,
)
object ResponseBody {
implicit val codec: Schema[ResponseBody] = DeriveSchema.gen[ResponseBody]
Expand Down
21 changes: 20 additions & 1 deletion zio-http-gen/src/test/resources/inline_schema_minmaxlength.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,32 @@
"string",
"minLength" : 1,
"maxLength" : 255
},
"day" : {
"type" :
"string",
"format": "date"
},
"deadline": {
"type": "string",
"format": "date-time"
},
"time": {
"type": "string",
"format": "time"
},
"uuid" : {
"type" :
"string",
"format" : "uuid"
}
},
"additionalProperties" :
true,
"required" : [
"id",
"name"
"name",
"day"
]
}

Expand Down

0 comments on commit eef8adc

Please sign in to comment.