From eef8adcd4fdf072010242371cd0fcaf768415061 Mon Sep 17 00:00:00 2001 From: Gregor Rayman Date: Wed, 11 Dec 2024 11:57:28 +0100 Subject: [PATCH] In the code generation from OpenAPI schema, added the handling for the string formats `date`, `date-time`, `time` and `duration` to be generated as `LocalDate`, `Instant`, `LocalTime` and `Duration`, and also to configure own formart to type mapping. --- .../scala/zio/http/gen/openapi/Config.scala | 7 ++- .../zio/http/gen/openapi/EndpointGen.scala | 26 ++++++++ .../main/scala/zio/http/gen/scala/Code.scala | 30 +++++---- .../scala/zio/http/gen/scala/CodeGen.scala | 61 +++++++++++-------- ...equestResponseBodyInlineMinMaxLength.scala | 10 ++- .../resources/inline_schema_minmaxlength.json | 21 ++++++- 6 files changed, 117 insertions(+), 38 deletions(-) diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/Config.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/Config.scala index e4d90db78b..f9e62eda04 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/Config.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/Config.scala @@ -33,6 +33,7 @@ final case class Config( commonFieldsOnSuperType: Boolean, generateSafeTypeAliases: Boolean, fieldNamesNormalization: NormalizeFields, + stringFormatTypes: Map[String, String], ) object Config { @@ -73,11 +74,15 @@ object Config { enableAutomatic = false, manualOverrides = Map.empty, ), + stringFormatTypes = Map.empty, ) def config: zio.Config[Config] = ( zio.Config.boolean("common-fields-on-super-type").withDefault(Config.default.commonFieldsOnSuperType) ++ zio.Config.boolean("generate-safe-type-aliases").withDefault(Config.default.generateSafeTypeAliases) ++ - NormalizeFields.config.nested("fields-normalization") + NormalizeFields.config.nested("fields-normalization") ++ zio.Config.table( + "string-format-types", + zio.Config.string, + ) ).to[Config] } diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index f92bdb144d..d1cb7a85a7 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -685,6 +685,14 @@ final case class EndpointGen(config: Config) { Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Long) case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, _, _) => Code.PathSegmentCode(name = name, segmentType = Code.CodecType.UUID) + case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.LocalDate) + case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Instant) + case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.LocalTime) + case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Duration) case JsonSchema.String(_, _, _, _) => Code.PathSegmentCode(name = name, segmentType = Code.CodecType.String) case JsonSchema.Boolean => @@ -719,6 +727,14 @@ final case class EndpointGen(config: Config) { Code.QueryParamCode(name = name, queryType = Code.CodecType.Long) case JsonSchema.Integer(JsonSchema.IntegerFormat.Timestamp, _, _, _, _, _) => Code.QueryParamCode(name = name, queryType = Code.CodecType.Long) + case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.LocalDate) + case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Instant) + case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Duration) + case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.LocalTime) case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, _, _) => Code.QueryParamCode(name = name, queryType = Code.CodecType.UUID) case JsonSchema.String(_, _, _, _) => @@ -1237,9 +1253,19 @@ final case class EndpointGen(config: Config) { val annotations = addNumericValidations[Long](exclusiveMin, exclusiveMax) Some(Code.Field(name, Code.Primitive.ScalaLong, annotations, config.fieldNamesNormalization)) + case JsonSchema.String(Some(format), _, _, _) if config.stringFormatTypes.contains(format.value) => + Some(Code.Field(name, Code.TypeRef(config.stringFormatTypes(format.value)), config.fieldNamesNormalization)) case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, maxLength, minLength) => val annotations = addStringValidations(minLength, maxLength) Some(Code.Field(name, Code.Primitive.ScalaUUID, annotations, config.fieldNamesNormalization)) + case JsonSchema.String(Some(JsonSchema.StringFormat.Date), _, _, _) => + Some(Code.Field(name, Code.Primitive.ScalaLocalDate, config.fieldNamesNormalization)) + case JsonSchema.String(Some(JsonSchema.StringFormat.DateTime), _, _, _) => + Some(Code.Field(name, Code.Primitive.ScalaInstant, config.fieldNamesNormalization)) + case JsonSchema.String(Some(JsonSchema.StringFormat.Time), _, _, _) => + Some(Code.Field(name, Code.Primitive.ScalaTime, config.fieldNamesNormalization)) + case JsonSchema.String(Some(JsonSchema.StringFormat.Duration), _, _, _) => + Some(Code.Field(name, Code.Primitive.ScalaDuration, config.fieldNamesNormalization)) case JsonSchema.String(_, _, maxLength, minLength) => val annotations = addStringValidations(minLength, maxLength) Some(Code.Field(name, Code.Primitive.ScalaString, annotations, config.fieldNamesNormalization)) diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala index 1777c4b44c..aa8827d8e5 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala @@ -217,17 +217,21 @@ object Code { sealed trait Primitive extends ScalaType object Primitive { - case object ScalaInt extends Primitive - case object ScalaLong extends Primitive - case object ScalaDouble extends Primitive - case object ScalaFloat extends Primitive - case object ScalaChar extends Primitive - case object ScalaByte extends Primitive - case object ScalaShort extends Primitive - case object ScalaBoolean extends Primitive - case object ScalaUnit extends Primitive - case object ScalaUUID extends Primitive - case object ScalaString extends Primitive + case object ScalaInt extends Primitive + case object ScalaLong extends Primitive + case object ScalaDouble extends Primitive + case object ScalaFloat extends Primitive + case object ScalaChar extends Primitive + case object ScalaByte extends Primitive + case object ScalaShort extends Primitive + case object ScalaBoolean extends Primitive + case object ScalaUnit extends Primitive + case object ScalaUUID extends Primitive + case object ScalaLocalDate extends Primitive + case object ScalaInstant extends Primitive + case object ScalaTime extends Primitive + case object ScalaDuration extends Primitive + case object ScalaString extends Primitive } final case class EndpointCode( @@ -253,6 +257,10 @@ object Code { case object Long extends CodecType case object String extends CodecType case object UUID extends CodecType + case object LocalDate extends CodecType + case object LocalTime extends CodecType + case object Duration extends CodecType + case object Instant extends CodecType case class Aliased(underlying: CodecType, newtypeName: String) extends CodecType } final case class QueryParamCode(name: String, queryType: CodecType) diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala index 4a797b83f7..6b582ca7d9 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala @@ -4,6 +4,8 @@ import java.nio.charset.StandardCharsets import java.nio.file.StandardOpenOption._ import java.nio.file._ +import scala.util.matching.Regex + object CodeGen { private val EndpointImports = @@ -230,18 +232,21 @@ object CodeGen { val multipleAnnotationsAboveContent = if (annotationValues.size > 1) "\n" + content else content allImports -> annotationValues.mkString("", "\n", multipleAnnotationsAboveContent) - case Code.Primitive.ScalaBoolean => Nil -> "Boolean" - case Code.Primitive.ScalaByte => Nil -> "Byte" - case Code.Primitive.ScalaChar => Nil -> "Char" - case Code.Primitive.ScalaDouble => Nil -> "Double" - case Code.Primitive.ScalaFloat => Nil -> "Float" - case Code.Primitive.ScalaInt => Nil -> "Int" - case Code.Primitive.ScalaLong => Nil -> "Long" - case Code.Primitive.ScalaShort => Nil -> "Short" - case Code.Primitive.ScalaString => Nil -> "String" - case Code.Primitive.ScalaUnit => Nil -> "Unit" - case Code.Primitive.ScalaUUID => List(Code.Import("java.util.UUID")) -> "UUID" - case Code.ScalaType.Inferred => Nil -> "" + case Code.Primitive.ScalaBoolean => Nil -> "Boolean" + case Code.Primitive.ScalaByte => Nil -> "Byte" + case Code.Primitive.ScalaChar => Nil -> "Char" + case Code.Primitive.ScalaDouble => Nil -> "Double" + case Code.Primitive.ScalaFloat => Nil -> "Float" + case Code.Primitive.ScalaInt => Nil -> "Int" + case Code.Primitive.ScalaLong => Nil -> "Long" + case Code.Primitive.ScalaShort => Nil -> "Short" + case Code.Primitive.ScalaString => Nil -> "String" + case Code.Primitive.ScalaUnit => Nil -> "Unit" + case Code.Primitive.ScalaUUID => List(Code.Import("java.util.UUID")) -> "UUID" + case Code.Primitive.ScalaLocalDate => List(Code.Import("java.time.LocalDate")) -> "LocalDate" + case Code.Primitive.ScalaInstant => List(Code.Import("java.time.Instant")) -> "Instant" + case Code.Primitive.ScalaTime => List(Code.Import("java.time.LocalTime")) -> "LocalTime" + case Code.ScalaType.Inferred => Nil -> "" case Code.EndpointCode(method, pathPatternCode, queryParamsCode, headersCode, inCode, outCodes, errorsCode) => val (queryImports, queryContent) = queryParamsCode.map(renderQueryCode).unzip @@ -266,12 +271,16 @@ object CodeGen { def renderSegmentType(name: String, segmentType: Code.CodecType): (String, List[Code.Import]) = segmentType match { - case Code.CodecType.Boolean => s"""bool("$name")""" -> Nil - case Code.CodecType.Int => s"""int("$name")""" -> Nil - case Code.CodecType.Long => s"""long("$name")""" -> Nil - case Code.CodecType.String => s"""string("$name")""" -> Nil - case Code.CodecType.UUID => s"""uuid("$name")""" -> Nil - case Code.CodecType.Literal => s""""$name"""" -> Nil + case Code.CodecType.Boolean => s"""bool("$name")""" -> Nil + case Code.CodecType.Int => s"""int("$name")""" -> Nil + case Code.CodecType.Long => s"""long("$name")""" -> Nil + case Code.CodecType.String => s"""string("$name")""" -> Nil + case Code.CodecType.UUID => s"""uuid("$name")""" -> Nil + case Code.CodecType.LocalDate => s"""date("$name")""" -> Nil + case Code.CodecType.LocalTime => s"""time("$name")""" -> Nil + case Code.CodecType.Instant => s"""date-time("$name")""" -> Nil + case Code.CodecType.Duration => s"""duration("$name")""" -> Nil + case Code.CodecType.Literal => s""""$name"""" -> Nil case Code.CodecType.Aliased(underlying, newtypeName) => val sb = new StringBuilder() val (code, imports) = renderSegmentType(name, underlying) @@ -379,12 +388,16 @@ object CodeGen { def renderQueryCode(queryCode: Code.QueryParamCode): (List[Code.Import], String) = queryCode match { case Code.QueryParamCode(name, queryType) => val (imports, tpe) = queryType match { - case Code.CodecType.Boolean => Nil -> "Boolean" - case Code.CodecType.Int => Nil -> "Int" - case Code.CodecType.Long => Nil -> "Long" - case Code.CodecType.String => Nil -> "String" - case Code.CodecType.UUID => List(Code.Import("java.util.UUID")) -> "UUID" - case Code.CodecType.Literal => throw new Exception("Literal query params are not supported") + case Code.CodecType.Boolean => Nil -> "Boolean" + case Code.CodecType.Int => Nil -> "Int" + case Code.CodecType.Long => Nil -> "Long" + case Code.CodecType.String => Nil -> "String" + case Code.CodecType.UUID => List(Code.Import("java.util.UUID")) -> "UUID" + case Code.CodecType.LocalDate => List(Code.Import("java.time.LocalDate")) -> "LocalDate" + case Code.CodecType.LocalTime => List(Code.Import("java.time.LocalTime")) -> "LocalTime" + case Code.CodecType.Instant => List(Code.Import("java.time.Instant")) -> "Instant" + case Code.CodecType.Duration => List(Code.Import("java.time.Duration")) -> "Duration" + case Code.CodecType.Literal => throw new Exception("Literal query params are not supported") case Code.CodecType.Aliased(underlying, newtypeName) => val (imports, _) = renderQueryCode(Code.QueryParamCode(name, underlying)) (Code.Import.FromBase(s"components.$newtypeName") :: imports) -> (newtypeName + ".Type") diff --git a/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineMinMaxLength.scala b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineMinMaxLength.scala index 1808fd4a0a..2d69fa2a4a 100644 --- a/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineMinMaxLength.scala +++ b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineMinMaxLength.scala @@ -14,6 +14,10 @@ object Entries { object POST { import zio.schema.annotation.validate import zio.schema.validation.Validation + import java.util.UUID + import java.time.Instant + import java.time.LocalTime + import java.time.LocalDate case class RequestBody( id: Int, @@ -23,8 +27,12 @@ object Entries { implicit val codec: Schema[RequestBody] = DeriveSchema.gen[RequestBody] } case class ResponseBody( - id: Int, @validate[String](Validation.maxLength(255) && Validation.minLength(1)) name: String, + uuid: Option[UUID], + deadline: Option[Instant], + id: Int, + time: Option[LocalTime], + day: LocalDate, ) object ResponseBody { implicit val codec: Schema[ResponseBody] = DeriveSchema.gen[ResponseBody] diff --git a/zio-http-gen/src/test/resources/inline_schema_minmaxlength.json b/zio-http-gen/src/test/resources/inline_schema_minmaxlength.json index e8844cd804..9f59708189 100644 --- a/zio-http-gen/src/test/resources/inline_schema_minmaxlength.json +++ b/zio-http-gen/src/test/resources/inline_schema_minmaxlength.json @@ -61,13 +61,32 @@ "string", "minLength" : 1, "maxLength" : 255 + }, + "day" : { + "type" : + "string", + "format": "date" + }, + "deadline": { + "type": "string", + "format": "date-time" + }, + "time": { + "type": "string", + "format": "time" + }, + "uuid" : { + "type" : + "string", + "format" : "uuid" } }, "additionalProperties" : true, "required" : [ "id", - "name" + "name", + "day" ] }