diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9236feae67..ab02dace77 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @amitksingh1490 @jdegoes @vigoo @adamgfraser +* @amitksingh1490 @jdegoes @vigoo @adamgfraser @987Nabil diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9ce1391fd5..ed135607c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - scala: [2.12.18, 2.13.10, 3.3.0] + scala: [2.12.18, 2.13.12, 3.3.1] java: [graal_graalvm@17, temurin@8] runs-on: ${{ matrix.os }} timeout-minutes: 60 @@ -54,8 +54,8 @@ jobs: cache: sbt - name: Check formatting - if: matrix.scala == '2.13.10' - run: sbt ++2.13.10 fmtCheck + if: matrix.scala == '2.13.12' + run: sbt ++2.13.12 fmtCheck - name: Check that workflows are up to date run: sbt '++ ${{ matrix.scala }}' githubWorkflowCheck @@ -65,16 +65,16 @@ jobs: - name: Check doc generation if: ${{ github.event_name == 'pull_request' }} - run: sbt ++2.13.10 doc + run: sbt ++2.13.12 doc - name: zio-http-shaded Tests - if: matrix.scala == '2.13.10' + if: matrix.scala == '2.13.12' env: PUBLISH_SHADED: true run: sbt '++ ${{ matrix.scala }}' zioHttpShadedTests/test - name: Compress target directories - run: tar cf targets.tar zio-http-cli/target target zio-http/target zio-http-docs/target zio-http-benchmarks/target zio-http-example/target zio-http-testkit/target project/target + run: tar cf targets.tar zio-http-cli/target target zio-http/target zio-http-docs/target zio-http-gen/target zio-http-benchmarks/target zio-http-example/target zio-http-testkit/target project/target - name: Upload target directories uses: actions/upload-artifact@v3 @@ -89,7 +89,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - scala: [2.13.10] + scala: [2.13.12] java: [graal_graalvm@17] runs-on: ${{ matrix.os }} steps: @@ -126,22 +126,22 @@ jobs: tar xf targets.tar rm targets.tar - - name: Download target directories (2.13.10) + - name: Download target directories (2.13.12) uses: actions/download-artifact@v3 with: - name: target-${{ matrix.os }}-2.13.10-${{ matrix.java }} + name: target-${{ matrix.os }}-2.13.12-${{ matrix.java }} - - name: Inflate target directories (2.13.10) + - name: Inflate target directories (2.13.12) run: | tar xf targets.tar rm targets.tar - - name: Download target directories (3.3.0) + - name: Download target directories (3.3.1) uses: actions/download-artifact@v3 with: - name: target-${{ matrix.os }}-3.3.0-${{ matrix.java }} + name: target-${{ matrix.os }}-3.3.1-${{ matrix.java }} - - name: Inflate target directories (3.3.0) + - name: Inflate target directories (3.3.1) run: | tar xf targets.tar rm targets.tar @@ -182,7 +182,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - scala: [2.13.10] + scala: [2.13.12] java: [temurin@8] runs-on: ${{ matrix.os }} steps: @@ -193,7 +193,7 @@ jobs: - name: Add Scoverage id: add_plugin - run: sed -i -e '$aaddSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.9.3")' project/plugins.sbt + run: sed -i -e '$aaddSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.9")' project/plugins.sbt - name: Update Build Definition id: update_build_definition @@ -215,7 +215,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - scala: [2.13.10] + scala: [2.13.12] java: [temurin@8] runs-on: ${{ matrix.os }} steps: diff --git a/.scalafmt.conf b/.scalafmt.conf index 792e8620f8..022e17afb1 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.7.14 +version = 3.7.17 maxColumn = 120 align.preset = more diff --git a/README.md b/README.md index cd884f1d60..1727665fb9 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ ZIO HTTP is a scala library for building http apps. It is powered by ZIO and [Ne Setup via `build.sbt`: ```scala -libraryDependencies += "dev.zio" %% "zio-http" % "3.0.0-RC3" +libraryDependencies += "dev.zio" %% "zio-http" % "3.0.0-RC4" ``` **NOTES ON VERSIONING:** diff --git a/build.sbt b/build.sbt index bfdd29ff55..d4ac37d356 100644 --- a/build.sbt +++ b/build.sbt @@ -1,5 +1,5 @@ import BuildHelper._ -import Dependencies._ +import Dependencies.{scalafmt, _} import sbt.librarymanagement.ScalaArtifacts.isScala3 import scala.concurrent.duration._ @@ -12,7 +12,10 @@ ThisBuild / resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots" // CI Configuration -ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.graalvm(Graalvm.Distribution("graalvm"), "17"), JavaSpec.temurin("8")) +ThisBuild / githubWorkflowJavaVersions := Seq( + JavaSpec.graalvm(Graalvm.Distribution("graalvm"), "17"), + JavaSpec.temurin("8"), +) ThisBuild / githubWorkflowPREventTypes := Seq( PREventType.Opened, PREventType.Synchronize, @@ -104,17 +107,28 @@ inThisBuild( ThisBuild / githubWorkflowBuildTimeout := Some(60.minutes) +lazy val aggregatedProjects: Seq[ProjectReference] = + if (Shading.shadingEnabled) { + Seq( + zioHttp, + zioHttpTestkit, + ) + } else { + Seq( + zioHttp, + zioHttpBenchmarks, + zioHttpCli, + zioHttpGen, + zioHttpExample, + zioHttpTestkit, + docs, + ) + } + lazy val root = (project in file(".")) .settings(stdSettings("zio-http-root")) .settings(publishSetting(false)) - .aggregate( - zioHttp, - zioHttpBenchmarks, - zioHttpCli, - zioHttpExample, - zioHttpTestkit, - docs, - ) + .aggregate(aggregatedProjects: _*) lazy val zioHttp = (project in file("zio-http")) .enablePlugins(Shading.plugins(): _*) @@ -196,7 +210,7 @@ lazy val zioHttpBenchmarks = (project in file("zio-http-benchmarks")) // "com.softwaremill.sttp.tapir" %% "tapir-akka-http-server" % "1.1.0", "com.softwaremill.sttp.tapir" %% "tapir-http4s-server" % "1.5.1", "com.softwaremill.sttp.tapir" %% "tapir-json-circe" % "1.5.1", - "com.softwaremill.sttp.client3" %% "core" % "3.9.0", + "com.softwaremill.sttp.client3" %% "core" % "3.9.1", // "dev.zio" %% "zio-interop-cats" % "3.3.0", "org.slf4j" % "slf4j-api" % "2.0.9", "org.slf4j" % "slf4j-simple" % "2.0.9", @@ -224,6 +238,19 @@ lazy val zioHttpExample = (project in file("zio-http-example")) .settings(libraryDependencies ++= Seq(`jwt-core`)) .dependsOn(zioHttp, zioHttpCli) +lazy val zioHttpGen = (project in file("zio-http-gen")) + .settings(stdSettings("zio-http-gen")) + .settings(publishSetting(true)) + .settings( + libraryDependencies ++= Seq( + `zio`, + `zio-test`, + `zio-test-sbt`, + scalafmt.cross(CrossVersion.for3Use2_13), + ), + ) + .dependsOn(zioHttp) + lazy val zioHttpTestkit = (project in file("zio-http-testkit")) .enablePlugins(Shading.plugins(): _*) .settings(stdSettings("zio-http-testkit")) diff --git a/docs/examples/advanced/static-files.md b/docs/examples/advanced/static-files.md new file mode 100644 index 0000000000..8108509554 --- /dev/null +++ b/docs/examples/advanced/static-files.md @@ -0,0 +1,23 @@ +--- +id: static-files +title: "Serving Static Files" +sidebar_label: "Static Files" +--- + +```scala mdoc:silent +import zio._ +import zio.http._ + +object StaticFiles extends ZIOAppDefault { + + /** + * Creates an HTTP app that only serves static files from resources via + * "/static". For paths other than the resources directory, see + * [[Middleware.serveDirectory]]. + */ + val app = Routes.empty.toHttpApp @@ Middleware.serveResources(Path.empty / "static") + + override def run = Server.serve(app).provide(Server.default) +} + +``` \ No newline at end of file diff --git a/project/BuildHelper.scala b/project/BuildHelper.scala index 039dcaf6ed..cdd39c839e 100644 --- a/project/BuildHelper.scala +++ b/project/BuildHelper.scala @@ -6,11 +6,11 @@ import de.heikoseeberger.sbtheader.HeaderPlugin.autoImport.{headerLicense, Heade object BuildHelper extends ScalaSettings { val Scala212 = "2.12.18" - val Scala213 = "2.13.10" - val Scala3 = "3.3.0" - val ScoverageVersion = "1.9.3" + val Scala213 = "2.13.12" + val Scala3 = "3.3.1" + val ScoverageVersion = "2.0.9" val JmhVersion = "0.4.3" - val SilencerVersion = "1.17.13" + val SilencerVersion = "1.7.14" private val stdOptions = Seq( "-deprecation", diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 9e7761376a..ec159d6f95 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -3,17 +3,19 @@ import sbt.Keys.scalaVersion object Dependencies { val JwtCoreVersion = "9.1.1" - val NettyVersion = "4.1.100.Final" - val NettyIncubatorVersion = "0.0.20.Final" + val NettyVersion = "4.1.101.Final" + val NettyIncubatorVersion = "0.0.24.Final" val ScalaCompactCollectionVersion = "2.11.0" - val ZioVersion = "2.0.18" + val ZioVersion = "2.0.19" val ZioCliVersion = "0.5.0" - val ZioSchemaVersion = "0.4.14" + val ZioSchemaVersion = "0.4.16" val SttpVersion = "3.3.18" val `jwt-core` = "com.github.jwt-scala" %% "jwt-core" % JwtCoreVersion val `scala-compact-collection` = "org.scala-lang.modules" %% "scala-collection-compat" % ScalaCompactCollectionVersion + val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.7.17" + val netty = Seq( "io.netty" % "netty-codec-http" % NettyVersion, diff --git a/project/build.properties b/project/build.properties index 27430827bc..e8a1e246e8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.9.6 +sbt.version=1.9.7 diff --git a/project/plugins.sbt b/project/plugins.sbt index 989c03b398..56a2650f15 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,9 +3,9 @@ addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0") addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.6") addSbtPlugin("com.timushev.sbt" % "sbt-updates" % "0.6.3") addSbtPlugin("io.spray" % "sbt-revolver" % "0.10.0") -addSbtPlugin("com.github.sbt" % "sbt-github-actions" % "0.18.0") +addSbtPlugin("com.github.sbt" % "sbt-github-actions" % "0.19.0") addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12") addSbtPlugin("dev.zio" % "zio-sbt-website" % "0.3.10") addSbtPlugin("de.heikoseeberger" % "sbt-header" % "5.10.0") -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.8") +addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.9") addSbtPlugin("io.get-coursier" % "sbt-shading" % "2.1.3") diff --git a/zio-http-benchmarks/src/main/scala-2.13/zio/http/benchmarks/EndpointBenchmark.scala b/zio-http-benchmarks/src/main/scala-2.13/zio/http/benchmarks/EndpointBenchmark.scala index 8a3b9cd9ca..8e23412e81 100644 --- a/zio-http-benchmarks/src/main/scala-2.13/zio/http/benchmarks/EndpointBenchmark.scala +++ b/zio-http-benchmarks/src/main/scala-2.13/zio/http/benchmarks/EndpointBenchmark.scala @@ -11,8 +11,8 @@ import zio.{Scope => _, _} import zio.schema.{DeriveSchema, Schema} import zio.http._ -import zio.http.codec.{PathCodec, QueryCodec} -import zio.http.endpoint._ +import zio.http.codec.QueryCodec +import zio.http.endpoint.Endpoint import cats.effect.unsafe.implicits.global import cats.effect.{IO => CIO} diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index e7f393b456..2d855e3c17 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -1,7 +1,6 @@ package zio.http.endpoint.cli -import java.nio.file.Path - +import scala.annotation.tailrec import scala.language.implicitConversions import scala.util.Try @@ -9,6 +8,7 @@ import zio.cli._ import zio.json.ast._ import zio.schema._ +import zio.schema.annotation.description import zio.http._ import zio.http.codec._ @@ -241,8 +241,8 @@ private[cli] object HttpOptions { self => override val name = pathCodec.segments.map { - case SegmentCodec.Literal(value, _) => value - case _ => "" + case SegmentCodec.Literal(value) => value + case _ => "" } .filter(_ != "") .mkString("-") @@ -300,7 +300,7 @@ private[cli] object HttpOptions { Try(java.util.UUID.fromString(str)).toEither.left.map { error => ValidationError( ValidationErrorType.InvalidValue, - HelpDoc.p(HelpDoc.Span.code(error.getMessage())), + HelpDoc.p(HelpDoc.Span.code(error.getMessage)), ) }, ) @@ -312,27 +312,29 @@ private[cli] object HttpOptions { } private[cli] def optionsFromSegment(segment: SegmentCodec[_]): Options[String] = { + @tailrec def fromSegment[A](segment: SegmentCodec[A]): Options[String] = segment match { - case SegmentCodec.UUID(name, doc) => + case SegmentCodec.UUID(name) => Options .text(name) .mapOrFail(str => Try(java.util.UUID.fromString(str)).toEither.left.map { error => ValidationError( ValidationErrorType.InvalidValue, - HelpDoc.p(HelpDoc.Span.code(error.getMessage())), + HelpDoc.p(HelpDoc.Span.code(error.getMessage)), ) }, ) .map(_.toString) - case SegmentCodec.Text(name, doc) => Options.text(name) - case SegmentCodec.IntSeg(name, doc) => Options.integer(name).map(_.toInt).map(_.toString) - case SegmentCodec.LongSeg(name, doc) => Options.integer(name).map(_.toInt).map(_.toString) - case SegmentCodec.BoolSeg(name, doc) => Options.boolean(name).map(_.toString) - case SegmentCodec.Literal(value, doc) => Options.Empty.map(_ => value) - case SegmentCodec.Trailing(doc) => Options.none.map(_.toString) - case SegmentCodec.Empty(_) => Options.none.map(_.toString) + case SegmentCodec.Text(name) => Options.text(name) + case SegmentCodec.IntSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) + case SegmentCodec.LongSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) + case SegmentCodec.BoolSeg(name) => Options.boolean(name).map(_.toString) + case SegmentCodec.Literal(value) => Options.Empty.map(_ => value) + case SegmentCodec.Trailing => Options.none.map(_.toString) + case SegmentCodec.Empty => Options.none.map(_.toString) + case SegmentCodec.Annotated(codec, _) => fromSegment(codec) } fromSegment(segment) diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/description.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/description.scala deleted file mode 100644 index d36c380e2a..0000000000 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/description.scala +++ /dev/null @@ -1,5 +0,0 @@ -package zio.http.endpoint.cli - -import scala.annotation.StaticAnnotation - -class description(val text: String) extends StaticAnnotation // TODO move this to zio-schema diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index 4b61db919e..53334db6c0 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -1,6 +1,7 @@ package zio.http.endpoint.cli -import zio.ZNothing +import scala.annotation.tailrec + import zio.cli._ import zio.test._ @@ -9,7 +10,6 @@ import zio.schema._ import zio.http._ import zio.http.codec._ import zio.http.endpoint._ -import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr.HelpRepr import zio.http.endpoint.cli.EndpointGen._ @@ -20,17 +20,20 @@ import zio.http.endpoint.cli.EndpointGen._ object CommandGen { def getSegment(segment: SegmentCodec[_]): (String, String) = { + @tailrec def fromSegment[A](segment: SegmentCodec[A]): (String, String) = segment match { - case SegmentCodec.UUID(name, doc) => (name, "text") - case SegmentCodec.Text(name, doc) => (name, "text") - case SegmentCodec.IntSeg(name, doc) => (name, "integer") - case SegmentCodec.LongSeg(name, doc) => (name, "integer") - case SegmentCodec.BoolSeg(name, doc) => (name, "boolean") - case SegmentCodec.Literal(value, doc) => ("", "") - case SegmentCodec.Trailing(doc) => ("", "") - case SegmentCodec.Empty(_) => ("", "") + case SegmentCodec.UUID(name) => (name, "text") + case SegmentCodec.Text(name) => (name, "text") + case SegmentCodec.IntSeg(name) => (name, "integer") + case SegmentCodec.LongSeg(name) => (name, "integer") + case SegmentCodec.BoolSeg(name) => (name, "boolean") + case SegmentCodec.Literal(_) => ("", "") + case SegmentCodec.Trailing => ("", "") + case SegmentCodec.Empty => ("", "") + case SegmentCodec.Annotated(codec, _) => fromSegment(codec) } + fromSegment(segment) } diff --git a/zio-http-example/src/main/scala/example/CliExamples.scala b/zio-http-example/src/main/scala/example/CliExamples.scala index 3b5032b3fa..0ee4c0d759 100644 --- a/zio-http-example/src/main/scala/example/CliExamples.scala +++ b/zio-http-example/src/main/scala/example/CliExamples.scala @@ -4,12 +4,13 @@ import zio._ import zio.cli._ import zio.schema._ +import zio.schema.annotation.description import zio.http.Header.Location import zio.http._ import zio.http.codec._ -import zio.http.endpoint._ import zio.http.endpoint.cli._ +import zio.http.endpoint.{Endpoint, EndpointExecutor} trait TestCliEndpoints { import zio.http.codec.PathCodec._ diff --git a/zio-http-example/src/main/scala/example/EndpointExamples.scala b/zio-http-example/src/main/scala/example/EndpointExamples.scala index 1cbe14077f..0d40c12427 100644 --- a/zio-http-example/src/main/scala/example/EndpointExamples.scala +++ b/zio-http-example/src/main/scala/example/EndpointExamples.scala @@ -3,12 +3,13 @@ package example import zio._ import zio.http.Header.Authorization +import zio.http._ import zio.http.codec.{HttpCodec, PathCodec} -import zio.http.endpoint._ -import zio.http.{int => _, _} +import zio.http.endpoint.openapi.{OpenAPIGen, SwaggerUI} +import zio.http.endpoint.{Endpoint, EndpointExecutor, EndpointLocator, EndpointMiddleware} object EndpointExamples extends ZIOAppDefault { - import HttpCodec._ + import HttpCodec.query import PathCodec._ val auth = EndpointMiddleware.auth @@ -36,7 +37,9 @@ object EndpointExamples extends ZIOAppDefault { } } - val routes = Routes(getUserRoute, getUserPostsRoute) + val openAPI = OpenAPIGen.fromEndpoints(title = "Endpoint Example", version = "1.0", getUser, getUserPosts) + + val routes = Routes(getUserRoute, getUserPostsRoute) ++ SwaggerUI.routes("docs" / "openapi", openAPI) val app = routes.toHttpApp // (auth.implement(_ => ZIO.unit)(_ => ZIO.unit)) diff --git a/zio-http-example/src/main/scala/example/StaticFiles.scala b/zio-http-example/src/main/scala/example/StaticFiles.scala new file mode 100644 index 0000000000..f04f33c91f --- /dev/null +++ b/zio-http-example/src/main/scala/example/StaticFiles.scala @@ -0,0 +1,17 @@ +package example + +import zio._ + +import zio.http._ + +object StaticFiles extends ZIOAppDefault { + + /** + * Creates an HTTP app that only serves static files from resources via + * "/static". For paths other than the resources directory, see + * [[Middleware.serveDirectory]]. + */ + val app = Routes.empty.toHttpApp @@ Middleware.serveResources(Path.empty / "static") + + override def run = Server.serve(app).provide(Server.default) +} diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala new file mode 100644 index 0000000000..972b30190a --- /dev/null +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -0,0 +1,773 @@ +package zio.http.gen.openapi + +import scala.annotation.tailrec + +import zio.Chunk + +import zio.http.Method +import zio.http.endpoint.openapi.OpenAPI.ReferenceOr +import zio.http.endpoint.openapi.{JsonSchema, OpenAPI} +import zio.http.gen.scala.Code +import zio.http.gen.scala.Code.ScalaType + +object EndpointGen { + + private object Inline { + val RequestBodyType = "RequestBody" + val ResponseBodyType = "ResponseBody" + val Null = "Unit" + } + + private val DataImports = + List( + Code.Import("zio.schema._"), + ) + + private val RequestBodyRef = "#/components/requestBodies/(.*)".r + private val ParameterRef = "#/components/parameters/(.*)".r + private val SchemaRef = "#/components/schemas/(.*)".r + private val ResponseRef = "#/components/responses/(.*)".r + + def fromOpenAPI(openAPI: OpenAPI): Code.Files = + EndpointGen().fromOpenAPI(openAPI) + +} + +final case class EndpointGen() { + import EndpointGen._ + + private var anonymousTypes: Map[String, Code.Object] = Map.empty[String, Code.Object] + + def fromOpenAPI(openAPI: OpenAPI): Code.Files = + Code.Files { + openAPI.paths.map { case (path, pathItem) => + val pathSegments = path.name.tail.replace('-', '_').split('/').toList + val packageName = pathSegments.init.mkString(".").replace("{", "").replace("}", "") + val className = pathSegments.last.replace("{", "").replace("}", "").capitalize + val params = List( + pathItem.delete, + pathItem.get, + pathItem.head, + pathItem.options, + pathItem.post, + pathItem.put, + pathItem.patch, + pathItem.trace, + ).flatten + .flatMap(_.parameters) + .map { + case OpenAPI.ReferenceOr.Or(param: OpenAPI.Parameter) => + param + case OpenAPI.ReferenceOr.Reference(ParameterRef(key), _, _) => + resolveParameterRef(openAPI, key) + case other => throw new Exception(s"Unexpected parameter definition: $other") + } + .map(p => p.name -> p) + .toMap + val segments = pathSegments.map { + case s if s.startsWith("{") && s.endsWith("}") => + val name = s.tail.init + val param = params.getOrElse( + name, + throw new Exception( + s"Path parameter $name not found in parameters: ${params.keys.mkString(", ")}", + ), + ) + parameterToPathCodec(openAPI, param) + case s => Code.PathSegmentCode(s, Code.CodecType.Literal) + } + + Code.File( + packageName.split('.').toList :+ s"$className.scala", + pkgPath = packageName.split('.').toList, + imports = List(Code.Import.FromBase("component._")), + objects = List( + Code.Object( + className, + schema = false, + endpoints = List( + pathItem.delete.map(op => fieldName(op, "delete") -> endpoint(segments, op, openAPI, Method.DELETE)), + pathItem.get.map(op => fieldName(op, "get") -> endpoint(segments, op, openAPI, Method.GET)), + pathItem.head.map(op => fieldName(op, "head") -> endpoint(segments, op, openAPI, Method.HEAD)), + pathItem.options.map(op => fieldName(op, "options") -> endpoint(segments, op, openAPI, Method.OPTIONS)), + pathItem.post.map(op => fieldName(op, "post") -> endpoint(segments, op, openAPI, Method.POST)), + pathItem.put.map(op => fieldName(op, "put") -> endpoint(segments, op, openAPI, Method.PUT)), + pathItem.patch.map(op => fieldName(op, "patch") -> endpoint(segments, op, openAPI, Method.PATCH)), + pathItem.trace.map(op => fieldName(op, "trace") -> endpoint(segments, op, openAPI, Method.TRACE)), + ).flatten.toMap, + objects = anonymousTypes.values.toList, + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + }.toList ++ + openAPI.components.toList.flatMap { components => + components.schemas.flatMap { case (OpenAPI.Key(name), refOrSchema) => + var annotations: Chunk[JsonSchema.MetaData] = Chunk.empty + val schema = refOrSchema match { + case ReferenceOr.Or(schema: JsonSchema) => + annotations = schema.annotations + schema.withoutAnnotations + case ReferenceOr.Reference(ref, _, _) => + val schema = resolveSchemaRef(openAPI, ref) + annotations = schema.annotations + schema.withoutAnnotations + } + schemaToCode(schema, openAPI, name, annotations) + } + } + } + + private def fieldName(op: OpenAPI.Operation, fallback: String) = + Code.Field(op.operationId.getOrElse(fallback)) + + private def endpoint( + segments: List[Code.PathSegmentCode], + op: OpenAPI.Operation, + openAPI: OpenAPI, + method: Method, + ) = { + + val params = op.parameters.map { + case OpenAPI.ReferenceOr.Or(param: OpenAPI.Parameter) => param + case OpenAPI.ReferenceOr.Reference(ParameterRef(key), _, _) => resolveParameterRef(openAPI, key) + case other => throw new Exception(s"Unexpected parameter definition: $other") + } + // TODO: Resolve query and header parameters from components + val queryParams = params.collect { + case p if p.in == "query" => + schemaToQueryParamCodec( + p.schema.get.asInstanceOf[ReferenceOr.Or[JsonSchema]].value, + openAPI, + p.name, + ) + } + val headers = params.collect { case p if p.in == "header" => Code.HeaderCode(p.name) }.toList + val inType = + op.requestBody.flatMap { + case OpenAPI.ReferenceOr.Reference(RequestBodyRef(key), _, _) => Some(key) + case OpenAPI.ReferenceOr.Or(body: OpenAPI.RequestBody) => + body.content + .get("application/json") + .map { mt => + mt.schema match { + case ReferenceOr.Or(s) => + s.withoutAnnotations match { + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case schema if schema.isPrimitive => + schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString + case schema => + val code = schemaToCode(schema, openAPI, Inline.RequestBodyType, Chunk.empty) + .getOrElse( + throw new Exception(s"Could not generate code for request body $schema"), + ) + anonymousTypes += method.toString -> + Code.Object( + method.toString, + schema = false, + endpoints = Map.empty, + objects = code.objects, + caseClasses = code.caseClasses, + enums = code.enums, + ) + s"$method.${Inline.RequestBodyType}" + } + case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref + case other => throw new Exception(s"Unexpected request body schema: $other") + } + } + case other => throw new Exception(s"Unexpected request body definition: $other") + }.getOrElse("Unit") + + val outCodes: Iterable[Code.OutCode] = + // TODO: ignore default for now. Not sure how to handle it + op.responses.collect { + case (OpenAPI.StatusOrDefault.StatusValue(status), OpenAPI.ReferenceOr.Reference(ResponseRef(key), _, _)) => + val response = resolveResponseRef(openAPI, key) + Code.OutCode( + outType = response.content + .get("application/json") + .map { mt => + mt.schema match { + case ReferenceOr.Or(s) => + s.withoutAnnotations match { + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case schema if schema.isPrimitive => + schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString + case schema => + val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) + .getOrElse( + throw new Exception(s"Could not generate code for request body $schema"), + ) + val obj = Code.Object( + method.toString, + schema = false, + endpoints = Map.empty, + objects = code.objects, + caseClasses = code.caseClasses, + enums = code.enums, + ) + anonymousTypes += method.toString -> anonymousTypes.get(method.toString).fold(obj) { obj => + obj.copy( + objects = obj.objects ++ code.objects, + caseClasses = obj.caseClasses ++ code.caseClasses, + enums = obj.enums ++ code.enums, + ) + } + s"$method.${Inline.ResponseBodyType}" + } + case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref + case other => throw new Exception(s"Unexpected response body schema: $other") + } + } + .getOrElse("Unit"), + status = status, + mediaType = Some("application/json"), + doc = None, + ) + case (OpenAPI.StatusOrDefault.StatusValue(status), OpenAPI.ReferenceOr.Or(response: OpenAPI.Response)) => + Code.OutCode( + outType = response.content + .get("application/json") + .map { mt => + mt.schema match { + case ReferenceOr.Or(s) => + s.withoutAnnotations match { + case JsonSchema.Null => Inline.Null + case JsonSchema.RefSchema(SchemaRef(ref)) => ref + case schema if schema.isPrimitive => + schemaToField(schema, openAPI, "unused", Chunk.empty).get.fieldType.toString + case schema => + val code = schemaToCode(schema, openAPI, Inline.ResponseBodyType, Chunk.empty) + .getOrElse( + throw new Exception(s"Could not generate code for request body $schema"), + ) + val obj = Code.Object( + method.toString, + schema = false, + endpoints = Map.empty, + objects = code.objects, + caseClasses = code.caseClasses, + enums = code.enums, + ) + anonymousTypes += method.toString -> anonymousTypes.get(method.toString).fold(obj) { obj => + obj.copy( + objects = obj.objects ++ code.objects, + caseClasses = obj.caseClasses ++ code.caseClasses, + enums = obj.enums ++ code.enums, + ) + } + s"$method.${Inline.ResponseBodyType}" + } + case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => ref + case other => throw new Exception(s"Unexpected response body schema: $other") + } + } + .getOrElse("Unit"), + status = status, + mediaType = Some("application/json"), + doc = None, + ) + } + + Code.EndpointCode( + method = method, + pathPatternCode = Code.PathPatternCode(segments), + queryParamsCode = queryParams, + headersCode = Code.HeadersCode(headers), + inCode = Code.InCode(inType, None, None), + outCodes = outCodes.filterNot(_.status.isError).toList, + errorsCode = outCodes.filter(_.status.isError).toList, + ) + + } + + private def parameterToPathCodec(openAPI: OpenAPI, param: OpenAPI.Parameter): Code.PathSegmentCode = { + param.schema match { + case Some(OpenAPI.ReferenceOr.Or(schema: JsonSchema)) => + schemaToPathCodec(schema, openAPI, param.name) + case Some(OpenAPI.ReferenceOr.Reference(ref, _, _)) => + schemaToPathCodec(resolveSchemaRef(openAPI, ref), openAPI, param.name) + case None => + // Not sure if open api allows path parameters without schema. + // But string seems a good default + schemaToPathCodec(JsonSchema.String(), openAPI, param.name) + } + } + + @tailrec + private def resolveParameterRef(openAPI: OpenAPI, key: String): OpenAPI.Parameter = + openAPI.components match { + case Some(components) => + val param = components.parameters.getOrElse( + OpenAPI.Key.fromString(key).get, + throw new Exception(s"Only references to internal parameters are supported. Not found: $key"), + ) + param match { + case ReferenceOr.Reference(ref, _, _) => resolveParameterRef(openAPI, ref) + case ReferenceOr.Or(param) => param + } + case None => + throw new Exception(s"Found reference to parameter $key, but no components section found.") + } + + @tailrec + private def resolveSchemaRef(openAPI: OpenAPI, key: String): JsonSchema = + openAPI.components match { + case Some(components) => + val schema = components.schemas.getOrElse( + OpenAPI.Key.fromString(key).get, + throw new Exception(s"Only references to internal schemas are supported. Not found: $key"), + ) + schema match { + case ReferenceOr.Reference(ref, _, _) => resolveSchemaRef(openAPI, ref) + case ReferenceOr.Or(schema) => schema + } + case None => + throw new Exception(s"Found reference to schema $key, but no components section found.") + } + + @tailrec + private def resolveRequestBodyRef(openAPI: OpenAPI, key: String): OpenAPI.RequestBody = + openAPI.components match { + case Some(components) => + val schema = components.requestBodies.getOrElse( + OpenAPI.Key.fromString(key).get, + throw new Exception(s"Only references to internal schemas are supported. Not found: $key"), + ) + schema match { + case ReferenceOr.Reference(ref, _, _) => resolveRequestBodyRef(openAPI, ref) + case ReferenceOr.Or(schema) => schema + } + case None => + throw new Exception(s"Found reference to schema $key, but no components section found.") + } + + @tailrec + private def resolveResponseRef(openAPI: OpenAPI, key: String): OpenAPI.Response = + openAPI.components match { + case Some(components) => + val schema = components.responses.getOrElse( + OpenAPI.Key.fromString(key).get, + throw new Exception(s"Only references to internal schemas are supported. Not found: $key"), + ) + schema match { + case ReferenceOr.Reference(ref, _, _) => resolveResponseRef(openAPI, ref) + case ReferenceOr.Or(schema) => schema + } + case None => + throw new Exception(s"Found reference to schema $key, but no components section found.") + } + + @tailrec + private def schemaToPathCodec(schema: JsonSchema, openAPI: OpenAPI, name: String): Code.PathSegmentCode = { + schema match { + case JsonSchema.AnnotatedSchema(s, _) => schemaToPathCodec(s, openAPI, name) + case JsonSchema.RefSchema(ref) => schemaToPathCodec(resolveSchemaRef(openAPI, ref), openAPI, name) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Int) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Long) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Timestamp) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Long) + case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.UUID) + case JsonSchema.String(_, _) => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.String) + case JsonSchema.Boolean => + Code.PathSegmentCode(name = name, segmentType = Code.CodecType.Boolean) + case JsonSchema.OneOfSchema(_) => throw new Exception("Alternative path variables are not supported") + case JsonSchema.AllOfSchema(_) => throw new Exception("Path variables must have exactly one schema") + case JsonSchema.AnyOfSchema(_) => throw new Exception("Path variables must have exactly one schema") + case JsonSchema.Number(_) => throw new Exception("Floating point path variables are currently not supported") + case JsonSchema.ArrayType(_) => throw new Exception("Array path variables are not supported") + case JsonSchema.Object(_, _, _) => throw new Exception("Object path variables are not supported") + case JsonSchema.Enum(_) => throw new Exception("Enum path variables are not supported") + case JsonSchema.Null => throw new Exception("Null path variables are not supported") + case JsonSchema.AnyJson => throw new Exception("AnyJson path variables are not supported") + } + } + + @tailrec + private def schemaToQueryParamCodec( + schema: JsonSchema, + openAPI: OpenAPI, + name: String, + ): Code.QueryParamCode = { + schema match { + case JsonSchema.AnnotatedSchema(s, _) => + schemaToQueryParamCodec(s, openAPI, name) + case JsonSchema.RefSchema(ref) => + schemaToQueryParamCodec(resolveSchemaRef(openAPI, ref), openAPI, name) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Int) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Long) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Timestamp) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Long) + case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.UUID) + case JsonSchema.String(_, _) => + Code.QueryParamCode(name = name, queryType = Code.CodecType.String) + case JsonSchema.Boolean => + Code.QueryParamCode(name = name, queryType = Code.CodecType.Boolean) + case JsonSchema.OneOfSchema(_) => throw new Exception("Alternative query parameters are not supported") + case JsonSchema.AllOfSchema(_) => throw new Exception("Query parameters must have exactly one schema") + case JsonSchema.AnyOfSchema(_) => throw new Exception("Query parameters must have exactly one schema") + case JsonSchema.Number(_) => throw new Exception("Floating point query parameters are currently not supported") + case JsonSchema.ArrayType(_) => throw new Exception("Array query parameters are not supported") + case JsonSchema.Object(_, _, _) => throw new Exception("Object query parameters are not supported") + case JsonSchema.Enum(_) => throw new Exception("Enum query parameters are not supported") + case JsonSchema.Null => throw new Exception("Null query parameters are not supported") + case JsonSchema.AnyJson => throw new Exception("AnyJson query parameters are not supported") + } + } + + def schemaToCode( + schema: JsonSchema, + openAPI: OpenAPI, + name: String, + annotations: Chunk[JsonSchema.MetaData], + ): Option[Code.File] = { + schema match { + case JsonSchema.AnnotatedSchema(s, _) => + schemaToCode(s.withoutAnnotations, openAPI, name, schema.annotations) + case JsonSchema.RefSchema(RequestBodyRef(ref)) => + val (schemaName, schema) = resolveRequestBodyRef(openAPI, ref).content + .get("application/json") + .map { mt => + mt.schema match { + case ReferenceOr.Or(s: JsonSchema) => name -> s + case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => + ref.capitalize -> resolveSchemaRef(openAPI, ref) + case other => + throw new Exception(s"Unexpected reference schema: $other") + } + } + .getOrElse(throw new Exception(s"Could not find content type application/json for request body $ref")) + schemaToCode(schema, openAPI, schemaName, annotations) + + case JsonSchema.RefSchema(SchemaRef(ref)) => + val schema = resolveSchemaRef(openAPI, ref) + schemaToCode(schema, openAPI, ref.capitalize, annotations) + + case JsonSchema.RefSchema(ResponseRef(ref)) => + val (schemaName, schema) = resolveResponseRef(openAPI, ref).content + .get("application/json") + .map { mt => + mt.schema match { + case ReferenceOr.Or(s: JsonSchema) => name -> s + case OpenAPI.ReferenceOr.Reference(SchemaRef(ref), _, _) => + ref.capitalize -> resolveSchemaRef(openAPI, ref) + case other => + throw new Exception(s"Unexpected reference schema: $other") + } + } + .getOrElse(throw new Exception(s"Could not find content type application/json for response $ref")) + schemaToCode(schema, openAPI, schemaName, annotations) + + case JsonSchema.RefSchema(ref) => throw new Exception(s"Unexpected reference schema: $ref") + case JsonSchema.Integer(_) => None + case JsonSchema.String(_, _) => None // this could maybe be im proved to generate a string type with validation + case JsonSchema.Boolean => None + case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) => + throw new Exception("OneOf schemas with primitive types are not supported") + case JsonSchema.OneOfSchema(schemas) => + val discriminatorInfo = + annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } + val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) + val caseNameMapping: Map[String, String] = discriminatorInfo.map(_.mapping).getOrElse(Map.empty).map { + case (k, v) => v -> k + } + var caseNames: List[String] = Nil + val caseClasses = schemas + .map(_.withoutAnnotations) + .flatMap { + case schema @ JsonSchema.Object(properties, _, _) if singleFieldTypeTag(schema) => + val (name, schema) = properties.head + caseNames :+= name + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + .caseClasses + case schema @ JsonSchema.RefSchema(ref @ SchemaRef(name)) => + caseNameMapping.get(ref).foreach(caseNames :+= _) + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for subtype $name of oneOf schema $schema"), + ) + .caseClasses + case schema @ JsonSchema.Object(_, _, _) => + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for subtype $name of oneOf schema $schema"), + ) + .caseClasses + case other => + throw new Exception(s"Unexpected subtype $other for oneOf schema $schema") + } + .toList + val noDiscriminator = caseNames.isEmpty + Some( + Code.File( + List("component", name.capitalize + ".scala"), + pkgPath = List("component"), + imports = DataImports ++ + (if (noDiscriminator || caseNames.nonEmpty) List(Code.Import("zio.schema.annotation._")) else Nil), + objects = Nil, + caseClasses = Nil, + enums = List( + Code.Enum( + name = name, + cases = caseClasses, + caseNames = caseNames, + discriminator = discriminator, + noDiscriminator = noDiscriminator, + schema = true, + ), + ), + ), + ) + case JsonSchema.AllOfSchema(schemas) => + val genericFieldIndex = Iterator.from(0) + val fields = schemas.map(_.withoutAnnotations).flatMap { + case schema @ JsonSchema.Object(_, _, _) => + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + .caseClasses + .headOption + .toList + .flatMap(_.fields) + case schema @ JsonSchema.RefSchema(SchemaRef(name)) => + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for subtype $name of allOf schema $schema"), + ) + .caseClasses + .headOption + .toList + .flatMap(_.fields) + case schema if schema.isPrimitive => + val name = s"field${genericFieldIndex.next()}" + Chunk(schemaToField(schema, openAPI, name, annotations)).flatten + case other => + throw new Exception(s"Unexpected subtype $other for allOf schema $schema") + } + Some( + Code.File( + List("component", name.capitalize + ".scala"), + pkgPath = List("component"), + imports = DataImports, + objects = Nil, + caseClasses = List( + Code.CaseClass( + name, + fields.toList, + companionObject = Some(Code.Object.schemaCompanion(name)), + ), + ), + enums = Nil, + ), + ) + case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) => + throw new Exception("AnyOf schemas with primitive types are not supported") + case JsonSchema.AnyOfSchema(schemas) => + val discriminatorInfo = + annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } + val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) + val caseNameMapping: Map[String, String] = discriminatorInfo.map(_.mapping).getOrElse(Map.empty).map { + case (k, v) => v -> k + } + var caseNames: List[String] = Nil + val caseClasses = schemas + .map(_.withoutAnnotations) + .flatMap { + case schema @ JsonSchema.Object(properties, _, _) if singleFieldTypeTag(schema) => + val (name, schema) = properties.head + caseNames :+= name + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + .caseClasses + case schema @ JsonSchema.RefSchema(ref @ SchemaRef(name)) => + caseNameMapping.get(ref).foreach(caseNames :+= _) + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for subtype $name of anyOf schema $schema"), + ) + .caseClasses + case schema @ JsonSchema.Object(_, _, _) => + schemaToCode(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for subtype $name of anyOf schema $schema"), + ) + .caseClasses + case other => + throw new Exception(s"Unexpected subtype $other for anyOf schema $schema") + } + .toList + Some( + Code.File( + List("component", name.capitalize + ".scala"), + pkgPath = List("component"), + imports = DataImports, + objects = Nil, + caseClasses = Nil, + enums = List( + Code.Enum( + name = name, + cases = caseClasses, + caseNames = caseNames, + discriminator = discriminator, + noDiscriminator = caseNames.isEmpty, + schema = true, + ), + ), + ), + ) + case JsonSchema.Number(_) => None + case JsonSchema.ArrayType(None) => None + case JsonSchema.ArrayType(Some(schema)) => schemaToCode(schema, openAPI, name, annotations) + // TODO use additionalProperties + case JsonSchema.Object(properties, additionalProperties, required) => + val fields = properties.map { case (name, schema) => + val field = schemaToField(schema, openAPI, name, annotations) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + .asInstanceOf[Code.Field] + if (required.contains(name)) field else field.copy(fieldType = field.fieldType.opt) + }.toList + val nested = properties.collect { + case (name, schema) if !schema.isInstanceOf[JsonSchema.RefSchema] && !schema.isPrimitive => + schemaToCode(schema, openAPI, name.capitalize, Chunk.empty) + .getOrElse( + throw new Exception(s"Could not generate code for field $name of object $name"), + ) + } + val nestedObjects = nested.flatMap(_.objects) + val nestedCaseClasses = nested.flatMap(_.caseClasses) + Some( + Code.File( + List("component", name.capitalize + ".scala"), + pkgPath = List("component"), + imports = DataImports, + objects = nestedObjects.toList, + caseClasses = List( + Code.CaseClass( + name, + fields, + companionObject = Some(Code.Object.schemaCompanion(name)), + ), + ) ++ nestedCaseClasses, + enums = Nil, + ), + ) + + case JsonSchema.Enum(enums) => + Some( + Code.File( + List("component", name.capitalize + ".scala"), + pkgPath = List("component"), + imports = DataImports, + objects = Nil, + caseClasses = Nil, + enums = List( + Code.Enum( + name, + enums.flatMap { + case JsonSchema.EnumValue.Str(e) => Some(Code.CaseClass(e)) + case JsonSchema.EnumValue.Null => + None // can be ignored here, but field of this type should be optional + case other => throw new Exception(s"OpenAPI Enums of value $other, are currently unsupported") + }.toList, + ), + ), + ), + ) + case JsonSchema.Null => throw new Exception("Null query parameters are not supported") + case JsonSchema.AnyJson => throw new Exception("AnyJson query parameters are not supported") + } + } + + private def singleFieldTypeTag(schema: JsonSchema.Object) = + schema.properties.size == 1 && + schema.properties.head._2.isInstanceOf[JsonSchema.RefSchema] && + schema.additionalProperties == Left(false) && + schema.required == Chunk(schema.properties.head._1) + + def schemaToField( + schema: JsonSchema, + openAPI: OpenAPI, + name: String, + annotations: Chunk[JsonSchema.MetaData], + ): Option[Code.Field] = { + schema match { + case JsonSchema.AnnotatedSchema(s, _) => + schemaToField(s.withoutAnnotations, openAPI, name, schema.annotations) + case JsonSchema.RefSchema(SchemaRef(ref)) => + Some(Code.Field(name, Code.TypeRef(ref.capitalize))) + case JsonSchema.RefSchema(ref) => + throw new Exception(s" Not found: $ref. Only references to internal schemas are supported.") + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) => + Some(Code.Field(name, Code.Primitive.ScalaInt)) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) => + Some(Code.Field(name, Code.Primitive.ScalaLong)) + case JsonSchema.Integer(JsonSchema.IntegerFormat.Timestamp) => + Some(Code.Field(name, Code.Primitive.ScalaLong)) + case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _) => + Some(Code.Field(name, Code.Primitive.ScalaUUID)) + case JsonSchema.String(_, _) => + Some(Code.Field(name, Code.Primitive.ScalaString)) + case JsonSchema.Boolean => + Some(Code.Field(name, Code.Primitive.ScalaBoolean)) + case JsonSchema.OneOfSchema(schemas) => + val tpe = + schemas + .map(_.withoutAnnotations) + .flatMap(schemaToField(_, openAPI, "unused", annotations)) + .map(_.fieldType) + .reduceLeft(ScalaType.Or(_, _)) + Some(Code.Field(name, tpe)) + case JsonSchema.AllOfSchema(_) => + throw new Exception("Inline allOf schemas are not supported for fields") + case JsonSchema.AnyOfSchema(schemas) => + val tpe = + schemas + .map(_.withoutAnnotations) + .flatMap(schemaToField(_, openAPI, "unused", annotations)) + .map(_.fieldType) + .reduceLeft(ScalaType.Or(_, _)) + Some(Code.Field(name, tpe)) + case JsonSchema.Number(JsonSchema.NumberFormat.Double) => + Some(Code.Field(name, Code.Primitive.ScalaDouble)) + case JsonSchema.Number(JsonSchema.NumberFormat.Float) => + Some(Code.Field(name, Code.Primitive.ScalaFloat)) + case JsonSchema.ArrayType(items) => + val tpe = items + .flatMap(schemaToField(_, openAPI, name, annotations)) + .map(_.fieldType.seq) + .orElse( + Some(Code.Primitive.ScalaString.seq), + ) + tpe.map(Code.Field(name, _)) + case JsonSchema.Object(_, _, _) => + Some(Code.Field(name, Code.TypeRef(name.capitalize))) + case JsonSchema.Enum(_) => + Some(Code.Field(name, Code.TypeRef(name.capitalize))) + case JsonSchema.Null => + Some(Code.Field(name, ScalaType.Unit)) + case JsonSchema.AnyJson => + Some(Code.Field(name, ScalaType.JsonAST)) + } + } +} diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala new file mode 100644 index 0000000000..da1f0ba796 --- /dev/null +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala @@ -0,0 +1,155 @@ +package zio.http.gen.scala + +import java.nio.file.Path + +import zio.http.{Method, Status} + +sealed trait Code extends Product with Serializable + +object Code { + sealed trait ScalaType extends Code { self => + def seq: Collection.Seq = Collection.Seq(self) + def set: Collection.Set = Collection.Set(self) + def map: Collection.Map = Collection.Map(self) + def opt: Collection.Opt = Collection.Opt(self) + } + + object ScalaType { + case object Inferred extends ScalaType + case object Unit extends ScalaType + case object JsonAST extends ScalaType + final case class Or(left: ScalaType, right: ScalaType) extends ScalaType + } + + final case class TypeRef(name: String) extends ScalaType + + final case class Files(files: List[File]) extends Code + + final case class File( + path: List[String], + pkgPath: List[String], + imports: List[Import], + objects: List[Object], + caseClasses: List[CaseClass], + enums: List[Enum], + ) extends Code + + sealed trait Import extends Code + + object Import { + def apply(name: String): Import = Absolute(name) + + final case class Absolute(path: String) extends Import + final case class FromBase(path: String) extends Import + } + + final case class Object( + name: String, + schema: Boolean, + endpoints: Map[Field, EndpointCode], + objects: List[Object], + caseClasses: List[CaseClass], + enums: List[Enum], + ) extends ScalaType + + object Object { + def schemaCompanion(str: String): Object = Object(str, schema = true, Map.empty, Nil, Nil, Nil) + + def apply(name: String, endpoints: Map[Field, EndpointCode]): Object = + Object(name, schema = false, endpoints, Nil, Nil, Nil) + } + + final case class CaseClass(name: String, fields: List[Field], companionObject: Option[Object]) extends ScalaType + + object CaseClass { + def apply(name: String): CaseClass = CaseClass(name, Nil, None) + } + + final case class Enum( + name: String, + cases: List[CaseClass], + caseNames: List[String] = Nil, + discriminator: Option[String] = None, + noDiscriminator: Boolean = false, + schema: Boolean = true, + ) extends ScalaType + + final case class Field(name: String, fieldType: ScalaType) extends Code + + object Field { + def apply(name: String): Field = Field(name, ScalaType.Inferred) + } + + sealed trait Collection extends ScalaType { + def elementType: ScalaType + } + + object Collection { + final case class Seq(elementType: ScalaType) extends Collection + final case class Set(elementType: ScalaType) extends Collection + final case class Map(elementType: ScalaType) extends Collection + final case class Opt(elementType: ScalaType) extends Collection + } + + sealed trait Primitive extends ScalaType + + object Primitive { + case object ScalaInt extends Primitive + case object ScalaLong extends Primitive + case object ScalaDouble extends Primitive + case object ScalaFloat extends Primitive + case object ScalaChar extends Primitive + case object ScalaByte extends Primitive + case object ScalaShort extends Primitive + case object ScalaBoolean extends Primitive + case object ScalaUnit extends Primitive + case object ScalaUUID extends Primitive + case object ScalaString extends Primitive + } + + final case class EndpointCode( + method: Method, + pathPatternCode: PathPatternCode, + queryParamsCode: Set[QueryParamCode], + headersCode: HeadersCode, + inCode: InCode, + outCodes: List[OutCode], + errorsCode: List[OutCode], + ) extends Code + + final case class PathPatternCode(segments: List[PathSegmentCode]) + final case class PathSegmentCode(name: String, segmentType: CodecType) + object PathSegmentCode { + def apply(name: String): PathSegmentCode = PathSegmentCode(name, CodecType.Literal) + } + sealed trait CodecType + object CodecType { + case object Boolean extends CodecType + case object Int extends CodecType + case object Literal extends CodecType + case object Long extends CodecType + case object String extends CodecType + case object UUID extends CodecType + } + final case class QueryParamCode(name: String, queryType: CodecType) + final case class HeadersCode(headers: List[HeaderCode]) + object HeadersCode { val empty: HeadersCode = HeadersCode(Nil) } + final case class HeaderCode(name: String) + final case class InCode( + inType: String, + name: Option[String], + doc: Option[String], + ) + object InCode { def apply(inType: String): InCode = InCode(inType, None, None) } + final case class OutCode( + outType: String, + status: Status, + mediaType: Option[String], + doc: Option[String], + ) + object OutCode { + def apply(outType: String, status: Status): OutCode = OutCode(outType, status, None, None) + def json(outType: String, status: Status): OutCode = OutCode(outType, status, Some("application/json"), None) + } + +} diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala new file mode 100644 index 0000000000..c66c1853c8 --- /dev/null +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala @@ -0,0 +1,277 @@ +package zio.http.gen.scala + +import java.nio.charset.StandardCharsets +import java.nio.file.StandardOpenOption._ +import java.nio.file._ + +object CodeGen { + + private val EndpointImports = + List( + Code.Import("zio.http._"), + Code.Import("zio.http.endpoint._"), + Code.Import("zio.http.codec._"), + ) + + def format(config: Path)(file: Path, content: String): String = { + import org.scalafmt.interfaces.Scalafmt + + val scalafmt: Scalafmt = Scalafmt.create(this.getClass.getClassLoader) + scalafmt.format(config, file, content) + } + + def writeFiles(files: Code.Files, basePath: Path, basePackage: String, scalafmtPath: Option[Path]): Unit = { + + val formatCode = scalafmtPath.map(format(_: Path) _).getOrElse((_: Path, content: String) => content) + + val rendered = renderedFiles(files, basePackage) + rendered.map { case (path, content) => path -> formatCode(Paths.get(path), content) }.foreach { + case (path, content) => + val filePath = Paths.get(basePath.toString, path) + Files.createDirectories(filePath.getParent) + Files.write(filePath, content.getBytes(StandardCharsets.UTF_8), CREATE, TRUNCATE_EXISTING) + } + } + + def renderedFiles(files: Code.Files, basePackage: String): Map[String, String] = + files.files.map { file => + val rendered = render(basePackage)(file) + file.path.mkString("/") -> rendered + }.toMap + + def render(basePackage: String)(structure: Code): String = structure match { + case Code.Files(_) => + throw new Exception("Files should be rendered separately") + + case Code.File(_, path, imports, objects, caseClasses, enums) => + s"package $basePackage.${path.mkString(".")}\n\n" + + s"${imports.map(render(basePackage)).mkString("\n")}\n\n" + + objects.map(render(basePackage)).mkString("\n") + + caseClasses.map(render(basePackage)).mkString("\n") + + enums.map(render(basePackage)).mkString("\n") + + case Code.Import.Absolute(path) => + s"import $path" + + case Code.Import.FromBase(path) => + s"import $basePackage.$path" + + case Code.Object(name, schema, endpoints, objects, caseClasses, enums) => + s"object $name {\n" + + (if (endpoints.nonEmpty) EndpointImports.map(render(basePackage)).mkString("", "\n", "\n") else "") + + endpoints.map { case (k, v) => s"${render(basePackage)(k)}=${render(basePackage)(v)}" } + .mkString("\n") + + (if (schema) s"\n\n implicit val codec: Schema[$name] = DeriveSchema.gen[$name]" else "") + + "\n" + objects.map(render(basePackage)).mkString("\n") + + "\n" + caseClasses.map(render(basePackage)).mkString("\n") + + "\n" + enums.map(render(basePackage)).mkString("\n") + + "\n}" + + case Code.CaseClass(name, fields, companionObject) => + s"case class $name(\n" + + fields.map(render(basePackage)).mkString(",\n").replace("val", "") + + "\n)" + companionObject.map(render(basePackage)).map("\n" + _).getOrElse("") + + case Code.Enum(name, cases, caseNames, discriminator, noDiscriminator, schema) => + val discriminatorAnnotation = + if (noDiscriminator) "@noDiscriminator\n" else "" + val discriminatorNameAnnotation = + if (discriminator.isDefined) s"""@discriminatorName("${discriminator.get}")\n""" else "" + discriminatorAnnotation + + discriminatorNameAnnotation + + s"sealed trait $name\n" + + s"object $name {\n" + + (if (schema) s"\n\n implicit val codec: Schema[$name] = DeriveSchema.gen[$name]\n" else "") + { + if (caseNames.nonEmpty) { + cases + .map(render(basePackage)) + .zipWithIndex + .map { case (c, i) => s"""@caseName("${caseNames(i)}")\n$c""" } + .mkString("\n") + } else { + cases.map(render(basePackage)).mkString("\n") + } + } + + "\n}" + + case col: Code.Collection => + col match { + case Code.Collection.Seq(elementType) => + s"Seq[${render(basePackage)(elementType)}]" + case Code.Collection.Set(elementType) => + s"Set[${render(basePackage)(elementType)}]" + case Code.Collection.Map(elementType) => + s"Map[String, ${render(basePackage)(elementType)}]" + case Code.Collection.Opt(elementType) => + s"Option[${render(basePackage)(elementType)}]" + } + + case Code.Field(name, fieldType) => + val tpe = render(basePackage)(fieldType) + if (tpe.isEmpty) s"val $name" else s"val $name: $tpe" + + case Code.Primitive.ScalaInt => "Int" + case Code.Primitive.ScalaLong => "Long" + case Code.Primitive.ScalaDouble => "Double" + case Code.Primitive.ScalaFloat => "Float" + case Code.Primitive.ScalaChar => "Char" + case Code.Primitive.ScalaByte => "Byte" + case Code.Primitive.ScalaShort => "Short" + case Code.Primitive.ScalaBoolean => "Boolean" + case Code.Primitive.ScalaUnit => "Unit" + case Code.Primitive.ScalaString => "String" + case Code.ScalaType.Inferred => "" + + case Code.EndpointCode(method, pathPatternCode, queryParamsCode, headersCode, inCode, outCodes, errorsCode) => + s"""Endpoint(Method.$method / ${pathPatternCode.segments.map(renderSegment).mkString(" / ")}) + | ${queryParamsCode.map(renderQueryCode).mkString("\n")} + | ${headersCode.headers.map(renderHeader).mkString("\n")} + | ${renderInCode(inCode)} + | ${outCodes.map(renderOutCode).mkString("\n")} + | ${errorsCode.map(renderOutErrorCode).mkString("\n")} + |""".stripMargin + + case Code.TypeRef(name) => + name + + case scalaType => + println(s"Unknown ScalaType: $scalaType") + throw new Exception(s"Unknown ScalaType: $scalaType") + } + + def renderSegment(segment: Code.PathSegmentCode): String = segment match { + case Code.PathSegmentCode(name, segmentType) => + segmentType match { + case Code.CodecType.Boolean => s"""bool("$name")""" + case Code.CodecType.Int => s"""int("$name")""" + case Code.CodecType.Long => s"""long("$name")""" + case Code.CodecType.String => s"""string("$name")""" + case Code.CodecType.UUID => s"""uuid("$name")""" + case Code.CodecType.Literal => s""""$name"""" + } + + } + + // currently, we do not support schemas + def renderHeader(header: Code.HeaderCode): String = { + val headerSelector = header.name.toLowerCase match { + case "accept" => "HeaderCodec.accept" + case "accept-encoding" => "HeaderCodec.acceptEncoding" + case "accept-language" => "HeaderCodec.acceptLanguage" + case "accept-ranges" => "HeaderCodec.acceptRanges" + case "accept-patch" => "HeaderCodec.acceptPatch" + case "access-control-allow-credentials" => "HeaderCodec.accessControlAllowCredentials" + case "access-control-allow-headers" => "HeaderCodec.accessControlAllowHeaders" + case "access-control-allow-methods" => "HeaderCodec.accessControlAllowMethods" + case "access-control-allow-origin" => "HeaderCodec.accessControlAllowOrigin" + case "access-control-expose-headers" => "HeaderCodec.accessControlExposeHeaders" + case "access-control-max-age" => "HeaderCodec.accessControlMaxAge" + case "access-control-request-headers" => "HeaderCodec.accessControlRequestHeaders" + case "access-control-request-method" => "HeaderCodec.accessControlRequestMethod" + case "age" => "HeaderCodec.age" + case "allow" => "HeaderCodec.allow" + case "authorization" => "HeaderCodec.authorization" + case "cache-control" => "HeaderCodec.cacheControl" + case "connection" => "HeaderCodec.connection" + case "content-base" => "HeaderCodec.contentBase" + case "content-encoding" => "HeaderCodec.contentEncoding" + case "content-language" => "HeaderCodec.contentLanguage" + case "content-length" => "HeaderCodec.contentLength" + case "content-location" => "HeaderCodec.contentLocation" + case "content-transfer-encoding" => "HeaderCodec.contentTransferEncoding" + case "content-disposition" => "HeaderCodec.contentDisposition" + case "content-md5" => "HeaderCodec.contentMd5" + case "content-range" => "HeaderCodec.contentRange" + case "content-security-policy" => "HeaderCodec.contentSecurityPolicy" + case "content-type" => "HeaderCodec.contentType" + case "cookie" => "HeaderCodec.cookie" + case "date" => "HeaderCodec.date" + case "dnt" => "HeaderCodec.dnt" + case "etag" => "HeaderCodec.etag" + case "expect" => "HeaderCodec.expect" + case "expires" => "HeaderCodec.expires" + case "from" => "HeaderCodec.from" + case "host" => "HeaderCodec.host" + case "if-match" => "HeaderCodec.ifMatch" + case "if-modified-since" => "HeaderCodec.ifModifiedSince" + case "if-none-match" => "HeaderCodec.ifNoneMatch" + case "if-range" => "HeaderCodec.ifRange" + case "if-unmodified-since" => "HeaderCodec.ifUnmodifiedSince" + case "last-modified" => "HeaderCodec.lastModified" + case "location" => "HeaderCodec.location" + case "max-forwards" => "HeaderCodec.maxForwards" + case "origin" => "HeaderCodec.origin" + case "pragma" => "HeaderCodec.pragma" + case "proxy-authenticate" => "HeaderCodec.proxyAuthenticate" + case "proxy-authorization" => "HeaderCodec.proxyAuthorization" + case "range" => "HeaderCodec.range" + case "referer" => "HeaderCodec.referer" + case "retry-after" => "HeaderCodec.retryAfter" + case "sec-websocket-location" => "HeaderCodec.secWebSocketLocation" + case "sec-websocket-origin" => "HeaderCodec.secWebSocketOrigin" + case "sec-websocket-protocol" => "HeaderCodec.secWebSocketProtocol" + case "sec-websocket-version" => "HeaderCodec.secWebSocketVersion" + case "sec-websocket-key" => "HeaderCodec.secWebSocketKey" + case "sec-websocket-accept" => "HeaderCodec.secWebSocketAccept" + case "sec-websocket-extensions" => "HeaderCodec.secWebSocketExtensions" + case "server" => "HeaderCodec.server" + case "set-cookie" => "HeaderCodec.setCookie" + case "te" => "HeaderCodec.te" + case "trailer" => "HeaderCodec.trailer" + case "transfer-encoding" => "HeaderCodec.transferEncoding" + case "upgrade" => "HeaderCodec.upgrade" + case "upgrade-insecure-requests" => "HeaderCodec.upgradeInsecureRequests" + case "user-agent" => "HeaderCodec.userAgent" + case "vary" => "HeaderCodec.vary" + case "via" => "HeaderCodec.via" + case "warning" => "HeaderCodec.warning" + case "web-socket-location" => "HeaderCodec.webSocketLocation" + case "web-socket-origin" => "HeaderCodec.webSocketOrigin" + case "web-socket-protocol" => "HeaderCodec.webSocketProtocol" + case "www-authenticate" => "HeaderCodec.wwwAuthenticate" + case "x-frame-options" => "HeaderCodec.xFrameOptions" + case "x-requested-with" => "HeaderCodec.xRequestedWith" + case name => s"HeaderCodec.name[String]($name)" + } + s""".header($headerSelector)""" + } + + def renderQueryCode(queryCode: Code.QueryParamCode): String = queryCode match { + case Code.QueryParamCode(name, queryType) => + val tpe = queryType match { + case Code.CodecType.Boolean => "Boolean" + case Code.CodecType.Int => "Int" + case Code.CodecType.Long => "Long" + case Code.CodecType.String => "String" + case Code.CodecType.UUID => "UUID" + case Code.CodecType.Literal => throw new Exception("Literal query params are not supported") + } + s""".query(QueryCodec.queryAs[$tpe]("$name"))""" + } + + def renderInCode(inCode: Code.InCode): String = inCode match { + case Code.InCode(inType, Some(name), Some(doc)) => + s""".in[$inType](name = "$name", doc = md""\"$doc"\"")""" + case Code.InCode(inType, Some(name), None) => + s""".in[$inType](name = "$name")""" + case Code.InCode(inType, None, Some(doc)) => + s""".in[$inType](doc = md""\"$doc"\"")""" + case Code.InCode(inType, None, None) => + s".in[$inType]" + } + + def renderOutCode(outCode: Code.OutCode): String = outCode match { + case Code.OutCode(outType, status, _, Some(doc)) => + s""".out[$outType](status = Status.$status, doc = md""\"$doc"\"")""" + case Code.OutCode(outType, status, _, None) => + s""".out[$outType](status = Status.$status)""" + } + + def renderOutErrorCode(errOutCode: Code.OutCode): String = errOutCode match { + case Code.OutCode(outType, status, _, Some(doc)) => + s""".outError[$outType](status = Status.$status, doc = md""\"$doc"\"")""" + case Code.OutCode(outType, status, _, None) => + s""".outError[$outType](status = Status.$status)""" + } + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithEnumInput.scala b/zio-http-gen/src/test/resources/EndpointWithEnumInput.scala new file mode 100644 index 0000000000..3660a153e7 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithEnumInput.scala @@ -0,0 +1,12 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[Payment] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithEnumInputNamedDiscriminator.scala b/zio-http-gen/src/test/resources/EndpointWithEnumInputNamedDiscriminator.scala new file mode 100644 index 0000000000..800e80da15 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithEnumInputNamedDiscriminator.scala @@ -0,0 +1,12 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[PaymentNamedDiscriminator] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithEnumInputNoDiscriminator.scala b/zio-http-gen/src/test/resources/EndpointWithEnumInputNoDiscriminator.scala new file mode 100644 index 0000000000..0d4d621234 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithEnumInputNoDiscriminator.scala @@ -0,0 +1,12 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[PaymentNoDiscriminator] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithHeaders.scala b/zio-http-gen/src/test/resources/EndpointWithHeaders.scala new file mode 100644 index 0000000000..677c72cf24 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithHeaders.scala @@ -0,0 +1,14 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val get = Endpoint(Method.GET / "api" / "v1" / "users") + .header(HeaderCodec.accept) + .header(HeaderCodec.contentType) + .in[Unit] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithQueryParams.scala b/zio-http-gen/src/test/resources/EndpointWithQueryParams.scala new file mode 100644 index 0000000000..a264407f89 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithQueryParams.scala @@ -0,0 +1,14 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val get = Endpoint(Method.GET / "api" / "v1" / "users") + .query(QueryCodec.queryAs[Int]("limit")) + .query(QueryCodec.queryAs[String]("name")) + .in[Unit] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithRequestBody.scala b/zio-http-gen/src/test/resources/EndpointWithRequestBody.scala new file mode 100644 index 0000000000..cf0817ae0c --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithRequestBody.scala @@ -0,0 +1,12 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[User] + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithRequestResponseBody.scala b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBody.scala new file mode 100644 index 0000000000..274169f005 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBody.scala @@ -0,0 +1,13 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[User] + .out[User](status = Status.Ok) + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInline.scala b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInline.scala new file mode 100644 index 0000000000..5ef216a88e --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInline.scala @@ -0,0 +1,36 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[POST.RequestBody] + .out[POST.ResponseBody](status = Status.Ok) + + object POST { + + case class RequestBody( + id: Int, + name: String, + ) + object RequestBody { + + implicit val codec: Schema[RequestBody] = DeriveSchema.gen[RequestBody] + + } + case class ResponseBody( + id: Int, + name: String, + ) + object ResponseBody { + + implicit val codec: Schema[ResponseBody] = DeriveSchema.gen[ResponseBody] + + } + + } + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineNested.scala b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineNested.scala new file mode 100644 index 0000000000..c503a0f509 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithRequestResponseBodyInlineNested.scala @@ -0,0 +1,46 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[POST.RequestBody] + .out[POST.ResponseBody](status = Status.Ok) + + object POST { + + case class RequestBody( + id: Int, + name: String, + address: Option[Address], + ) + object RequestBody { + + implicit val codec: Schema[RequestBody] = DeriveSchema.gen[RequestBody] + + } + case class Address( + number: Option[Int], + street: Option[String], + ) + object Address { + + implicit val codec: Schema[Address] = DeriveSchema.gen[Address] + + } + case class ResponseBody( + id: Int, + name: String, + ) + object ResponseBody { + + implicit val codec: Schema[ResponseBody] = DeriveSchema.gen[ResponseBody] + + } + + } + +} diff --git a/zio-http-gen/src/test/resources/EndpointWithResponseBody.scala b/zio-http-gen/src/test/resources/EndpointWithResponseBody.scala new file mode 100644 index 0000000000..88ea38d618 --- /dev/null +++ b/zio-http-gen/src/test/resources/EndpointWithResponseBody.scala @@ -0,0 +1,13 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val post = Endpoint(Method.POST / "api" / "v1" / "users") + .in[Unit] + .out[User](status = Status.Ok) + +} diff --git a/zio-http-gen/src/test/resources/GeneratedPayment.scala b/zio-http-gen/src/test/resources/GeneratedPayment.scala new file mode 100644 index 0000000000..dc1324a71c --- /dev/null +++ b/zio-http-gen/src/test/resources/GeneratedPayment.scala @@ -0,0 +1,29 @@ +package test.component + +import zio.schema._ +import zio.schema.annotation._ + +sealed trait Payment +object Payment { + + implicit val codec: Schema[Payment] = DeriveSchema.gen[Payment] + @caseName("Card") + case class Card( + number: String, + cvv: String, + ) + object Card { + + implicit val codec: Schema[Card] = DeriveSchema.gen[Card] + + } + @caseName("cash") + case class Cash( + amount: Int, + ) + object Cash { + + implicit val codec: Schema[Cash] = DeriveSchema.gen[Cash] + + } +} diff --git a/zio-http-gen/src/test/resources/GeneratedPaymentNamedDiscriminator.scala b/zio-http-gen/src/test/resources/GeneratedPaymentNamedDiscriminator.scala new file mode 100644 index 0000000000..843175202d --- /dev/null +++ b/zio-http-gen/src/test/resources/GeneratedPaymentNamedDiscriminator.scala @@ -0,0 +1,30 @@ +package test.component + +import zio.schema._ +import zio.schema.annotation._ + +@discriminatorName("type") +sealed trait PaymentNamedDiscriminator +object PaymentNamedDiscriminator { + + implicit val codec: Schema[PaymentNamedDiscriminator] = DeriveSchema.gen[PaymentNamedDiscriminator] + @caseName("Card") + case class Card( + number: String, + cvv: String, + ) + object Card { + + implicit val codec: Schema[Card] = DeriveSchema.gen[Card] + + } + @caseName("cash") + case class Cash( + amount: Int, + ) + object Cash { + + implicit val codec: Schema[Cash] = DeriveSchema.gen[Cash] + + } +} diff --git a/zio-http-gen/src/test/resources/GeneratedPaymentNoDiscriminator.scala b/zio-http-gen/src/test/resources/GeneratedPaymentNoDiscriminator.scala new file mode 100644 index 0000000000..64c0b250f7 --- /dev/null +++ b/zio-http-gen/src/test/resources/GeneratedPaymentNoDiscriminator.scala @@ -0,0 +1,28 @@ +package test.component + +import zio.schema._ +import zio.schema.annotation._ + +@noDiscriminator +sealed trait PaymentNoDiscriminator +object PaymentNoDiscriminator { + + implicit val codec: Schema[PaymentNoDiscriminator] = DeriveSchema.gen[PaymentNoDiscriminator] + case class Card( + number: String, + cvv: String, + ) + object Card { + + implicit val codec: Schema[Card] = DeriveSchema.gen[Card] + + } + case class Cash( + amount: Int, + ) + object Cash { + + implicit val codec: Schema[Cash] = DeriveSchema.gen[Cash] + + } +} diff --git a/zio-http-gen/src/test/resources/GeneratedUser.scala b/zio-http-gen/src/test/resources/GeneratedUser.scala new file mode 100644 index 0000000000..22d95ad94f --- /dev/null +++ b/zio-http-gen/src/test/resources/GeneratedUser.scala @@ -0,0 +1,13 @@ +package test.component + +import zio.schema._ + +case class User( + id: Int, + name: String, +) +object User { + + implicit val codec: Schema[User] = DeriveSchema.gen[User] + +} diff --git a/zio-http-gen/src/test/resources/UserIdUnitInOut.scala b/zio-http-gen/src/test/resources/UserIdUnitInOut.scala new file mode 100644 index 0000000000..dedd852523 --- /dev/null +++ b/zio-http-gen/src/test/resources/UserIdUnitInOut.scala @@ -0,0 +1,12 @@ +package test.api.v1.users + +import test.component._ + +object UserId { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val get = Endpoint(Method.GET / "api" / "v1" / "users" / int("userId")) + .in[Unit] + +} diff --git a/zio-http-gen/src/test/resources/UsersUnitInOut.scala b/zio-http-gen/src/test/resources/UsersUnitInOut.scala new file mode 100644 index 0000000000..8eda9a0013 --- /dev/null +++ b/zio-http-gen/src/test/resources/UsersUnitInOut.scala @@ -0,0 +1,12 @@ +package test.api.v1 + +import test.component._ + +object Users { + import zio.http._ + import zio.http.endpoint._ + import zio.http.codec._ + val get = Endpoint(Method.GET / "api" / "v1" / "users") + .in[Unit] + +} diff --git a/zio-http-gen/src/test/resources/inline_schema.json b/zio-http-gen/src/test/resources/inline_schema.json new file mode 100644 index 0000000000..31c6025d42 --- /dev/null +++ b/zio-http-gen/src/test/resources/inline_schema.json @@ -0,0 +1,78 @@ +{ + "openapi" : "3.1.0", + "info" : { + "title" : "", + "version" : "" + }, + "paths" : { + "/api/v1/users" : { + "post" : { + "requestBody" : + { + "content" : { + "application/json" : { + "schema" : + { + "type" : + "object", + "properties" : { + "id" : { + "type" : + "integer", + "format" : "int32" + }, + "name" : { + "type" : + "string" + } + }, + "additionalProperties" : + true, + "required" : [ + "id", + "name" + ] + } + + } + }, + "required" : true + }, + "responses" : { + "200" : + { + "description" : "", + "content" : { + "application/json" : { + "schema" : + { + "type" : + "object", + "properties" : { + "id" : { + "type" : + "integer", + "format" : "int32" + }, + "name" : { + "type" : + "string" + } + }, + "additionalProperties" : + true, + "required" : [ + "id", + "name" + ] + } + + } + } + } + }, + "deprecated" : false + } + } + } +} diff --git a/zio-http-gen/src/test/resources/inline_schema_nested.json b/zio-http-gen/src/test/resources/inline_schema_nested.json new file mode 100644 index 0000000000..a61ed80b51 --- /dev/null +++ b/zio-http-gen/src/test/resources/inline_schema_nested.json @@ -0,0 +1,76 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "", + "version": "" + }, + "paths": { + "/api/v1/users": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int32" + }, + "name": { + "type": "string" + }, + "address": { + "type": "object", + "properties": { + "number": { + "type": "integer", + "format": "int32" + }, + "street": { + "type": "string" + } + } + } + }, + "additionalProperties": true, + "required": [ + "id", + "name" + ] + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int32" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": true, + "required": [ + "id", + "name" + ] + } + } + } + } + }, + "deprecated": false + } + } + } +} diff --git a/zio-http-gen/src/test/resources/scalafmt.conf b/zio-http-gen/src/test/resources/scalafmt.conf new file mode 100644 index 0000000000..022e17afb1 --- /dev/null +++ b/zio-http-gen/src/test/resources/scalafmt.conf @@ -0,0 +1,27 @@ +version = 3.7.17 +maxColumn = 120 + +align.preset = more +align.multiline = true +align.stripMargin = true + +continuationIndent.defnSite = 2 +assumeStandardLibraryStripMargin = true +danglingParentheses.preset = true +docstrings = JavaDoc +lineEndings = preserve +includeCurlyBraceInSelectChains = false +spaces.inImportCurlyBraces = false +optIn.annotationNewlines = true + +rewrite.rules = [Imports, RedundantBraces, SortModifiers] +rewrite.imports.sort = original + +docstrings.wrap = yes +docstrings.style = Asterisk + +newlines.afterInfix = keep +rewrite.rules = [RedundantParens] +trailingCommas = "always" +runner.dialect = Scala213Source3 +docstrings.wrapMaxColumn = 80 diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/Direction.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/Direction.scala new file mode 100644 index 0000000000..2be4ba7c7c --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/Direction.scala @@ -0,0 +1,13 @@ +package zio.http.gen.model + +import zio.schema._ + +sealed trait Direction +object Direction { + case object North extends Direction + case object South extends Direction + case object East extends Direction + case object West extends Direction + + implicit val codec: Schema[Direction] = DeriveSchema.gen[Direction] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/Payment.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/Payment.scala new file mode 100644 index 0000000000..b769670d17 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/Payment.scala @@ -0,0 +1,13 @@ +package zio.http.gen.model + +import zio.schema._ +import zio.schema.annotation._ + +sealed trait Payment +object Payment { + case class Card(number: String, cvv: String) extends Payment + @caseName("cash") + case class Cash(amount: Int) extends Payment + + implicit val codec: Schema[Payment] = DeriveSchema.gen[Payment] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNamedDiscriminator.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNamedDiscriminator.scala new file mode 100644 index 0000000000..35c0772176 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNamedDiscriminator.scala @@ -0,0 +1,14 @@ +package zio.http.gen.model + +import zio.schema._ +import zio.schema.annotation._ + +@discriminatorName("type") +sealed trait PaymentNamedDiscriminator +object PaymentNamedDiscriminator { + case class Card(number: String, cvv: String) extends PaymentNamedDiscriminator + @caseName("cash") + case class Cash(amount: Int) extends PaymentNamedDiscriminator + + implicit val codec: Schema[PaymentNamedDiscriminator] = DeriveSchema.gen[PaymentNamedDiscriminator] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNoDiscriminator.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNoDiscriminator.scala new file mode 100644 index 0000000000..ba4b8c7243 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/PaymentNoDiscriminator.scala @@ -0,0 +1,14 @@ +package zio.http.gen.model + +import zio.schema._ +import zio.schema.annotation._ + +@noDiscriminator +sealed trait PaymentNoDiscriminator +object PaymentNoDiscriminator { + case class Card(number: String, cvv: String) extends PaymentNoDiscriminator + @caseName("cash") + case class Cash(amount: Int) extends PaymentNoDiscriminator + + implicit val codec: Schema[PaymentNoDiscriminator] = DeriveSchema.gen[PaymentNoDiscriminator] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala b/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala new file mode 100644 index 0000000000..782fe12a83 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/model/User.scala @@ -0,0 +1,9 @@ +package zio.http.gen.model + +import zio.schema._ +import zio.schema.annotation._ + +case class User(id: Int, name: String) +object User { + implicit val codec: Schema[User] = DeriveSchema.gen[User] +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala new file mode 100644 index 0000000000..126bfbc6e1 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/openapi/EndpointGenSpec.scala @@ -0,0 +1,935 @@ +package zio.http.gen.openapi + +import java.nio.file._ + +import zio._ +import zio.test._ + +import zio.http._ +import zio.http.codec.HeaderCodec +import zio.http.codec.HttpCodec.{query, queryInt} +import zio.http.endpoint._ +import zio.http.endpoint.openapi.JsonSchema.SchemaStyle.Inline +import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} +import zio.http.gen.model._ +import zio.http.gen.scala.Code + +object EndpointGenSpec extends ZIOSpecDefault { + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("EndpointGenSpec")( + suite("file gen spec")( + test("right package and file name") { + val openAPI = OpenAPI.empty.path( + OpenAPI.Path.fromString("/api/v1/users").get, + OpenAPI.PathItem.empty.addGet( + OpenAPI.Operation( + summary = None, + externalDocs = None, + operationId = None, + requestBody = None, + description = None, + ), + ), + ) + val scala = EndpointGen.fromOpenAPI(openAPI) + val filePath = Paths.get("/api/v1", "Users.scala") + val pkgPath = List("api", "v1") + val firstFile = scala.files.head + assertTrue(firstFile.pkgPath == pkgPath, firstFile.path.mkString("/", "/", "") == filePath.toString) + }, + test("right package and file name with path parameters") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / int("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val filePath = Paths.get("/api/v1/users", "UserId.scala") + val pkgPath = List("api", "v1", "users") + val firstFile = scala.files.head + assertTrue(firstFile.pkgPath == pkgPath, firstFile.path.mkString("/", "/", "") == filePath.toString) + }, + ), + suite("endpoint gen spec")( + test("empty request and response") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users") + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + Nil, + Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with int path parameter") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / int("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Int), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with string path parameter") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / string("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> + Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.String), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with long path parameter") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / long("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> + Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Long), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with uuid path parameter") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / uuid("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> + Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.UUID), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with boolean path parameter") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / boolean("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> + Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Boolean), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with accept header") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").header(HeaderCodec.accept) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode(List(Code.HeaderCode("accept"))), + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with accept and content-type headers") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users") + .header(HeaderCodec.accept) + .header(HeaderCodec.contentType) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode(List(Code.HeaderCode("accept"), Code.HeaderCode("content-type"))), + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("empty request and response with accept and content-type headers and query parameters") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users") + .header(HeaderCodec.accept) + .header(HeaderCodec.contentType) + .query(queryInt("limit")) + .query(query("name")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set( + Code.QueryParamCode("limit", Code.CodecType.Int), + Code.QueryParamCode("name", Code.CodecType.String), + ), + headersCode = Code.HeadersCode(List(Code.HeaderCode("accept"), Code.HeaderCode("content-type"))), + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test( + "empty request and response with accept and content-type headers and query parameters and path parameters", + ) { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / int("userId")) + .header(HeaderCodec.accept) + .header(HeaderCodec.contentType) + .query(queryInt("limit")) + .query(query("name")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Int), + ), + ), + queryParamsCode = Set( + Code.QueryParamCode("limit", Code.CodecType.Int), + Code.QueryParamCode("name", Code.CodecType.String), + ), + headersCode = Code.HeadersCode(List(Code.HeaderCode("accept"), Code.HeaderCode("content-type"))), + inCode = Code.InCode("Unit"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("request body and empty response") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("User"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("request body and empty response with int path parameter") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users" / int("userId")).in[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Int), + ), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("User"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("request body and empty response with path parameter and query parameters") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users" / int("userId")) + .in[User] + .query(queryInt("limit")) + .query(query("name")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Int), + ), + ), + queryParamsCode = Set( + Code.QueryParamCode("limit", Code.CodecType.Int), + Code.QueryParamCode("name", Code.CodecType.String), + ), + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("User"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("request body and empty response with path parameter and query parameters and headers") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users" / int("userId")) + .in[User] + .query(queryInt("limit")) + .query(query("name")) + .header(HeaderCodec.accept) + .header(HeaderCodec.contentType) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "users", "UserId.scala"), + pkgPath = List("api", "v1", "users"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "UserId", + Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List( + Code.PathSegmentCode("api"), + Code.PathSegmentCode("v1"), + Code.PathSegmentCode("users"), + Code.PathSegmentCode("userId", Code.CodecType.Int), + ), + ), + queryParamsCode = Set( + Code.QueryParamCode("limit", Code.CodecType.Int), + Code.QueryParamCode("name", Code.CodecType.String), + ), + headersCode = Code.HeadersCode(List(Code.HeaderCode("accept"), Code.HeaderCode("content-type"))), + inCode = Code.InCode("User"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + test("response and empty request") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = List(Code.OutCode.json("User", Status.Ok)), + errorsCode = Nil, + ), + ), + ), + ), + caseClasses = Nil, + enums = Nil, + ) + assertTrue(scala.files.head == expected) + }, + ), + suite("data gen spec")( + test("generates case class, companion object and schema") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("component", "User.scala"), + pkgPath = List("component"), + imports = List(Code.Import(name = "zio.schema._")), + objects = List.empty, + caseClasses = List( + Code.CaseClass( + "User", + fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString), + ), + companionObject = Some(Code.Object.schemaCompanion("User")), + ), + ), + Nil, + ) + assertTrue(scala.files.tail.head == expected) + }, + test("generates simple enum and schema") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[Direction] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("component", "Direction.scala"), + pkgPath = List("component"), + imports = List(Code.Import(name = "zio.schema._")), + objects = List.empty, + caseClasses = List.empty, + enums = List( + Code.Enum( + "Direction", + List( + Code.CaseClass("North"), + Code.CaseClass("South"), + Code.CaseClass("East"), + Code.CaseClass("West"), + ), + schema = true, + ), + ), + ) + assertTrue(scala.files.tail.head == expected) + }, + test("generates enum with values and schema") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[Payment] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("component", "Payment.scala"), + pkgPath = List("component"), + imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + objects = List.empty, + caseClasses = List.empty, + enums = List( + Code.Enum( + name = "Payment", + cases = List( + Code.CaseClass( + "Card", + fields = List( + Code.Field("number", Code.Primitive.ScalaString), + Code.Field("cvv", Code.Primitive.ScalaString), + ), + companionObject = Some(Code.Object.schemaCompanion("Card")), + ), + Code.CaseClass( + "Cash", + fields = List( + Code.Field("amount", Code.Primitive.ScalaInt), + ), + companionObject = Some(Code.Object.schemaCompanion("Cash")), + ), + ), + caseNames = List("Card", "cash"), + schema = true, + ), + ), + ) + assertTrue(scala.files.last == expected) + }, + test("generates enum with values and schema with named discriminator") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[PaymentNamedDiscriminator] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("component", "PaymentNamedDiscriminator.scala"), + pkgPath = List("component"), + imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + objects = List.empty, + caseClasses = List.empty, + enums = List( + Code.Enum( + name = "PaymentNamedDiscriminator", + cases = List( + Code.CaseClass( + "Card", + fields = List( + Code.Field("number", Code.Primitive.ScalaString), + Code.Field("cvv", Code.Primitive.ScalaString), + ), + companionObject = Some(Code.Object.schemaCompanion("Card")), + ), + Code.CaseClass( + "Cash", + fields = List( + Code.Field("amount", Code.Primitive.ScalaInt), + ), + companionObject = Some(Code.Object.schemaCompanion("Cash")), + ), + ), + caseNames = List("Card", "cash"), + discriminator = Some("type"), + noDiscriminator = false, + schema = true, + ), + ), + ) + assertTrue(scala.files.last == expected) + }, + test("generates enum with values and schema with no discriminator") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[PaymentNoDiscriminator] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val scala = EndpointGen.fromOpenAPI(openAPI) + val expected = Code.File( + List("component", "PaymentNoDiscriminator.scala"), + pkgPath = List("component"), + imports = List(Code.Import(name = "zio.schema._"), Code.Import(name = "zio.schema.annotation._")), + objects = List.empty, + caseClasses = List.empty, + enums = List( + Code.Enum( + name = "PaymentNoDiscriminator", + cases = List( + Code.CaseClass( + "Card", + fields = List( + Code.Field("number", Code.Primitive.ScalaString), + Code.Field("cvv", Code.Primitive.ScalaString), + ), + companionObject = Some(Code.Object.schemaCompanion("Card")), + ), + Code.CaseClass( + "Cash", + fields = List( + Code.Field("amount", Code.Primitive.ScalaInt), + ), + companionObject = Some(Code.Object.schemaCompanion("Cash")), + ), + ), + caseNames = Nil, + discriminator = None, + noDiscriminator = true, + schema = true, + ), + ), + ) + assertTrue(scala.files.last == expected) + }, + test("generates case class for request with inlined schema") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[User] + val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint).copy(components = None) + val scala = EndpointGen.fromOpenAPI(openAPI) + val fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString), + ) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + schema = false, + endpoints = Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("POST.RequestBody"), + outCodes = Nil, + errorsCode = Nil, + ), + ), + objects = List( + Code.Object( + "POST", + schema = false, + endpoints = Map.empty, + objects = Nil, + caseClasses = List( + Code.CaseClass( + "RequestBody", + fields = fields, + companionObject = Some(Code.Object.schemaCompanion("RequestBody")), + ), + ), + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + + assertTrue(scala.files.head == expected) + }, + test("generates case class for response with inlined schema") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users").out[User] + val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint).copy(components = None) + val scala = EndpointGen.fromOpenAPI(openAPI) + val fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString), + ) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + schema = false, + endpoints = Map( + Code.Field("get") -> Code.EndpointCode( + Method.GET, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("Unit"), + outCodes = List(Code.OutCode.json("GET.ResponseBody", Status.Ok)), + errorsCode = Nil, + ), + ), + objects = List( + Code.Object( + "GET", + schema = false, + endpoints = Map.empty, + objects = Nil, + caseClasses = List( + Code.CaseClass( + "ResponseBody", + fields = fields, + companionObject = Some(Code.Object.schemaCompanion("ResponseBody")), + ), + ), + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + + assertTrue(scala.files.head == expected) + }, + test("generates case class for request and response with inlined schema") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[User].out[User] + val openAPI = OpenAPIGen.fromEndpoints("", "", Inline, endpoint).copy(components = None) + val scala = EndpointGen.fromOpenAPI(openAPI) + val fields = List( + Code.Field("id", Code.Primitive.ScalaInt), + Code.Field("name", Code.Primitive.ScalaString), + ) + val expected = Code.File( + List("api", "v1", "Users.scala"), + pkgPath = List("api", "v1"), + imports = List(Code.Import.FromBase(path = "component._")), + objects = List( + Code.Object( + "Users", + schema = false, + endpoints = Map( + Code.Field("post") -> Code.EndpointCode( + Method.POST, + Code.PathPatternCode(segments = + List(Code.PathSegmentCode("api"), Code.PathSegmentCode("v1"), Code.PathSegmentCode("users")), + ), + queryParamsCode = Set.empty, + headersCode = Code.HeadersCode.empty, + inCode = Code.InCode("POST.RequestBody"), + outCodes = List(Code.OutCode.json("POST.ResponseBody", Status.Ok)), + errorsCode = Nil, + ), + ), + objects = List( + Code.Object( + "POST", + schema = false, + endpoints = Map.empty, + objects = Nil, + caseClasses = List( + Code.CaseClass( + "RequestBody", + fields = fields, + companionObject = Some(Code.Object.schemaCompanion("RequestBody")), + ), + Code.CaseClass( + "ResponseBody", + fields = fields, + companionObject = Some(Code.Object.schemaCompanion("ResponseBody")), + ), + ), + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ), + ), + caseClasses = Nil, + enums = Nil, + ) + + assertTrue(scala.files.head == expected) + }, + ), + ) + +} diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala new file mode 100644 index 0000000000..c109ba5584 --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -0,0 +1,203 @@ +package zio.http.gen.scala + +import java.io.File +import java.nio.file._ + +import scala.jdk.CollectionConverters._ + +import zio.Scope +import zio.test._ + +import zio.http._ +import zio.http.codec._ +import zio.http.endpoint.Endpoint +import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} +import zio.http.gen.model._ +import zio.http.gen.openapi.EndpointGen + +object CodeGenSpec extends ZIOSpecDefault { + + private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = { + val filePath = dir.resolve(Paths.get(subPath)) + val generated = Files.readAllLines(filePath).asScala.mkString("\n") + val url = getClass.getResource(expectedFile) + val expected = java.nio.file.Paths.get(url.toURI.getPath) + val expectedLines = Files.readAllLines(expected).asScala.mkString("\n") + assertTrue(generated == expectedLines) + } + + private val java11OrNewer = { + val version = System.getProperty("java.version") + if (version.takeWhile(_ != '.').toInt >= 11) TestAspect.identity else TestAspect.ignore + } + + private val scalaFmtPath = java.nio.file.Paths.get(getClass.getResource("/scalafmt.conf").toURI) + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("CodeGenSpec")( + test("Simple endpoint without data structures") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users") + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/UsersUnitInOut.scala") + }, + test("Endpoint with path parameters") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users" / int("userId")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/users/UserId.scala", "/UserIdUnitInOut.scala") + }, + test("Endpoint with query parameters") { + val endpoint = Endpoint(Method.GET / "api" / "v1" / "users") + .query(QueryCodec.queryInt("limit")) + .query(QueryCodec.query("name")) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithQueryParams.scala") + }, + test("Endpoint with headers") { + val endpoint = + Endpoint(Method.GET / "api" / "v1" / "users").header(HeaderCodec.accept).header(HeaderCodec.contentType) + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithHeaders.scala") + }, + test("Endpoint with request body") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithRequestBody.scala") && + fileShouldBe(tempDir, "test/component/User.scala", "/GeneratedUser.scala") + }, + test("Endpoint with response body") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").out[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithResponseBody.scala") && + fileShouldBe(tempDir, "test/component/User.scala", "/GeneratedUser.scala") + }, + test("Endpoint with request and response body") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[User].out[User] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithRequestResponseBody.scala") && + fileShouldBe(tempDir, "test/component/User.scala", "/GeneratedUser.scala") + }, + test("OpenAPI spec with inline schema request and response body") { + val openAPIString = + Files.readAllLines(Paths.get(getClass.getResource("/inline_schema.json").toURI)).asScala.mkString("\n") + val openAPI = OpenAPI.fromJson(openAPIString).getOrElse(OpenAPI.empty) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithRequestResponseBodyInline.scala") + } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 + test("OpenAPI spec with inline schema request and response body, with nested object schema") { + val openAPIString = + Files.readAllLines(Paths.get(getClass.getResource("/inline_schema_nested.json").toURI)).asScala.mkString("\n") + val openAPI = OpenAPI.fromJson(openAPIString).getOrElse(OpenAPI.empty) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe(tempDir, "test/api/v1/Users.scala", "/EndpointWithRequestResponseBodyInlineNested.scala") + } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 + test("Endpoint with enum input") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[Payment] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe( + tempDir, + "test/api/v1/Users.scala", + "/EndpointWithEnumInput.scala", + ) && + fileShouldBe( + tempDir, + "test/component/Payment.scala", + "/GeneratedPayment.scala", + ) + }, + test("Endpoint with enum input with named discriminator") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[PaymentNamedDiscriminator] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe( + tempDir, + "test/api/v1/Users.scala", + "/EndpointWithEnumInputNamedDiscriminator.scala", + ) && + fileShouldBe( + tempDir, + "test/component/PaymentNamedDiscriminator.scala", + "/GeneratedPaymentNamedDiscriminator.scala", + ) + }, + test("Endpoint with enum input no discriminator") { + val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[PaymentNoDiscriminator] + val openAPI = OpenAPIGen.fromEndpoints(endpoint) + val code = EndpointGen.fromOpenAPI(openAPI) + + val tempDir = Files.createTempDirectory("codegen") + CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) + + fileShouldBe( + tempDir, + "test/api/v1/Users.scala", + "/EndpointWithEnumInputNoDiscriminator.scala", + ) && + fileShouldBe( + tempDir, + "test/component/PaymentNoDiscriminator.scala", + "/GeneratedPaymentNoDiscriminator.scala", + ) + }, + ) @@ java11OrNewer +} diff --git a/zio-http/src/main/scala/zio/http/ConnectionPool.scala b/zio-http/src/main/scala/zio/http/ConnectionPool.scala index 6e9886f740..a3eed38b0e 100644 --- a/zio-http/src/main/scala/zio/http/ConnectionPool.scala +++ b/zio-http/src/main/scala/zio/http/ConnectionPool.scala @@ -26,6 +26,7 @@ trait ConnectionPool[Connection] { location: URL.Location.Absolute, proxy: Option[Proxy], sslOptions: ClientSSLConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, decompression: Decompression, idleTimeout: Option[Duration], diff --git a/zio-http/src/main/scala/zio/http/Flash.scala b/zio-http/src/main/scala/zio/http/Flash.scala new file mode 100644 index 0000000000..57224e607a --- /dev/null +++ b/zio-http/src/main/scala/zio/http/Flash.scala @@ -0,0 +1,397 @@ +package zio.http + +import java.net.{URLDecoder, URLEncoder} +import java.nio.charset.StandardCharsets +import java.util.UUID + +import zio._ + +import zio.schema.Schema +import zio.schema.codec.JsonCodec + +import zio.http.template._ + +/** + * `Flash` represents a flash value that one can retrieve from the flash scope. + * + * The flash scope consists of a serialized and url-encoded json object built + * with `zio-schema`. + */ +sealed trait Flash[+A] { self => + + final def flatMap[B](f: A => Flash[B]): Flash[B] = Flash.FlatMap(self, f) + + final def map[B](f: A => B): Flash[B] = self.flatMap(a => Flash.succeed(f(a))) + + final def orElse[B >: A](that: => Flash[B]): Flash[B] = Flash.OrElse(self, that) + + /** + * Operator alias for `orElse`. + */ + final def <>[B >: A](that: => Flash[B]): Flash[B] = self.orElse(that) + + final def zip[B](that: => Flash[B]): Flash[(A, B)] = self.zipWith(that)((a, b) => a -> b) + + /** + * Operator alias for `zip`. + */ + final def <*>[B](that: => Flash[B]): Flash[(A, B)] = self.zip(that) + + final def zipWith[B, C](that: => Flash[B])(f: (A, B) => C): Flash[C] = + self.flatMap(a => that.map(b => f(a, b))) + + final def optional: Flash[Option[A]] = self.map(Option(_)) <> Flash.succeed(None) + + final def foldHtml[A1 >: A, B](f: Html => B, g: Html => B)(h: (B, B) => B)(implicit + ev: A1 =:= Flash.Message[Html, Html], + ): Flash[B] = + self.map(a => a.asInstanceOf[A1].fold(f, g)(h)) + + final def toHtml[A1 >: A](implicit ev: A1 =:= String): Flash[Html] = + self.map(Html.fromString(_)) + +} + +object Flash { + + /** + * A fash message can represent a notice, an alert or both - it's some kind of + * a specialized `zio.prelude.These`. + * + * Using a flash message allows one to categorize those into notice or alert + * and by that wrap both messages with a different ui design. + */ + sealed trait Message[+A, +B] { self => + + /** + * Folds a notice with `f` into `C`, an alert with `g` into `C` and both + * with `h` into another `C`. + */ + def fold[C](f: A => C, g: B => C)(h: (C, C) => C): C = this match { + case Message.Notice(a) => f(a) + case Message.Alert(b) => g(b) + case Message.Both(Message.Notice(a), Message.Alert(b)) => h(f(a), g(b)) + } + + /** + * Returns true if this `Message` represents both, a notice and and alert. + */ + def isBoth: Boolean = this match { + case Message.Both(_, _) => true + case _ => false + } + + /** + * Returns true if this `Message` represents a notice only. + */ + def isNotice = this match { + case Message.Notice(_) => true + case _ => false + } + + /** + * Returns true if this `Message` represents an alert only. + */ + def isAlert = this match { + case Message.Alert(_) => true + case _ => false + } + } + private[http] object Message { + case class Notice[+A](a: A) extends Message[A, Nothing] + private[http] object Notice { + val name = "notice" + } + case class Alert[+B](b: B) extends Message[Nothing, B] + private[http] object Alert { + val name = "alert" + } + private[http] case class Both[+A, +B](notice: Notice[A], alert: Alert[B]) extends Message[A, B] + } + + /** + * `Flash.Backend` represents a flash-scope that is not cookie-based but + * instead uses an internal structure. + * + * Semantically it is identical to the cookie-based flash-scope (valid for a + * single request) but by using `Flash.Backend` we're not limited in size of + * the payload as in the cookie-based flash-scope. Still, the `Flash.Backend` + * uses a cookie but does not transport the payload with it but only an + * internal identifier. + */ + trait Backend { self => + + /** + * Gets an `A` from the backend-based flash-scope or fails with a + * `Throwable`. + */ + def flash[A](request: Request, flash: Flash[A]): IO[Throwable, A] + + /** + * Gets an `A` from the backend-based flash-scope and provides a fallback. + */ + final def flashOrElse[A](request: Request, flash: Flash[A])(orElse: => A): UIO[A] = + self.flash(request, flash) <> ZIO.succeed(orElse) + + /** + * Adds flash values to the backend-based flash-scope and returns a workflow + * with an updated `Response`. + */ + def addFlash[A](response: Response, setter: Flash.Setter[A]): UIO[Response] + + /** + * Optionally adds flash values to the backend-based flash-scope and returns + * a workflow with an updated `Response`. + */ + final def addFlash[A](response: Response, setterOpt: Option[Flash.Setter[A]]): UIO[Response] = + setterOpt.fold(ZIO.succeed(response))(self.addFlash(response, _)) + } + + object Backend { + + private case class Impl(ref: Ref[Map[UUID, Map[String, String]]]) extends Backend { + override final def flash[A](request: Request, flash: Flash[A]): IO[Throwable, A] = + for { + flashId <- ZIO.from(Flash.run(Flash.getUUID(flashIdName), request)) + a <- ref.modify { map => + Flash.run(flash, map.get(flashId).getOrElse(Map.empty)) match { + case value @ Right(_) => value -> (map - flashId) + case value @ Left(_) => value -> map + } + }.flatMap(ZIO.from(_)) + } yield a + + override final def addFlash[A](response: Response, setter: Setter[A]): UIO[Response] = { + val map = Flash.Setter.run(setter, Map.empty) + for { + flashId <- zio.Random.nextUUID + setterFlashId <- ref.update(in => in + (flashId -> map)).as(Flash.setValue(flashIdName, flashId)) + } yield response.addFlash(setterFlashId) + } + } + + /** + * Provides a `Flash.Backend` based on a `Ref` in-memory. + */ + val inMemory: ULayer[Backend] = ZLayer(Ref.make(Map.empty[UUID, Map[String, String]]).map(Impl.apply)) + + private val flashIdName = "flashId" + + } + + sealed trait Setter[A] { self => + + /** + * Combines setting this flash value with another setter `that`. + */ + final def ++[B](that: => Setter[B]): Setter[(A, B)] = Setter.Concat(self, that) + } + + private[http] object Setter { + + case object Empty extends Flash.Setter[Unit] + + case class SetValue[A](schema: Schema[A], key: String, a: A) extends Flash.Setter[A] + + case class Concat[A, B](left: Setter[A], right: Setter[B]) extends Flash.Setter[(A, B)] + + def run[A](setter: Setter[A]): Cookie.Response = + Cookie.Response( + Flash.COOKIE_NAME, + URLEncoder.encode( + JsonCodec.jsonEncoder(Schema[Map[String, String]]).encodeJson(run(setter, Map.empty)).toString, + StandardCharsets.UTF_8.toString.toLowerCase, + ), + ) + + def run[A](setter: Setter[A], map: Map[String, String]): Map[String, String] = { + def loop[B](setter: Setter[B], map: Map[String, String]): Map[String, String] = + setter match { + case SetValue(schema, key, a) => + map.updated(key, JsonCodec.jsonEncoder(schema).encodeJson(a).toString) + case Concat(left, right) => + loop(right, loop(left, map)) + case Empty => map + } + loop(setter, map) + } + } + + /** + * Sets a flash value of type `A` with the given key `key`. + */ + def setValue[A: Schema](key: String, a: A): Setter[A] = Setter.SetValue(Schema[A], key, a) + + /** + * Sets a flash value of type `A` with the key for a notice. + */ + def setNotice[A: Schema](a: A): Setter[A] = Setter.SetValue(Schema[A], Message.Notice.name, a) + + /** + * Sets a flash value of type `A` with the key for an alert. + */ + def setAlert[A: Schema](a: A): Setter[A] = Setter.SetValue(Schema[A], Message.Alert.name, a) + + def setEmpty: Setter[Unit] = Setter.Empty + + private[http] val COOKIE_NAME = "zio-http-flash" + + private case class Get[A](schema: Schema[A], key: String) extends Flash[A] + + private case class FlatMap[A, B](self: Flash[A], f: A => Flash[B]) extends Flash[B] + + private case class OrElse[A, B >: A](self: Flash[A], that: Flash[B]) extends Flash[B] + + private case class WithInput[A](f: Map[String, String] => Flash[A]) extends Flash[A] + + private case class Succeed[A](a: A) extends Flash[A] + + private case class Fail(message: String) extends Flash[Nothing] + + def succeed[A](a: A): Flash[A] = Succeed(a) + + def fail(message: String): Flash[Nothing] = Fail(message) + + private def withInput[A](f: Map[String, String] => Flash[A]): Flash[A] = WithInput(f) + + private def getMessage[A: Schema, B: Schema]: Flash[Message[A, B]] = + getMessage(Flash.get[A](Message.Notice.name), Flash.get[B](Message.Alert.name)) + + /** + * Creates a `Flash.Message` from two other values `flashNotice` and + * `flashAlert`. + * + * Uses `flashNotice` to create a `Flash.Message` representing a notice. + * + * Uses `flashAlert` to create a `Flash.Message` representing an alert. + * + * If `flashNotice` and `flashAlert` are both available in the flash scope the + * resulting `Flash.Message` will represent both. + */ + def getMessage[A, B](flashNotice: Flash[A], flashAlert: Flash[B]): Flash[Message[A, B]] = + (flashNotice.optional <*> flashAlert.optional).flatMap { + case (Some(a), Some(b)) => Flash.succeed(Flash.Message.Both(Flash.Message.Notice(a), Flash.Message.Alert(b))) + case (Some(a), _) => Flash.succeed(Flash.Message.Notice(a)) + case (_, Some(b)) => Flash.succeed(Flash.Message.Alert(b)) + case _ => Flash.fail(s"neither '${Message.Notice.name}' nor '${Message.Alert.name}' do exist in the flash-scope") + } + + private def getMessageHtml[A: Schema, B: Schema](f: A => Html, g: B => Html): Flash[Message[Html, Html]] = + getMessage[A, B].map { + case Message.Notice(a) => Message.Notice(f(a)) + case Message.Alert(b) => Message.Alert(g(b)) + case Message.Both(Message.Notice(a), Message.Alert(b)) => Message.Both(Message.Notice(f(a)), Message.Alert(g(b))) + } + + /** + * Creates a `Flash.Message` using the default keys for notice and alert. + * + * Additionally the values must be of type `String` so they can be transformed + * to `Html`. + * + * Usage e.g.: `Flash.getMessageHtml.foldHtml(showNotice, showAlert)(_ ++ _)` + */ + def getMessageHtml: Flash[Message[Html, Html]] = + getMessageHtml[String, String](a => Dom.text(a), b => Dom.text(b)) + + /** + * Gets any flash value of type `A` with the given key `key`. + */ + def get[A: Schema](key: String): Flash[A] = Flash.Get(Schema[A], key) + + /** + * Gets a flash value of type `String` with the given key `key`. + */ + def getString(key: String): Flash[String] = get[String](key) + + /** + * Gets a flash value of type `A` associated with the notice key. + */ + def getNotice[A: Schema]: Flash[A] = get[A](Message.Notice.name) + + /** + * Gets a flash value of type `A` associated with the alert key. + */ + def getAlert[A: Schema]: Flash[A] = get[A](Message.Alert.name) + + /** + * Gets a flash value of type `Float` with the given key `key`. + */ + def getFloat(key: String): Flash[Float] = get[Float](key) + + /** + * Gets a flash value of type `Double` with the given key `key`. + */ + def getDouble(key: String): Flash[Double] = get[Double](key) + + /** + * Gets a flash value of type `Int` with the given key `key`. + */ + def getInt(key: String): Flash[Int] = get[Int](key) + + /** + * Gets a flash value of type `Long` with the given key `key`. + */ + def getLong(key: String): Flash[Long] = get[Long](key) + + /** + * Gets a flash value of type `UUID` with the given key `key`. + */ + def getUUID(key: String): Flash[UUID] = get[UUID](key) + + /** + * Gets a flash value of type `Boolean` with the given key `key`. + */ + def getBoolean(key: String): Flash[Boolean] = get[Boolean](key) + + /** + * Gets the first flash value of type `A` regardless of any key. + */ + def get[A: Schema]: Flash[A] = withInput { map => + map.keys.map(a => Flash.get(a)(Schema[A])).reduce(_ <> _) + } + + private[http] def run[A](flash: Flash[A], sourceRequest: Request): Either[Throwable, A] = + sourceRequest + .cookie(COOKIE_NAME) + .toRight(new RuntimeException("flash cookie doesn't exist")) + .flatMap { cookie => + try { + val content = + URLDecoder.decode(cookie.content, StandardCharsets.UTF_8.toString.toLowerCase) + JsonCodec.jsonDecoder(Schema.map[String, String]).decodeJson(content).left.map(e => new RuntimeException(e)) + } catch { + case e: Exception => Left(e) + } + } + .flatMap(in => run(flash, in)) + + private[http] def run[A](flash: Flash[A], sourceMap: Map[String, String]): Either[Throwable, A] = { + def loop[A](flash: Flash[A], map: Map[String, String]): Either[Throwable, A] = + flash match { + case Get(schema, key) => + map + .get(key) + .toRight(new RuntimeException(s"""flash key doesn't exist: "${key}" (existing flash keys: "${map.keys + .mkString(", ")}")""")) + .flatMap { value => + JsonCodec.jsonDecoder(schema).decodeJson(value).left.map(e => new RuntimeException(e)) + } + case WithInput(f) => + loop(f(map), map) + case OrElse(self, that) => + loop(self, map) match { + case Left(_) => loop(that, map) + case r @ Right(_) => r.asInstanceOf[Either[Throwable, A]] + } + case FlatMap(self, f) => + loop(self, map) match { + case Right(value) => loop(f(value), map) + case l @ Left(_) => l.asInstanceOf[Either[Throwable, A]] + } + case Succeed(a) => Right(a) + case Fail(message) => Left(new RuntimeException(message)) + } + loop(flash, sourceMap) + } +} diff --git a/zio-http/src/main/scala/zio/http/FormField.scala b/zio-http/src/main/scala/zio/http/FormField.scala index 55e5fc2b7e..b4699c2849 100644 --- a/zio-http/src/main/scala/zio/http/FormField.scala +++ b/zio-http/src/main/scala/zio/http/FormField.scala @@ -188,7 +188,7 @@ object FormField { private[http] def getContentType(ast: Chunk[FormAST]): MediaType = ast.collectFirst { - case header: FormAST.Header if header.name == "Content-Type" => + case header: FormAST.Header if header.name.equalsIgnoreCase("Content-Type") => MediaType .forContentType(header.value) .getOrElse(MediaType.application.`octet-stream`) // Unknown content type defaults to binary @@ -200,13 +200,13 @@ object FormField { )(implicit trace: Trace): ZIO[Any, FormDecodingError, FormField] = { val extract = ast.foldLeft((Option.empty[FormAST.Header], Option.empty[FormAST.Header], Option.empty[FormAST.Header])) { - case (accum, header: FormAST.Header) if header.name == "Content-Disposition" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Disposition") => (Some(header), accum._2, accum._3) - case (accum, header: FormAST.Header) if header.name == "Content-Type" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Type") => (accum._1, Some(header), accum._3) - case (accum, header: FormAST.Header) if header.name == "Content-Transfer-Encoding" => + case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Transfer-Encoding") => (accum._1, accum._2, Some(header)) - case (accum, _) => accum + case (accum, _) => accum } for { diff --git a/zio-http/src/main/scala/zio/http/Handler.scala b/zio-http/src/main/scala/zio/http/Handler.scala index 52bb0ea89a..851a8be2af 100644 --- a/zio-http/src/main/scala/zio/http/Handler.scala +++ b/zio-http/src/main/scala/zio/http/Handler.scala @@ -367,6 +367,19 @@ sealed trait Handler[-R, +Err, -In, +Out] { self => ): Handler[R1, Err1, In, Out1] = self.foldHandler(err => Handler.fromZIO(f(err)), Handler.succeed(_)) + /** + * Transforms all failures of the handler effectfully except pure + * interruption. + */ + final def mapErrorCauseZIO[R1 <: R, Err1, Out1 >: Out]( + f: Cause[Err] => ZIO[R1, Err1, Out1], + )(implicit trace: Trace): Handler[R1, Err1, In, Out1] = + self.foldCauseHandler( + err => + if (err.isInterruptedOnly) Handler.failCause(err.asInstanceOf[Cause[Nothing]]) else Handler.fromZIO(f(err)), + Handler.succeed(_), + ) + /** * Returns a new handler where the error channel has been merged into the * success channel to their common combined type. diff --git a/zio-http/src/main/scala/zio/http/Header.scala b/zio-http/src/main/scala/zio/http/Header.scala index e47b3476a2..5be397f714 100644 --- a/zio-http/src/main/scala/zio/http/Header.scala +++ b/zio-http/src/main/scala/zio/http/Header.scala @@ -30,6 +30,7 @@ import scala.util.{Either, Failure, Success, Try} import zio._ import zio.http.codec.RichTextCodec +import zio.http.endpoint.openapi.OpenAPI.SecurityScheme.Http import zio.http.internal.DateEncoding sealed trait Header { @@ -2480,16 +2481,12 @@ object Header { private val codec: RichTextCodec[ContentType] = { // char `.` according to BNF not allowed as `token`, but here tolerated - val token = RichTextCodec.filter(_ => true).validate("not a token") { - case ' ' | '(' | ')' | '<' | '>' | '@' | ',' | ';' | ':' | '\\' | '"' | '/' | '[' | ']' | '?' | '=' => false - case _ => true - } - val tokenQuoted = RichTextCodec.filter(_ => true).validate("not a quoted token") { - case ' ' | '"' => false - case _ => true - } + val token = RichTextCodec.charsNot(' ', '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=') + + val tokenQuoted = RichTextCodec.charsNot(' ', '"') + val type1 = RichTextCodec.string.collectOrFail("unsupported main type") { - case value if MediaType.mainTypeMap.get(value).isDefined => value + case value if MediaType.mainTypeMap.contains(value) => value } val type1x = (RichTextCodec.literalCI("x-") ~ token.repeat.string).transform[String](in => s"${in._1}${in._2}")(in => ("x-", s"${in.substring(2)}")) val codecType1 = (type1 | type1x).transform[String](_.merge) { @@ -4176,8 +4173,9 @@ object Header { 2xx warn-codes describe some aspect of the representation that is not rectified by a validation and will not be deleted by a cache after validation unless a full response is sent. */ - val warnCode: Int = Try { - Integer.parseInt(warningString.split(" ")(0)) + val warnCodeString = warningString.split(" ")(0) + val warnCode: Int = Try { + Integer.parseInt(warnCodeString) }.getOrElse(-1) /* @@ -4190,11 +4188,11 @@ object Header { An advisory text describing the error. */ - val descriptionStartIndex = warningString.indexOf('\"') - val descriptionEndIndex = warningString.indexOf("\"", warningString.indexOf("\"") + 1) + val descriptionStartIndex = warningString.indexOf('\"', warnCodeString.length + warnAgent.length) + 1 + val descriptionEndIndex = warningString.indexOf("\"", descriptionStartIndex) val description = Try { - warningString.substring(descriptionStartIndex, descriptionEndIndex + 1) + warningString.substring(descriptionStartIndex, descriptionEndIndex) }.getOrElse("") /* @@ -4252,17 +4250,16 @@ object Header { def render(warning: Warning): String = warning match { - case Warning(code, agent, text, date) => { + case Warning(code, agent, text, date) => val formattedDate = date match { case Some(value) => DateEncoding.default.encodeDate(value) case None => "" } if (formattedDate.isEmpty) { - code.toString + " " + agent + " " + text + code.toString + " " + agent + " " + '"' + text + '"' } else { - code.toString + " " + agent + " " + text + " " + '"' + formattedDate + '"' + code.toString + " " + agent + " " + '"' + text + '"' + " " + '"' + formattedDate + '"' } - } } } diff --git a/zio-http/src/main/scala/zio/http/Middleware.scala b/zio-http/src/main/scala/zio/http/Middleware.scala index 5f5a9212e7..4dfb35da6b 100644 --- a/zio-http/src/main/scala/zio/http/Middleware.scala +++ b/zio-http/src/main/scala/zio/http/Middleware.scala @@ -16,12 +16,13 @@ package zio.http import java.io.File +import java.net.URLEncoder import zio._ import zio.metrics._ -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http.codec.{PathCodec, SegmentCodec} +import zio.http.endpoint.openapi.OpenAPI trait Middleware[-UpperEnv] { self => def apply[Env1 <: UpperEnv, Err]( @@ -170,6 +171,68 @@ object Middleware extends HandlerAspects { } } + def logAnnotate(key: => String, value: => String)(implicit trace: Trace): Middleware[Any] = + logAnnotate(LogAnnotation(key, value)) + + def logAnnotate(logAnnotation: => LogAnnotation, logAnnotations: LogAnnotation*)(implicit + trace: Trace, + ): Middleware[Any] = + logAnnotate((logAnnotation +: logAnnotations).toSet) + + def logAnnotate(logAnnotations: => Set[LogAnnotation])(implicit trace: Trace): Middleware[Any] = + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler((req: Request) => ZIO.logAnnotate(logAnnotations)(h(req))) + } + } + + /** + * Creates a middleware that will annotate log messages that are logged while + * a request is handled with log annotations derived from the request. + */ + def logAnnotate(fromRequest: Request => Set[LogAnnotation])(implicit trace: Trace): Middleware[Any] = + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler((req: Request) => ZIO.logAnnotate(fromRequest(req))(h(req))) + } + } + + /** + * Creates a middleware that will annotate log messages that are logged while + * a request is handled with the names and the values of the specified + * headers. + */ + def logAnnotateHeaders(headerName: String, headerNames: String*)(implicit trace: Trace): Middleware[Any] = + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = { + val headers = headerName +: headerNames + routes.transform[Env1] { h => + handler((req: Request) => { + val annotations = Set.newBuilder[LogAnnotation] + annotations.sizeHint(headers.length) + var i = 0 + while (i < headers.length) { + val name = headers(i) + annotations += LogAnnotation(name, req.headers.get(name).mkString) + i += 1 + } + ZIO.logAnnotate(annotations.result())(h(req)) + }) + } + } + } + + /** + * Creates middleware that will annotate log messages that are logged while a + * request is handled with the names and the values of the specified headers. + */ + def logAnnotateHeaders(header: Header.HeaderType, headers: Header.HeaderType*)(implicit + trace: Trace, + ): Middleware[Any] = + logAnnotateHeaders(header.name, headers.map(_.name): _*) + def timeout(duration: Duration)(implicit trace: Trace): Middleware[Any] = new Middleware[Any] { def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = @@ -284,7 +347,8 @@ object Middleware extends HandlerAspects { } override def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = { - val mountpoint = Method.GET / path.segments.map(PathCodec.literal).reduceLeft(_ / _) + val mountpoint = + Method.GET / path.segments.map(PathCodec.literal).reduceLeftOption(_ / _).getOrElse(PathCodec.empty) val pattern = mountpoint / trailing val other = Routes( pattern -> Handler @@ -294,7 +358,7 @@ object Middleware extends HandlerAspects { if (isFishy) { Handler.fromZIO(ZIO.logWarning(s"fishy request detected: ${request.path.encode}")) *> Handler.badRequest } else { - val segs = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v, _) => + val segs = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v) => v } val unnest = segs.foldLeft(Path.empty)(_ / _).addLeadingSlash @@ -339,7 +403,7 @@ object Middleware extends HandlerAspects { * Creates a middleware for managing the flash scope. */ def flashScopeHandling: HandlerAspect[Any, Unit] = Middleware.intercept { (req, resp) => - req.cookie("zio-http-flash").fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name))) + req.cookie(Flash.COOKIE_NAME).fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name))) } } diff --git a/zio-http/src/main/scala/zio/http/Path.scala b/zio-http/src/main/scala/zio/http/Path.scala index 81c7f4f68a..c859a6b61c 100644 --- a/zio-http/src/main/scala/zio/http/Path.scala +++ b/zio-http/src/main/scala/zio/http/Path.scala @@ -16,8 +16,6 @@ package zio.http -import scala.collection.mutable - import zio.{Chunk, ChunkBuilder} /** @@ -184,6 +182,58 @@ final case class Path private[http] (flags: Path.Flags, segments: Chunk[String]) else Path.empty } else self + /** + * RFC 3986 § 5.2.4 Remove Dot Segments + * @return + * the Path with `.` and `..` resolved and removed + */ + def removeDotSegments: Path = { + // See https://www.rfc-editor.org/rfc/rfc3986#section-5.2.4 + val segments = new Array[String](self.segments.length) + var segmentCount = 0 + // leading/trailing slashes may change but is unlikely + var flags = self.flags + + var i = 0 + val max = self.segments.length + + if (!Flag.LeadingSlash.check(flags)) { + // § 5.2.4.2.A/D no leading slash, so skip all initial `./` and `../` + while (i < max && (self.segments(i) == "." | self.segments(i) == "..")) { + i += 1 + } + // if the entire input was consumed, there is no more trailing slash + if (i == max) flags = Flag.TrailingSlash.remove(flags) + } + + var loop = i < max + while (loop) { + val segment = self.segments(i) + + i += 1 + loop = i < max + + if (segment == "..") { + segmentCount = (segmentCount - 1).max(0) + // § 5.2.4.2.C resolving `/..` and `/../` removes preceding slashes and is itself replaced by a slash + // so if we popped the first one we definitely have a leading slash + if (segmentCount == 0) flags = Flag.LeadingSlash.add(flags) + // § 5.2.4.2.C resolving `/..` and `/../` are both as-if replaced by a `/` + // so if this is the last segment, then we have a trailing slash + if (i == max) flags = Flag.TrailingSlash.add(flags) + } else if (segment == ".") { + // § 5.2.4.2.B resolving `/.` and `/./` are both as-if replaced by a `/` + // so if this is the last segment, then we have a trailing slash + if (i == max) flags = Flag.TrailingSlash.add(flags) + } else { + segments(segmentCount) = segment + segmentCount += 1 + } + } + + Path(flags, Chunk.fromArray(segments.take(segmentCount))) + } + /** * Creates a new path from this one with it's segments reversed. */ diff --git a/zio-http/src/main/scala/zio/http/QueryParams.scala b/zio-http/src/main/scala/zio/http/QueryParams.scala index cd9367c941..ada32977b4 100644 --- a/zio-http/src/main/scala/zio/http/QueryParams.scala +++ b/zio-http/src/main/scala/zio/http/QueryParams.scala @@ -18,9 +18,9 @@ package zio.http import java.nio.charset.Charset -import zio.Chunk +import zio.{Chunk, IO, NonEmptyChunk, ZIO} -import zio.http.Charsets +import zio.http.codec.TextCodec import zio.http.internal.QueryParamEncoding /** @@ -88,11 +88,44 @@ final case class QueryParams(map: Map[String, Chunk[String]]) { */ def getAll(key: String): Option[Chunk[String]] = map.get(key) + /** + * Retrieves all typed query parameter values having the specified name. + */ + def getAllAs[A](key: String)(implicit codec: TextCodec[A]): Either[QueryParamsError, Chunk[A]] = for { + params <- map.get(key).toRight(QueryParamsError.Missing(key)) + (failed, typed) = params.partitionMap(p => codec.decode(p).toRight(p)) + result <- NonEmptyChunk + .fromChunk(failed) + .map(fails => QueryParamsError.Malformed(key, codec, fails)) + .toLeft(typed) + } yield result + + /** + * Retrieves all typed query parameter values having the specified name as + * ZIO. + */ + def getAllAsZIO[A](key: String)(implicit codec: TextCodec[A]): IO[QueryParamsError, Chunk[A]] = + ZIO.fromEither(getAllAs[A](key)) + /** * Retrieves the first query parameter value having the specified name. */ def get(key: String): Option[String] = getAll(key).flatMap(_.headOption) + /** + * Retrieves the first typed query parameter value having the specified name. + */ + def getAs[A](key: String)(implicit codec: TextCodec[A]): Either[QueryParamsError, A] = for { + param <- get(key).toRight(QueryParamsError.Missing(key)) + typedParam <- codec.decode(param).toRight(QueryParamsError.Malformed(key, codec, NonEmptyChunk(param))) + } yield typedParam + + /** + * Retrieves the first typed query parameter value having the specified name + * as ZIO. + */ + def getAsZIO[A](key: String)(implicit codec: TextCodec[A]): IO[QueryParamsError, A] = ZIO.fromEither(getAs[A](key)) + /** * Retrieves all query parameter values having the specified name, or else * uses the default iterable. @@ -100,6 +133,13 @@ final case class QueryParams(map: Map[String, Chunk[String]]) { def getAllOrElse(key: String, default: => Iterable[String]): Chunk[String] = getAll(key).getOrElse(Chunk.fromIterable(default)) + /** + * Retrieves all query parameter values having the specified name, or else + * uses the default iterable. + */ + def getAllAsOrElse[A](key: String, default: => Iterable[A])(implicit codec: TextCodec[A]): Chunk[A] = + getAllAs[A](key).getOrElse(Chunk.fromIterable(default)) + /** * Retrieves the first query parameter value having the specified name, or * else uses the default value. @@ -107,6 +147,13 @@ final case class QueryParams(map: Map[String, Chunk[String]]) { def getOrElse(key: String, default: => String): String = get(key).getOrElse(default) + /** + * Retrieves the first typed query parameter value having the specified name, + * or else uses the default value. + */ + def getAsOrElse[A](key: String, default: => A)(implicit codec: TextCodec[A]): A = + getAs[A](key).getOrElse(default) + override def hashCode: Int = normalize.map.hashCode /** diff --git a/zio-http/src/main/scala/zio/http/QueryParamsError.scala b/zio-http/src/main/scala/zio/http/QueryParamsError.scala new file mode 100644 index 0000000000..c74e02bdce --- /dev/null +++ b/zio-http/src/main/scala/zio/http/QueryParamsError.scala @@ -0,0 +1,42 @@ +/* + * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.http + +import java.nio.charset.Charset + +import scala.util.control.NoStackTrace + +import zio.{Chunk, NonEmptyChunk} + +import zio.http.codec.TextCodec +import zio.http.internal.QueryParamEncoding + +sealed trait QueryParamsError extends Exception with NoStackTrace { + override def getMessage(): String = message + def message: String +} +object QueryParamsError { + final case class Missing(name: String) extends QueryParamsError { + def message = s"Missing query parameter with name $name" + } + + final case class Malformed(name: String, codec: TextCodec[_], values: NonEmptyChunk[String]) + extends QueryParamsError { + def message: String = + s"Unable to decode query parameter $name with values [ ${values.mkString(", ")} ] using ${codec.describe} codec" + } +} diff --git a/zio-http/src/main/scala/zio/http/Request.scala b/zio-http/src/main/scala/zio/http/Request.scala index b258a4d180..1cec9fbbfc 100644 --- a/zio-http/src/main/scala/zio/http/Request.scala +++ b/zio-http/src/main/scala/zio/http/Request.scala @@ -18,8 +18,7 @@ package zio.http import java.net.InetAddress -import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Chunk, Trace, ZIO} +import zio._ import zio.http.internal.HeaderOps @@ -148,8 +147,11 @@ final case class Request( def cookies: Chunk[Cookie] = header(Header.Cookie).fold(Chunk.empty[Cookie])(_.value.toChunk) - def flashMessage: Option[String] = - cookie("zio-http-flash").map(_.content) + /** + * Returns an `A` if it exists from the cookie-based flash-scope. + */ + def flash[A](flash: Flash[A]): Option[A] = + Flash.run(flash, self).toOption } diff --git a/zio-http/src/main/scala/zio/http/Response.scala b/zio-http/src/main/scala/zio/http/Response.scala index 4a630f4d54..bba283a2f8 100644 --- a/zio-http/src/main/scala/zio/http/Response.scala +++ b/zio-http/src/main/scala/zio/http/Response.scala @@ -20,8 +20,7 @@ import java.nio.file.{AccessDeniedException, NotDirectoryException} import scala.annotation.tailrec -import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Cause, Task, Trace, ZIO} +import zio._ import zio.stream.ZStream @@ -40,8 +39,11 @@ final case class Response( def addCookie(cookie: Cookie.Response): Response = self.copy(headers = self.headers ++ Headers(Header.SetCookie(cookie))) - def addFlashMessage(message: String): Response = - addCookie(Cookie.Response("zio-http-flash", message)) + /** + * Adds flash values to the cookie-based flash-scope. + */ + def addFlash[A](setter: Flash.Setter[A]): Response = + self.addCookie(Flash.Setter.run(setter).copy(path = Some(Path.root))) /** * Collects the potentially streaming body of the response into a single @@ -136,7 +138,7 @@ object Response { val message2 = OutputEncoder.encodeHtml(if (message == null) status.text else message) - Response(status = status, headers = Headers(Header.Warning(status.code, "ZIO HTTP", message2))) + Response(status = status, headers = Headers(Header.Warning(199, "ZIO HTTP", message2))) } def error(status: Status.Error): Response = diff --git a/zio-http/src/main/scala/zio/http/Route.scala b/zio-http/src/main/scala/zio/http/Route.scala index 0e0ec8334e..6e3fae680a 100644 --- a/zio-http/src/main/scala/zio/http/Route.scala +++ b/zio-http/src/main/scala/zio/http/Route.scala @@ -16,9 +16,8 @@ package zio.http import zio._ -import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.http.Route.Provided +import zio.http.codec.PathCodec /* * Represents a single route, which has either handled its errors by converting @@ -47,14 +46,16 @@ sealed trait Route[-Env, +Err] { self => def asErrorType[Err2](implicit ev: Err <:< Err2): Route[Env, Err2] = self.asInstanceOf[Route[Env, Err2]] /** - * Handles the error of the route. This method can be used to convert a route - * that does not handle its errors into one that does handle its errors. + * Handles all typed errors in the route by converting them into responses. + * This method can be used to convert a route that does not handle its errors + * into one that does handle its errors. */ final def handleError(f: Err => Response)(implicit trace: Trace): Route[Env, Nothing] = self.handleErrorCause(Response.fromCauseWith(_)(f)) /** - * Handles the error of the route. This method can be used to convert a route + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into responses. This method can be used to convert a route * that does not handle its errors into one that does handle its errors. */ final def handleErrorCause(f: Cause[Err] => Response)(implicit trace: Trace): Route[Env, Nothing] = @@ -83,6 +84,120 @@ sealed trait Route[-Env, +Err] { self => Handled(rpm.routePattern, handler2, location) } + /** + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into a ZIO effect that produces the response. This method + * can be used to convert a route that does not handle its errors into one + * that does handle its errors. + */ + final def handleErrorCauseZIO( + f: Cause[Err] => ZIO[Any, Nothing, Response], + )(implicit trace: Trace): Route[Env, Nothing] = + self match { + case Provided(route, env) => Provided(route.handleErrorCauseZIO(f), env) + case Augmented(route, aspect) => Augmented(route.handleErrorCauseZIO(f), aspect) + case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + + case Unhandled(rpm, handler, zippable, location) => + val handler2: Handler[Env, Response, Request, Response] = { + val paramHandler = + Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) => + rpm.routePattern.decode(request.method, request.path) match { + case Left(error) => ZIO.dieMessage(error) + case Right(value) => + val params = rpm.zippable.zip(value, ctx) + + handler(zippable.zip(params, request)) + } + } + rpm.aspect.applyHandlerContext(paramHandler.mapErrorCauseZIO(f)) + } + + Handled(rpm.routePattern, handler2, location) + } + + /** + * Handles all typed errors in the route by converting them into responses, + * taking into account the request that caused the error. This method can be + * used to convert a route that does not handle its errors into one that does + * handle its errors. + */ + final def handleErrorRequest(f: (Err, Request) => Response)(implicit trace: Trace): Route[Env, Nothing] = + self.handleErrorRequestCause((request, cause) => Response.fromCauseWith(cause)(f(_, request))) + + /** + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into responses, taking into account the request that caused + * the error. This method can be used to convert a route that does not handle + * its errors into one that does handle its errors. + */ + final def handleErrorRequestCause(f: (Request, Cause[Err]) => Response)(implicit trace: Trace): Route[Env, Nothing] = + self match { + case Provided(route, env) => Provided(route.handleErrorRequestCause(f), env) + case Augmented(route, aspect) => Augmented(route.handleErrorRequestCause(f), aspect) + case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + + case Unhandled(rpm, handler, zippable, location) => + val handler2: Handler[Env, Response, Request, Response] = { + val paramHandler = + Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) => + rpm.routePattern.decode(request.method, request.path) match { + case Left(error) => ZIO.dieMessage(error) + case Right(value) => + val params = rpm.zippable.zip(value, ctx) + + handler(zippable.zip(params, request)) + } + } + + // Sandbox before applying aspect: + rpm.aspect.applyHandlerContext( + Handler.fromFunctionHandler[(rpm.Context, Request)] { case (_, req) => + paramHandler.mapErrorCause(f(req, _)) + }, + ) + } + + Handled(rpm.routePattern, handler2, location) + } + + /** + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into a ZIO effect that produces the response, taking into + * account the request that caused the error. This method can be used to + * convert a route that does not handle its errors into one that does handle + * its errors. + */ + final def handleErrorRequestCauseZIO( + f: (Request, Cause[Err]) => ZIO[Any, Nothing, Response], + )(implicit trace: Trace): Route[Env, Nothing] = + self match { + case Provided(route, env) => Provided(route.handleErrorRequestCauseZIO(f), env) + case Augmented(route, aspect) => Augmented(route.handleErrorRequestCauseZIO(f), aspect) + case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + + case Unhandled(rpm, handler, zippable, location) => + val handler2: Handler[Env, Response, Request, Response] = { + val paramHandler = + Handler.fromFunctionZIO[(rpm.Context, Request)] { case (ctx, request) => + rpm.routePattern.decode(request.method, request.path) match { + case Left(error) => ZIO.dieMessage(error) + case Right(value) => + val params = rpm.zippable.zip(value, ctx) + + handler(zippable.zip(params, request)) + } + } + rpm.aspect.applyHandlerContext( + Handler.fromFunctionHandler[(rpm.Context, Request)] { case (_, req) => + paramHandler.mapErrorCauseZIO(f(req, _)) + }, + ) + } + + Handled(rpm.routePattern, handler2, location) + } + /** * Determines if the route is defined for the specified request. */ @@ -94,6 +209,16 @@ sealed trait Route[-Env, +Err] { self => */ def location: Trace + def nest(prefix: PathCodec[Unit])(implicit ev: Err <:< Response): Route[Env, Err] = + self match { + case Provided(route, env) => Provided(route.nest(prefix), env) + case Augmented(route, aspect) => Augmented(route.nest(prefix), aspect) + case Handled(routePattern, handler, location) => Handled(routePattern.nest(prefix), handler, location) + + case Unhandled(rpm, handler, zippable, location) => + Unhandled(rpm.prefix(prefix), handler, zippable, location) + } + final def provideEnvironment(env: ZEnvironment[Env]): Route[Any, Err] = Route.Provided(self, env) @@ -103,6 +228,13 @@ sealed trait Route[-Env, +Err] { self => */ def routePattern: RoutePattern[_] + /** + * Applies the route to the specified request. The route must be defined for + * the request, or else this method will fail fatally. + */ + final def run(request: Request)(implicit trace: Trace): ZIO[Env, Either[Err, Response], Response] = + Routes(self).run(request) + /** * Returns a route that automatically translates all failures into responses, * using best-effort heuristics to determine the appropriate HTTP status code, @@ -196,6 +328,16 @@ object Route { Route.route[A, Env1](self)(handler) } + def prefix(path: PathCodec[Unit]): Builder[Env, A] = + new Builder[Env, A] { + type PathInput = self.PathInput + type Context = self.Context + + def routePattern: RoutePattern[PathInput] = self.routePattern.nest(path) + def aspect: HandlerAspect[Env, Context] = self.aspect + def zippable: Zippable.Out[PathInput, Context, A] = self.zippable + } + def provideEnvironment(env: ZEnvironment[Env]): Route.Builder[Any, A] = { implicit val z = zippable diff --git a/zio-http/src/main/scala/zio/http/RoutePattern.scala b/zio-http/src/main/scala/zio/http/RoutePattern.scala index 2dc33c00f9..f54cc50c98 100644 --- a/zio-http/src/main/scala/zio/http/RoutePattern.scala +++ b/zio-http/src/main/scala/zio/http/RoutePattern.scala @@ -122,6 +122,9 @@ final case class RoutePattern[A](method: Method, pathCodec: PathCodec[A]) { self */ def matches(method: Method, path: Path): Boolean = decode(method, path).isRight + def nest(prefix: PathCodec[Unit]): RoutePattern[A] = + copy(pathCodec = prefix ++ pathCodec) + /** * Renders the route pattern as a string. */ @@ -222,16 +225,16 @@ object RoutePattern { */ val any: RoutePattern[Path] = RoutePattern(Method.ANY, PathCodec.trailing) + def apply(method: Method, path: Path): RoutePattern[Unit] = + path.segments.foldLeft[RoutePattern[Unit]](fromMethod(method)) { (pathSpec, segment) => + pathSpec./[Unit](PathCodec.Segment(SegmentCodec.literal(segment))) + } + /** * Constructs a route pattern from a method and a path literal. To match * against any method, use [[zio.http.Method.ANY]]. The specified string may * contain path segments, which are separated by slashes. */ - def apply(method: Method, value: String): RoutePattern[Unit] = { - val path = Path(value) - - path.segments.foldLeft[RoutePattern[Unit]](fromMethod(method)) { (pathSpec, segment) => - pathSpec./[Unit](PathCodec.Segment(SegmentCodec.literal(segment))) - } - } + def apply(method: Method, pathString: String): RoutePattern[Unit] = + apply(method, Path(pathString)) } diff --git a/zio-http/src/main/scala/zio/http/Routes.scala b/zio-http/src/main/scala/zio/http/Routes.scala index bc8845d1d1..b8b953c871 100644 --- a/zio-http/src/main/scala/zio/http/Routes.scala +++ b/zio-http/src/main/scala/zio/http/Routes.scala @@ -16,7 +16,8 @@ package zio.http import zio._ -import zio.stacktracer.TracingImplicits.disableAutoTrace + +import zio.http.codec.PathCodec /** * Represents a collection of routes, each of which is defined by a pattern and @@ -57,6 +58,9 @@ final class Routes[-Env, +Err] private (val routes: Chunk[zio.http.Route[Env, Er def @@[Env1 <: Env](aspect: Middleware[Env1]): Routes[Env1, Err] = aspect(self) + def apply(request: Request)(implicit ev: Err <:< Response, trace: Trace): ZIO[Env, Response, Response] = + self.toHttpApp.apply(request) + def asEnvType[Env2](implicit ev: Env2 <:< Env): Routes[Env2, Err] = self.asInstanceOf[Routes[Env2, Err]] @@ -65,17 +69,62 @@ final class Routes[-Env, +Err] private (val routes: Chunk[zio.http.Route[Env, Er /** * Handles all typed errors in the routes by converting them into responses. + * This method can be used to convert routes that do not handle their errors + * into ones that do handle their errors. */ def handleError(f: Err => Response)(implicit trace: Trace): Routes[Env, Nothing] = new Routes(routes.map(_.handleError(f))) /** * Handles all typed errors, as well as all non-recoverable errors, by - * converting them into responses. + * converting them into responses. This method can be used to convert routes + * that do not handle their errors into ones that do handle their errors. */ def handleErrorCause(f: Cause[Err] => Response)(implicit trace: Trace): Routes[Env, Nothing] = new Routes(routes.map(_.handleErrorCause(f))) + /** + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into a ZIO effect that produces the response. This method + * can be used to convert routes that do not handle their errors into ones + * that do handle their errors. + */ + def handleErrorCauseZIO(f: Cause[Err] => ZIO[Any, Nothing, Response])(implicit trace: Trace): Routes[Env, Nothing] = + new Routes(routes.map(_.handleErrorCauseZIO(f))) + + def nest(prefix: PathCodec[Unit])(implicit trace: Trace, ev: Err <:< Response): Routes[Env, Err] = + new Routes(self.routes.map(_.nest(prefix))) + + /** + * Handles all typed errors in the routes by converting them into responses, + * taking into account the request that caused the error. This method can be + * used to convert routes that do not handle their errors into ones that do + * handle their errors. + */ + def handleErrorRequest(f: (Err, Request) => Response)(implicit trace: Trace): Routes[Env, Nothing] = + new Routes(routes.map(_.handleErrorRequest(f))) + + /** + * Handles all typed errors in the routes by converting them into responses, + * taking into account the request that caused the error. This method can be + * used to convert routes that do not handle their errors into ones that do + * handle their errors. + */ + def handleErrorRequestCause(f: (Request, Cause[Err]) => Response)(implicit trace: Trace): Routes[Env, Nothing] = + new Routes(routes.map(_.handleErrorRequestCause(f))) + + /** + * Handles all typed errors, as well as all non-recoverable errors, by + * converting them into a ZIO effect that produces the response, taking into + * account the request that caused the error. This method can be used to + * convert routes that do not handle their errors into ones that do handle + * their errors. + */ + def handleErrorRequestCauseZIO(f: (Request, Cause[Err]) => ZIO[Any, Nothing, Response])(implicit + trace: Trace, + ): Routes[Env, Nothing] = + new Routes(routes.map(_.handleErrorRequestCauseZIO(f))) + /** * Returns new routes that have each been provided the specified environment, * thus eliminating their requirement for any specific environment. @@ -83,6 +132,40 @@ final class Routes[-Env, +Err] private (val routes: Chunk[zio.http.Route[Env, Er def provideEnvironment(env: ZEnvironment[Env]): Routes[Any, Err] = new Routes(routes.map(_.provideEnvironment(env))) + def run(request: Request)(implicit trace: Trace): ZIO[Env, Either[Err, Response], Response] = { + + class RouteFailure[+Err](val err: Cause[Err]) extends Throwable(null, null, true, false) { + override def getMessage: String = err.unified.headOption.fold("")(_.message) + + override def getStackTrace(): Array[StackTraceElement] = + err.unified.headOption.fold[Chunk[StackTraceElement]](Chunk.empty)(_.trace).toArray + + override def getCause(): Throwable = + err.find { case Cause.Die(throwable, _) => throwable } + .orElse(err.find { case Cause.Fail(value: Throwable, _) => value }) + .orNull + + def fillSuppressed()(implicit unsafe: Unsafe): Unit = + if (getSuppressed().length == 0) { + err.unified.iterator.drop(1).foreach(unified => addSuppressed(unified.toThrowable)) + } + + override def toString = + err.prettyPrint + } + var routeFailure: RouteFailure[Err] = null + + handleErrorCauseZIO { cause => + routeFailure = new RouteFailure(cause) + ZIO.refailCause(Cause.die(routeFailure)) + } + .apply(request) + .mapErrorCause { + case Cause.Die(value: RouteFailure[_], _) if value == routeFailure => routeFailure.err.map(Left(_)) + case cause => cause.map(Right(_)) + } + } + /** * Returns new routes that automatically translate all failures into * responses, using best-effort heuristics to determine the appropriate HTTP diff --git a/zio-http/src/main/scala/zio/http/SSLConfig.scala b/zio-http/src/main/scala/zio/http/SSLConfig.scala index 876b59eef8..d681ae5d01 100644 --- a/zio-http/src/main/scala/zio/http/SSLConfig.scala +++ b/zio-http/src/main/scala/zio/http/SSLConfig.scala @@ -17,16 +17,32 @@ package zio.http import zio.Config -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http.SSLConfig._ -final case class SSLConfig(behaviour: HttpBehaviour, data: Data, provider: Provider) +sealed trait ClientAuth + +object ClientAuth { + case object Required extends ClientAuth + case object NoneClientAuth extends ClientAuth + case object Optional extends ClientAuth + +} + +final case class SSLConfig( + behaviour: HttpBehaviour, + data: Data, + provider: Provider, + clientAuth: Option[ClientAuth] = None, +) object SSLConfig { def apply(data: Data): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK) + new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK, None) + + def apply(data: Data, clientAuth: ClientAuth): SSLConfig = + new SSLConfig(HttpBehaviour.Redirect, data, Provider.JDK, Some(clientAuth)) val config: Config[SSLConfig] = ( @@ -38,22 +54,41 @@ object SSLConfig { } def fromFile(certPath: String, keyPath: String): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.FromFile(certPath, keyPath), Provider.JDK) + fromFile(HttpBehaviour.Redirect, certPath, keyPath) + + def fromFile(certPath: String, keyPath: String, clientAuth: ClientAuth): SSLConfig = + fromFile(HttpBehaviour.Redirect, certPath, keyPath, Some(clientAuth)) - def fromFile(behaviour: HttpBehaviour, certPath: String, keyPath: String): SSLConfig = - new SSLConfig(behaviour, Data.FromFile(certPath, keyPath), Provider.JDK) + def fromFile( + behaviour: HttpBehaviour, + certPath: String, + keyPath: String, + clientAuth: Option[ClientAuth] = None, + ): SSLConfig = + new SSLConfig(behaviour, Data.FromFile(certPath, keyPath), Provider.JDK, clientAuth) def fromResource(certPath: String, keyPath: String): SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.FromResource(certPath, keyPath), Provider.JDK) + fromResource(HttpBehaviour.Redirect, certPath, keyPath, None) - def fromResource(behaviour: HttpBehaviour, certPath: String, keyPath: String): SSLConfig = - new SSLConfig(behaviour, Data.FromResource(certPath, keyPath), Provider.JDK) + def fromResource(certPath: String, keyPath: String, clientAuth: ClientAuth): SSLConfig = + fromResource(HttpBehaviour.Redirect, certPath, keyPath, Some(clientAuth)) + + def fromResource( + behaviour: HttpBehaviour, + certPath: String, + keyPath: String, + clientAuth: Option[ClientAuth] = None, + ): SSLConfig = + new SSLConfig(behaviour, Data.FromResource(certPath, keyPath), Provider.JDK, clientAuth) def generate: SSLConfig = - new SSLConfig(HttpBehaviour.Redirect, Data.Generate, Provider.JDK) + generate(HttpBehaviour.Redirect, None) + + def generate(clientAuth: ClientAuth): SSLConfig = + generate(HttpBehaviour.Redirect, Some(clientAuth)) - def generate(behaviour: HttpBehaviour): SSLConfig = - new SSLConfig(behaviour, Data.Generate, Provider.JDK) + def generate(behaviour: HttpBehaviour, clientAuth: Option[ClientAuth] = None): SSLConfig = + new SSLConfig(behaviour, Data.Generate, Provider.JDK, clientAuth) sealed trait HttpBehaviour object HttpBehaviour { diff --git a/zio-http/src/main/scala/zio/http/Scheme.scala b/zio-http/src/main/scala/zio/http/Scheme.scala index bbb39803ca..fe3cde5b7c 100644 --- a/zio-http/src/main/scala/zio/http/Scheme.scala +++ b/zio-http/src/main/scala/zio/http/Scheme.scala @@ -21,13 +21,17 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace sealed trait Scheme { self => def encode: String = self match { - case Scheme.HTTP => "http" - case Scheme.HTTPS => "https" - case Scheme.WS => "ws" - case Scheme.WSS => "wss" + case Scheme.HTTP => "http" + case Scheme.HTTPS => "https" + case Scheme.WS => "ws" + case Scheme.WSS => "wss" + case Scheme.Custom(scheme) => scheme } - def isHttp: Boolean = !isWebSocket + def isHttp: Boolean = self match { + case Scheme.HTTP | Scheme.HTTPS => true + case _ => false + } def isWebSocket: Boolean = self match { case Scheme.WS => true @@ -35,38 +39,45 @@ sealed trait Scheme { self => case _ => false } - def isSecure: Boolean = self match { - case Scheme.HTTPS => true - case Scheme.WSS => true - case _ => false + def isSecure: Option[Boolean] = self match { + case Scheme.HTTPS | Scheme.WSS => Some(true) + case Scheme.HTTP | Scheme.WS => Some(false) + case _ => None } - def defaultPort: Int = self match { - case Scheme.HTTP => 80 - case Scheme.HTTPS => 443 - case Scheme.WS => 80 - case Scheme.WSS => 443 + /** default ports is only define for the Schemes: http, https, ws, wss */ + def defaultPort: Option[Int] = self match { + case Scheme.HTTP => Some(Scheme.defaultPortForHTTP) + case Scheme.HTTPS => Some(Scheme.defaultPortForHTTPS) + case Scheme.WS => Some(Scheme.defaultPortForWS) + case Scheme.WSS => Some(Scheme.defaultPortForWSS) + case Scheme.Custom(_) => None } + } -object Scheme { + +object Scheme { /** * Decodes a string to an Option of Scheme. Returns None in case of * null/non-valid Scheme + * + * The should be lowercase and follow this syntax: + * - Scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) */ def decode(scheme: String): Option[Scheme] = Option(unsafe.decode(scheme)(Unsafe.unsafe)) private[zio] object unsafe { def decode(scheme: String)(implicit unsafe: Unsafe): Scheme = { - if (scheme == null) null + if (scheme == null || scheme.isEmpty) null else - scheme.length match { - case 5 => Scheme.HTTPS - case 4 => Scheme.HTTP - case 3 => Scheme.WSS - case 2 => Scheme.WS - case _ => null + scheme match { + case "http" => HTTP + case "https" => HTTPS + case "ws" => WS + case "wss" => WSS + case custom => new Custom(custom.toLowerCase) {} } } } @@ -78,4 +89,15 @@ object Scheme { case object WS extends Scheme case object WSS extends Scheme + + /** + * @param scheme + * value MUST not be "http" "https" "ws" "wss" + */ + sealed abstract case class Custom private[http] (scheme: String) extends Scheme + + def defaultPortForHTTP = 80 + def defaultPortForHTTPS = 443 + def defaultPortForWS = 80 + def defaultPortForWSS = 443 } diff --git a/zio-http/src/main/scala/zio/http/Server.scala b/zio-http/src/main/scala/zio/http/Server.scala index ff540b68ce..4f86d66fa5 100644 --- a/zio-http/src/main/scala/zio/http/Server.scala +++ b/zio-http/src/main/scala/zio/http/Server.scala @@ -54,6 +54,7 @@ object Server { requestDecompression: Decompression, responseCompression: Option[ResponseCompressionConfig], requestStreaming: RequestStreaming, + maxInitialLineLength: Int, maxHeaderSize: Int, logWarningOnFatalError: Boolean, gracefulShutdownTimeout: Duration, @@ -112,6 +113,8 @@ object Server { */ def logWarningOnFatalError(enable: Boolean): Config = self.copy(logWarningOnFatalError = enable) + def maxInitialLineLength(initialLineLength: Int): Config = self.copy(maxInitialLineLength = initialLineLength) + /** * Configure the server to use `maxHeaderSize` value when encode/decode * headers. @@ -169,6 +172,7 @@ object Server { Decompression.config.nested("request-decompression").withDefault(Config.default.requestDecompression) ++ ResponseCompressionConfig.config.nested("response-compression").optional ++ RequestStreaming.config.nested("request-streaming").withDefault(Config.default.requestStreaming) ++ + zio.Config.int("max-initial-line-length").withDefault(Config.default.maxInitialLineLength) ++ zio.Config.int("max-header-size").withDefault(Config.default.maxHeaderSize) ++ zio.Config.boolean("log-warning-on-fatal-error").withDefault(Config.default.logWarningOnFatalError) ++ zio.Config.duration("graceful-shutdown-timeout").withDefault(Config.default.gracefulShutdownTimeout) ++ @@ -183,6 +187,7 @@ object Server { requestDecompression, responseCompression, requestStreaming, + maxInitialLineLength, maxHeaderSize, logWarningOnFatalError, gracefulShutdownTimeout, @@ -196,6 +201,7 @@ object Server { requestDecompression = requestDecompression, responseCompression = responseCompression, requestStreaming = requestStreaming, + maxInitialLineLength = maxInitialLineLength, maxHeaderSize = maxHeaderSize, logWarningOnFatalError = logWarningOnFatalError, gracefulShutdownTimeout = gracefulShutdownTimeout, @@ -211,6 +217,7 @@ object Server { requestDecompression = Decompression.No, responseCompression = None, requestStreaming = RequestStreaming.Disabled(1024 * 100), + maxInitialLineLength = 4096, maxHeaderSize = 8192, logWarningOnFatalError = true, gracefulShutdownTimeout = 10.seconds, diff --git a/zio-http/src/main/scala/zio/http/Status.scala b/zio-http/src/main/scala/zio/http/Status.scala index 28fb39c2d1..96fd397a9a 100644 --- a/zio-http/src/main/scala/zio/http/Status.scala +++ b/zio-http/src/main/scala/zio/http/Status.scala @@ -16,6 +16,8 @@ package zio.http +import scala.util.Try + import zio.Trace import zio.stacktracer.TracingImplicits.disableAutoTrace @@ -170,73 +172,68 @@ object Status { final case class Custom(override val code: Int) extends Status - def fromInt(code: Int): Option[Status] = { - - if (code < 100 || code > 599) { - None - } else { - - val status = code match { - case 100 => Status.Continue - case 101 => Status.SwitchingProtocols - case 102 => Status.Processing - case 200 => Status.Ok - case 201 => Status.Created - case 202 => Status.Accepted - case 203 => Status.NonAuthoritativeInformation - case 204 => Status.NoContent - case 205 => Status.ResetContent - case 206 => Status.PartialContent - case 207 => Status.MultiStatus - case 300 => Status.MultipleChoices - case 301 => Status.MovedPermanently - case 302 => Status.Found - case 303 => Status.SeeOther - case 304 => Status.NotModified - case 305 => Status.UseProxy - case 307 => Status.TemporaryRedirect - case 308 => Status.PermanentRedirect - case 400 => Status.BadRequest - case 401 => Status.Unauthorized - case 402 => Status.PaymentRequired - case 403 => Status.Forbidden - case 404 => Status.NotFound - case 405 => Status.MethodNotAllowed - case 406 => Status.NotAcceptable - case 407 => Status.ProxyAuthenticationRequired - case 408 => Status.RequestTimeout - case 409 => Status.Conflict - case 410 => Status.Gone - case 411 => Status.LengthRequired - case 412 => Status.PreconditionFailed - case 413 => Status.RequestEntityTooLarge - case 414 => Status.RequestUriTooLong - case 415 => Status.UnsupportedMediaType - case 416 => Status.RequestedRangeNotSatisfiable - case 417 => Status.ExpectationFailed - case 421 => Status.MisdirectedRequest - case 422 => Status.UnprocessableEntity - case 423 => Status.Locked - case 424 => Status.FailedDependency - case 425 => Status.UnorderedCollection - case 426 => Status.UpgradeRequired - case 428 => Status.PreconditionRequired - case 429 => Status.TooManyRequests - case 431 => Status.RequestHeaderFieldsTooLarge - case 500 => Status.InternalServerError - case 501 => Status.NotImplemented - case 502 => Status.BadGateway - case 503 => Status.ServiceUnavailable - case 504 => Status.GatewayTimeout - case 505 => Status.HttpVersionNotSupported - case 506 => Status.VariantAlsoNegotiates - case 507 => Status.InsufficientStorage - case 510 => Status.NotExtended - case 511 => Status.NetworkAuthenticationRequired - case _ => Status.Custom(code) - - } - Some(status) + def fromString(code: String): Option[Status] = + Try(code.toInt).toOption.map(fromInt) + + def fromInt(code: Int): Status = { + code match { + case 100 => Status.Continue + case 101 => Status.SwitchingProtocols + case 102 => Status.Processing + case 200 => Status.Ok + case 201 => Status.Created + case 202 => Status.Accepted + case 203 => Status.NonAuthoritativeInformation + case 204 => Status.NoContent + case 205 => Status.ResetContent + case 206 => Status.PartialContent + case 207 => Status.MultiStatus + case 300 => Status.MultipleChoices + case 301 => Status.MovedPermanently + case 302 => Status.Found + case 303 => Status.SeeOther + case 304 => Status.NotModified + case 305 => Status.UseProxy + case 307 => Status.TemporaryRedirect + case 308 => Status.PermanentRedirect + case 400 => Status.BadRequest + case 401 => Status.Unauthorized + case 402 => Status.PaymentRequired + case 403 => Status.Forbidden + case 404 => Status.NotFound + case 405 => Status.MethodNotAllowed + case 406 => Status.NotAcceptable + case 407 => Status.ProxyAuthenticationRequired + case 408 => Status.RequestTimeout + case 409 => Status.Conflict + case 410 => Status.Gone + case 411 => Status.LengthRequired + case 412 => Status.PreconditionFailed + case 413 => Status.RequestEntityTooLarge + case 414 => Status.RequestUriTooLong + case 415 => Status.UnsupportedMediaType + case 416 => Status.RequestedRangeNotSatisfiable + case 417 => Status.ExpectationFailed + case 421 => Status.MisdirectedRequest + case 422 => Status.UnprocessableEntity + case 423 => Status.Locked + case 424 => Status.FailedDependency + case 425 => Status.UnorderedCollection + case 426 => Status.UpgradeRequired + case 428 => Status.PreconditionRequired + case 429 => Status.TooManyRequests + case 431 => Status.RequestHeaderFieldsTooLarge + case 500 => Status.InternalServerError + case 501 => Status.NotImplemented + case 502 => Status.BadGateway + case 503 => Status.ServiceUnavailable + case 504 => Status.GatewayTimeout + case 505 => Status.HttpVersionNotSupported + case 506 => Status.VariantAlsoNegotiates + case 507 => Status.InsufficientStorage + case 510 => Status.NotExtended + case 511 => Status.NetworkAuthenticationRequired + case _ => Status.Custom(code) } } } diff --git a/zio-http/src/main/scala/zio/http/URL.scala b/zio-http/src/main/scala/zio/http/URL.scala index fe311dfe7a..281f842e1d 100644 --- a/zio-http/src/main/scala/zio/http/URL.scala +++ b/zio-http/src/main/scala/zio/http/URL.scala @@ -16,13 +16,13 @@ package zio.http -import java.net.{MalformedURLException, URI, URISyntaxException} +import java.net.{MalformedURLException, URI} import scala.util.Try import zio.Chunk -import zio.http.URL.{Fragment, Location, portFromScheme} +import zio.http.URL.{Fragment, Location} import zio.http.internal.QueryParamEncoding final case class URL( @@ -48,10 +48,10 @@ final case class URL( def /(segment: String): URL = self.copy(path = self.path / segment) def absolute(host: String): URL = - self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP))) + self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, None)) def absolute(scheme: Scheme, host: String, port: Int): URL = - self.copy(kind = URL.Location.Absolute(scheme, host, port)) + self.copy(kind = URL.Location.Absolute(scheme, host, Some(port))) def addLeadingSlash: URL = self.copy(path = path.addLeadingSlash) @@ -101,20 +101,25 @@ final case class URL( def host(host: String): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP)) + case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, None) case abs: URL.Location.Absolute => abs.copy(host = host) } copy(kind = location) } + /** + * @return + * the location, the host name and the port. The port part is omitted if is + * the default port for the protocol. + */ def hostPort: Option[String] = kind match { - case URL.Location.Relative => None - case URL.Location.Absolute(scheme, host, port) => - Some( - if (port == portFromScheme(scheme)) host - else s"$host:$port", - ) + case URL.Location.Relative => None + case abs: URL.Location.Absolute => + abs.portIfNotDefault match { + case None => Some(abs.host) + case Some(customPort) => Some(s"${abs.host}:$customPort") + } } def isAbsolute: Boolean = self.kind match { @@ -140,8 +145,8 @@ final case class URL( def port(port: Int): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", port) - case abs: URL.Location.Absolute => abs.copy(port = port) + case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", Some(port)) + case abs: URL.Location.Absolute => abs.copy(originalPort = Some(port)) } copy(kind = location) @@ -149,16 +154,17 @@ final case class URL( def port: Option[Int] = kind match { case URL.Location.Relative => None - case abs: URL.Location.Absolute => Option(abs.port) + case abs: URL.Location.Absolute => abs.originalPort } - def portOrDefault: Int = port.getOrElse(portFromScheme(scheme.getOrElse(Scheme.HTTP))) + def portOrDefault: Option[Int] = kind match { + case URL.Location.Relative => None + case abs: URL.Location.Absolute => abs.portOrDefault + } def portIfNotDefault: Option[Int] = kind match { - case URL.Location.Relative => - None - case abs: URL.Location.Absolute => - if (abs.port == portFromScheme(abs.scheme)) None else Some(abs.port) + case URL.Location.Relative => None + case abs: URL.Location.Absolute => abs.portIfNotDefault } def queryParams(queryParams: QueryParams): URL = @@ -178,6 +184,84 @@ final case class URL( case _ => self.copy(kind = URL.Location.Relative) } + /** + * RFC 3986 § 5.2 Relative Resolution + * @param reference + * the URL to resolve relative to ``this`` base URL + * @return + * the target URL + */ + def resolve(reference: URL): Either[String, URL] = { + // See https://www.rfc-editor.org/rfc/rfc3986#section-5.2 + // § 5.2.1 - `self` is the base and already pre-parsed into components + // § 5.2.2 - strict parsing does not ignore the reference URL scheme, so we use it directly, instead of un-setting it + + if (reference.kind.isRelative) { + // § 5.2.2 - reference scheme is undefined, i.e. it is relative + self.kind match { + // § 5.2.1 - `self` is the base and is required to have a scheme, therefore it must be absolute + case Location.Relative => Left("cannot resolve against relative url") + + case location: Location.Absolute => + var path: Path = null + var query: QueryParams = null + + if (reference.path.isEmpty) { + // § 5.2.2 - empty reference path keeps base path unmodified + path = self.path + // § 5.2.2 - given an empty reference path, use non-empty reference query params, + // while empty reference query params keeps base query params + // NOTE: strictly, if the reference defines a query it should be used, even if that query is empty + // but currently no-query is not differentiated from empty-query + if (reference.queryParams.isEmpty) { + query = self.queryParams + } else { + query = reference.queryParams + } + } else { + // § 5.2.2 - non-empty reference path always keeps reference query params + query = reference.queryParams + + if (reference.path.hasLeadingSlash) { + // § 5.2.2 - reference path starts from root, keep reference path without dot segments + path = reference.path.removeDotSegments + } else { + // § 5.2.2 - merge base and reference paths, then collapse dot segments + // § 5.2.3 - if base has an authority AND an empty path, use the reference path, ensuring a leading slash + // the authority is the [user]@host[:port], which is always present on `self`, + // so we only need to check for an empty path + if (self.path.isEmpty) { + path = reference.path.addLeadingSlash + } else { + // § 5.2.3 - otherwise (base has no authority OR a non-empty path), drop the very last portion of the base path, + // and append all the reference path components + path = Path( + Path.Flags.concat(self.path.flags, reference.path.flags), + self.path.segments.dropRight(1) ++ reference.path.segments, + ) + } + + path = path.removeDotSegments + } + } + + val url = URL(path, location, query, reference.fragment) + + Right(url) + + } + } else { + // § 5.2.2 - if the reference scheme is defined, i.e. the reference is absolute, + // the target components are the reference components but with dot segments removed + + // § 5.2.2 - if the reference scheme is undefined and authority is defined, keep the base scheme + // and take everything else from the reference, removing dot segments from the path + // NOTE: URL currently does not track authority separate from scheme to implement this + // so having an authority is the same as having a scheme and they are treated the same + Right(reference.copy(path = reference.path.removeDotSegments)) + } + } + def scheme: Option[Scheme] = kind match { case Location.Absolute(scheme, _, _) => Some(scheme) case Location.Relative => None @@ -185,7 +269,7 @@ final case class URL( def scheme(scheme: Scheme): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(scheme, "", URL.portFromScheme(scheme)) + case URL.Location.Relative => URL.Location.Absolute(scheme, "", None) case abs: URL.Location.Absolute => abs.copy(scheme = scheme) } @@ -239,7 +323,11 @@ object URL { } object Location { - final case class Absolute(scheme: Scheme, host: String, port: Int) extends Location + final case class Absolute(scheme: Scheme, host: String, originalPort: Option[Int]) extends Location { + def portOrDefault: Option[Int] = originalPort.orElse(scheme.defaultPort) + def portIfNotDefault: Option[Int] = originalPort.filter(p => scheme.defaultPort.exists(_ != p)) + def port: Int = originalPort.orElse(scheme.defaultPort).getOrElse(Scheme.defaultPortForHTTP) + } case object Relative extends Location } @@ -262,13 +350,13 @@ object URL { ) + url.fragment.fold("")(f => "#" + f.raw) url.kind match { - case Location.Relative => - path(true) - case Location.Absolute(scheme, host, port) => + case Location.Relative => path(true) + case abs: Location.Absolute => val path2 = path(false) - - if (port == portFromScheme(scheme)) s"${scheme.encode}://$host$path2" - else s"${scheme.encode}://$host:$port$path2" + abs.portIfNotDefault match { + case None => s"${abs.scheme.encode}://${abs.host}$path2" + case Some(customPort) => s"${abs.scheme.encode}://${abs.host}:$customPort$path2" + } } } @@ -277,7 +365,7 @@ object URL { scheme <- Scheme.decode(uri.getScheme) host <- Option(uri.getHost) path <- Option(uri.getRawPath) - port = Option(uri.getPort).filter(_ != -1).getOrElse(portFromScheme(scheme)) + port = Option(uri.getPort).filter(_ != -1).orElse(scheme.defaultPort) // FIXME REMOVE defaultPort connection = URL.Location.Absolute(scheme, host, port) path2 = Path.decode(path) path3 = if (path.nonEmpty) path2.addLeadingSlash else path2 @@ -288,9 +376,4 @@ object URL { path <- Option(uri.getRawPath) } yield URL(Path.decode(path), Location.Relative, QueryParams.decode(uri.getRawQuery), Fragment.fromURI(uri)) - private def portFromScheme(scheme: Scheme): Int = scheme match { - case Scheme.HTTP | Scheme.WS => 80 - case Scheme.HTTPS | Scheme.WSS => 443 - } - } diff --git a/zio-http/src/main/scala/zio/http/ZClient.scala b/zio-http/src/main/scala/zio/http/ZClient.scala index 477ee85cc9..0114a0813e 100644 --- a/zio-http/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/src/main/scala/zio/http/ZClient.scala @@ -535,6 +535,7 @@ object ZClient { ssl: Option[ClientSSLConfig], proxy: Option[zio.http.Proxy], connectionPool: ConnectionPoolConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, requestDecompression: Decompression, localAddress: Option[InetSocketAddress], @@ -557,6 +558,8 @@ object ZClient { def disabledConnectionPool: Config = self.copy(connectionPool = ConnectionPoolConfig.Disabled) + def maxInitialLineLength(initialLineLength: Int): Config = self.copy(maxInitialLineLength = initialLineLength) + /** * Configure the client to use `maxHeaderSize` value when encode/decode * headers. @@ -590,6 +593,7 @@ object ZClient { ClientSSLConfig.config.nested("ssl").optional.withDefault(Config.default.ssl) ++ zio.http.Proxy.config.nested("proxy").optional.withDefault(Config.default.proxy) ++ ConnectionPoolConfig.config.nested("connection-pool").withDefault(Config.default.connectionPool) ++ + zio.Config.int("max-initial-line-length").withDefault(Config.default.maxInitialLineLength) ++ zio.Config.int("max-header-size").withDefault(Config.default.maxHeaderSize) ++ Decompression.config.nested("request-decompression").withDefault(Config.default.requestDecompression) ++ zio.Config.boolean("add-user-agent-header").withDefault(Config.default.addUserAgentHeader) ++ @@ -600,6 +604,7 @@ object ZClient { ssl, proxy, connectionPool, + maxInitialLineLength, maxHeaderSize, requestDecompression, addUserAgentHeader, @@ -610,6 +615,7 @@ object ZClient { ssl = ssl, proxy = proxy, connectionPool = connectionPool, + maxInitialLineLength = maxInitialLineLength, maxHeaderSize = maxHeaderSize, requestDecompression = requestDecompression, addUserAgentHeader = addUserAgentHeader, @@ -622,6 +628,7 @@ object ZClient { ssl = None, proxy = None, connectionPool = ConnectionPoolConfig.Fixed(10), + maxInitialLineLength = 4096, maxHeaderSize = 8192, requestDecompression = Decompression.No, localAddress = None, @@ -668,18 +675,14 @@ object ZClient { app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] = for { - env <- ZIO.environment[Env1] - webSocketUrl = url.scheme( - url.scheme match { - case Some(Scheme.HTTP) => Scheme.WS - case Some(Scheme.HTTPS) => Scheme.WSS - case Some(Scheme.WS) => Scheme.WS - case Some(Scheme.WSS) => Scheme.WSS - case None => Scheme.WS - }, - ) - scope <- ZIO.scope - res <- requestAsync( + env <- ZIO.environment[Env1] + webSocketUrl <- url.scheme match { + case Some(Scheme.HTTP) | Some(Scheme.WS) | None => ZIO.succeed(url.scheme(Scheme.WS)) + case Some(Scheme.WSS) | Some(Scheme.HTTPS) => ZIO.succeed(url.scheme(Scheme.WSS)) + case _ => ZIO.fail(throw new IllegalArgumentException("URL's scheme MUST be WS(S) or HTTP(S)")) + } + scope <- ZIO.scope + res <- requestAsync( Request(version = version, method = Method.GET, url = webSocketUrl, headers = headers), config, () => app.provideEnvironment(env), @@ -712,6 +715,7 @@ object ZClient { location, clientConfig.proxy, clientConfig.ssl.getOrElse(ClientSSLConfig.Default), + clientConfig.maxInitialLineLength, clientConfig.maxHeaderSize, clientConfig.requestDecompression, clientConfig.idleTimeout, diff --git a/zio-http/src/main/scala/zio/http/ZClientAspect.scala b/zio-http/src/main/scala/zio/http/ZClientAspect.scala index 82bf529893..bf7c011f2e 100644 --- a/zio-http/src/main/scala/zio/http/ZClientAspect.scala +++ b/zio-http/src/main/scala/zio/http/ZClientAspect.scala @@ -222,11 +222,12 @@ object ZClientAspect { case (duration, Exit.Success(response)) => ZIO .logLevel(level(response.status)) { - def requestHeaders = + def requestHeaders = headers.collect { case header: Header if loggedRequestHeaderNames.contains(header.headerName.toLowerCase) => LogAnnotation(header.headerName, header.renderedValue) }.toSet + def responseHeaders = response.headers.collect { case header: Header if loggedResponseHeaderNames.contains(header.headerName.toLowerCase) => @@ -318,4 +319,121 @@ object ZClientAspect { } } } + + final def followRedirects[R, E](max: Int)( + onRedirectError: (Response, String) => ZIO[R, E, Response], + )(implicit trace: Trace): ZClientAspect[Nothing, R, Nothing, Body, E, Any, Nothing, Response] = { + new ZClientAspect[Nothing, R, Nothing, Body, E, Any, Nothing, Response] { + override def apply[ + Env >: Nothing <: R, + In >: Nothing <: Body, + Err >: E <: Any, + Out >: Nothing <: Response, + ](client: ZClient[Env, In, Err, Out]): ZClient[Env, In, Err, Out] = { + val oldDriver = client.driver + + val newDriver = new ZClient.Driver[Env, Err] { + def scopedRedirectErr(resp: Response, message: String) = + ZIO.scopeWith(_ => onRedirectError(resp, message)) + + override def request( + version: Version, + method: Method, + url: URL, + headers: Headers, + body: Body, + sslConfig: Option[ClientSSLConfig], + proxy: Option[Proxy], + )(implicit trace: Trace): ZIO[Env & Scope, Err, Response] = { + def req( + attempt: Int, + version: Version, + method: Method, + url: URL, + headers: Headers, + body: Body, + sslConfig: Option[ClientSSLConfig], + proxy: Option[Proxy], + ): ZIO[Env & Scope, Err, Response] = { + oldDriver.request(version, method, url, headers, body, sslConfig, proxy).flatMap { resp => + if (resp.status.isRedirection) { + if (attempt < max) { + resp.headerOrFail(Header.Location) match { + case Some(locOrError) => + locOrError match { + case Left(locHeaderErr) => + scopedRedirectErr(resp, locHeaderErr) + + case Right(loc) => + url.resolve(loc.url) match { + case Left(relativeResolveErr) => + scopedRedirectErr(resp, relativeResolveErr) + + case Right(resolved) => + req(attempt + 1, version, method, resolved, headers, body, sslConfig, proxy) + } + } + case None => + scopedRedirectErr(resp, "no location header to resolve redirect") + } + } else { + scopedRedirectErr(resp, "followed maximum redirects") + } + } else { + ZIO.succeed(resp) + } + } + } + + req(0, version, method, url, headers, body, sslConfig, proxy) + } + + override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])( + implicit trace: Trace, + ): ZIO[Env1 & Scope, Err, Response] = { + def sock( + attempt: Int, + version: Version, + url: URL, + headers: Headers, + app: WebSocketApp[Env1], + ): ZIO[Env1 & Scope, Err, Response] = { + oldDriver.socket(version, url, headers, app).flatMap { resp => + if (resp.status.isRedirection) { + if (attempt < max) { + resp.headerOrFail(Header.Location) match { + case Some(locOrError) => + locOrError match { + case Left(locHeaderErr) => + scopedRedirectErr(resp, locHeaderErr) + + case Right(loc) => + url.resolve(loc.url) match { + case Left(relativeResolveErr) => + scopedRedirectErr(resp, relativeResolveErr) + + case Right(resolved) => + sock(attempt + 1, version, resolved, headers, app) + } + } + case None => + scopedRedirectErr(resp, "no location header to resolve redirect") + } + } else { + scopedRedirectErr(resp, "followed maximum redirects") + } + } else { + ZIO.succeed(resp) + } + } + } + + sock(0, version, url, headers, app) + } + } + + client.transform(client.bodyEncoder, client.bodyDecoder, newDriver) + } + } + } } diff --git a/zio-http/src/main/scala/zio/http/codec/Doc.scala b/zio-http/src/main/scala/zio/http/codec/Doc.scala index 4b3cf5a01c..af25f719bd 100644 --- a/zio-http/src/main/scala/zio/http/codec/Doc.scala +++ b/zio-http/src/main/scala/zio/http/codec/Doc.scala @@ -16,6 +16,11 @@ package zio.http.codec +import zio.Chunk +import zio.stacktracer.TracingImplicits.disableAutoTrace + +import zio.schema.Schema + import zio.http.codec.Doc.Span.CodeStyle import zio.http.template @@ -42,6 +47,13 @@ sealed trait Doc { self => case _ => false } + private[zio] def flattened: Chunk[Doc] = + self match { + case Doc.Empty => Chunk.empty + case Doc.Sequence(left, right) => left.flattened ++ right.flattened + case x => Chunk(x) + } + def toCommonMark: String = { val writer = new StringBuilder @@ -315,6 +327,12 @@ sealed trait Doc { self => } object Doc { + implicit val schemaDocSchema: Schema[Doc] = + Schema[String].transform( + fromCommonMark, + _.toCommonMark, + ) + def fromCommonMark(commonMark: String): Doc = Doc.Raw(commonMark, RawDocType.CommonMark) diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala index 6e49a16009..a5e2147743 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala @@ -18,6 +18,7 @@ package zio.http.codec import java.util.concurrent.ConcurrentHashMap +import scala.annotation.tailrec import scala.language.implicitConversions import scala.reflect.ClassTag @@ -192,6 +193,18 @@ sealed trait HttpCodec[-AtomTypes, Value] { ): Task[Value] = encoderDecoder(Chunk.empty).decode(url, status, method, headers, body) + def doc: Option[Doc] = { + @tailrec + def loop(codec: HttpCodec[_, _]): Option[Doc] = + codec match { + case Annotated(_, Metadata.Documented(doc)) => Some(doc) + case Annotated(codec, _) => loop(codec) + case _ => None + } + + loop(self) + } + /** * Uses this codec to encode the Scala value into a request. */ @@ -630,6 +643,8 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with final case class Examples[A](examples: Map[String, A]) extends Metadata[A] final case class Documented[A](doc: Doc) extends Metadata[A] + + final case class Deprecated[A](doc: Doc) extends Metadata[A] } private[http] final case class TransformOrFail[AtomType, X, A]( diff --git a/zio-http/src/main/scala/zio/http/codec/PathCodec.scala b/zio-http/src/main/scala/zio/http/codec/PathCodec.scala index ea56ec1905..dbb3fb470d 100644 --- a/zio-http/src/main/scala/zio/http/codec/PathCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/PathCodec.scala @@ -16,14 +16,12 @@ package zio.http.codec -import scala.annotation.tailrec import scala.collection.immutable.ListMap import scala.language.implicitConversions -import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Chunk, NonEmptyChunk} +import zio._ -import zio.http.Path +import zio.http._ /** * A codec for paths, which consists of segments, where each segment may be a @@ -50,11 +48,10 @@ sealed trait PathCodec[A] { self => final def /[B](that: PathCodec[B])(implicit combiner: Combiner[A, B]): PathCodec[combiner.Out] = self ++ that - /** - * Returns a new pattern that is extended with the specified segment pattern. - */ - final def /[B](segment: SegmentCodec[B])(implicit combiner: Combiner[A, B]): PathCodec[combiner.Out] = - self ++ Segment[B](segment) + final def /[Env](routes: Routes[Env, Response])(implicit + ev: PathCodec[A] <:< PathCodec[Unit], + ): Routes[Env, Response] = + routes.nest(ev(self)) final def asType[B](implicit ev: A =:= B): PathCodec[B] = self.asInstanceOf[PathCodec[B]] @@ -239,7 +236,7 @@ sealed trait PathCodec[A] { self => rightPath <- loop(right, rightValue) } yield leftPath ++ rightPath - case PathCodec.Segment(segment, _) => + case PathCodec.Segment(segment) => Right(segment.format(value.asInstanceOf[segment.Type])) case PathCodec.TransformOrFail(api, _, g) => @@ -264,16 +261,17 @@ sealed trait PathCodec[A] { self => private[http] def optimize: Array[Opt] = { def loop(pattern: PathCodec[_]): Chunk[Opt] = pattern match { - case PathCodec.Segment(segment, _) => + case PathCodec.Segment(segment) => Chunk(segment.asInstanceOf[SegmentCodec[_]] match { - case SegmentCodec.Empty(_) => Opt.Unit - case SegmentCodec.Literal(value, _) => Opt.Match(value) - case SegmentCodec.IntSeg(_, _) => Opt.IntOpt - case SegmentCodec.LongSeg(_, _) => Opt.LongOpt - case SegmentCodec.Text(_, _) => Opt.StringOpt - case SegmentCodec.UUID(_, _) => Opt.UUIDOpt - case SegmentCodec.BoolSeg(_, _) => Opt.BoolOpt - case SegmentCodec.Trailing(_) => Opt.TrailingOpt + case SegmentCodec.Empty => Opt.Unit + case SegmentCodec.Literal(value) => Opt.Match(value) + case SegmentCodec.IntSeg(_) => Opt.IntOpt + case SegmentCodec.LongSeg(_) => Opt.LongOpt + case SegmentCodec.Text(_) => Opt.StringOpt + case SegmentCodec.UUID(_) => Opt.UUIDOpt + case SegmentCodec.BoolSeg(_) => Opt.BoolOpt + case SegmentCodec.Trailing => Opt.TrailingOpt + case SegmentCodec.Annotated(codec, _) => loop(PathCodec.Segment(codec)).head }) case Concat(left, right, combiner, _) => @@ -296,7 +294,7 @@ sealed trait PathCodec[A] { self => case PathCodec.Concat(left, right, _, _) => loop(left) + loop(right) - case PathCodec.Segment(segment, _) => segment.render + case PathCodec.Segment(segment) => segment.render case PathCodec.TransformOrFail(api, _, _) => loop(api) @@ -305,12 +303,27 @@ sealed trait PathCodec[A] { self => loop(self) } + private[zio] def renderIgnoreTrailing: String = { + def loop(path: PathCodec[_]): String = path match { + case PathCodec.Concat(left, right, _, _) => + loop(left) + loop(right) + + case PathCodec.Segment(SegmentCodec.Trailing) => "" + + case PathCodec.Segment(segment) => segment.render + + case PathCodec.TransformOrFail(api, _, _) => loop(api) + } + + loop(self) + } + /** * Returns the segments of the path codec. */ def segments: Chunk[SegmentCodec[_]] = { def loop(path: PathCodec[_]): Chunk[SegmentCodec[_]] = path match { - case PathCodec.Segment(segment, _) => Chunk(segment) + case PathCodec.Segment(segment) => Chunk(segment) case PathCodec.Concat(left, right, _, _) => loop(left) ++ loop(right) @@ -344,9 +357,14 @@ object PathCodec { def apply(value: String): PathCodec[Unit] = { val path = Path(value) - path.segments.foldLeft[PathCodec[Unit]](PathCodec.empty) { (pathSpec, segment) => - pathSpec./[Unit](SegmentCodec.literal(segment)) + path.segments match { + case Chunk() => PathCodec.empty + case Chunk(first, rest @ _*) => + rest.foldLeft[PathCodec[Unit]](Segment(SegmentCodec.literal(first))) { (pathSpec, segment) => + pathSpec / Segment(SegmentCodec.literal(segment)) + } } + } def bool(name: String): PathCodec[Boolean] = Segment(SegmentCodec.bool(name)) @@ -354,7 +372,7 @@ object PathCodec { /** * The empty / root path codec. */ - def empty: PathCodec[Unit] = Segment[Unit](SegmentCodec.Empty()) + def empty: PathCodec[Unit] = Segment[Unit](SegmentCodec.Empty) def int(name: String): PathCodec[Int] = Segment(SegmentCodec.int(name)) @@ -366,12 +384,13 @@ object PathCodec { def string(name: String): PathCodec[String] = Segment(SegmentCodec.string(name)) - def trailing: PathCodec[Path] = Segment(SegmentCodec.Trailing()) + def trailing: PathCodec[Path] = Segment(SegmentCodec.Trailing) def uuid(name: String): PathCodec[java.util.UUID] = Segment(SegmentCodec.uuid(name)) - private[http] final case class Segment[A](segment: SegmentCodec[A], doc: Doc = Doc.empty) extends PathCodec[A] { - def ??(doc: Doc): Segment[A] = copy(doc = this.doc + doc) + private[http] final case class Segment[A](segment: SegmentCodec[A]) extends PathCodec[A] { + def ??(doc: Doc): Segment[A] = copy(segment ?? doc) + def doc: Doc = segment.doc } private[http] final case class Concat[A, B, C]( left: PathCodec[A], @@ -502,14 +521,14 @@ object PathCodec { .foldRight[SegmentSubtree[A]](SegmentSubtree(ListMap(), ListMap(), Chunk(value))) { case (segment, subtree) => val literals = segment match { - case SegmentCodec.Literal(value, _) => ListMap(value -> subtree) - case _ => ListMap.empty[String, SegmentSubtree[A]] + case SegmentCodec.Literal(value) => ListMap(value -> subtree) + case _ => ListMap.empty[String, SegmentSubtree[A]] } val others = ListMap[SegmentCodec[_], SegmentSubtree[A]]((segment match { - case SegmentCodec.Literal(_, _) => Chunk.empty - case _ => Chunk((segment, subtree)) + case SegmentCodec.Literal(_) => Chunk.empty + case _ => Chunk((segment, subtree)) }): _*) SegmentSubtree(literals, others, Chunk.empty) diff --git a/zio-http/src/main/scala/zio/http/codec/RichTextCodec.scala b/zio-http/src/main/scala/zio/http/codec/RichTextCodec.scala index aedc47f8f3..df2d0b53b1 100644 --- a/zio-http/src/main/scala/zio/http/codec/RichTextCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/RichTextCodec.scala @@ -21,7 +21,6 @@ import java.lang.Integer.parseInt import scala.annotation.tailrec import scala.collection.immutable.BitSet -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.{Chunk, NonEmptyChunk} /** @@ -108,6 +107,14 @@ sealed trait RichTextCodec[A] { self => */ final def encode(value: A): Either[String, String] = RichTextCodec.encode(value, self) + /** + * This method is Right biased merge + */ + final def merge[B](implicit ev: A <:< Either[B, B]): RichTextCodec[B] = { + val codec = self.asInstanceOf[RichTextCodec[Either[B, B]]] + codec.transform[B](_.merge)(Right(_)) + } + final def optional(default: A): RichTextCodec[Option[A]] = self.transform[Option[A]](a => Some(a))(_.fold(default)(identity)) @@ -115,10 +122,10 @@ sealed trait RichTextCodec[A] { self => ((self ~ repeat).transform[NonEmptyChunk[A]](t => NonEmptyChunk(t._1, t._2: _*))(c => (c.head, c.tail), ) | RichTextCodec.empty.as(Chunk.empty[A])) - .transform[Chunk[A]](_ match { + .transform[Chunk[A]] { case Left(nonEmpty) => nonEmpty case Right(maybeEmpty) => maybeEmpty - })(c => c.nonEmptyOrElse[Either[NonEmptyChunk[A], Chunk[A]]](Right(c))(Left(_))) + }(c => c.nonEmptyOrElse[Either[NonEmptyChunk[A], Chunk[A]]](Right(c))(Left(_))) final def singleton: RichTextCodec[NonEmptyChunk[A]] = self.transform(a => NonEmptyChunk(a))(_.head) @@ -151,10 +158,16 @@ sealed trait RichTextCodec[A] { self => case x if p(x) => x } + final def withError(errorMessage: String): RichTextCodec[A] = + (self | RichTextCodec.fail[A](errorMessage)).merge + } object RichTextCodec { private[codec] case object Empty extends RichTextCodec[Unit] - private[codec] final case class CharIn(set: BitSet) extends RichTextCodec[Char] + private[codec] final case class CharIn(set: BitSet) extends RichTextCodec[Char] { + val errorMessage: Left[String, Nothing] = + Left(s"Expected, but did not find: ${this.describe}") + } private[codec] final case class TransformOrFail[A, B]( codec: RichTextCodec[A], to: A => Either[String, B], @@ -162,7 +175,7 @@ object RichTextCodec { ) extends RichTextCodec[B] private[codec] final case class Alt[A, B](left: RichTextCodec[A], right: RichTextCodec[B]) extends RichTextCodec[Either[A, B]] - private[codec] final case class Lazy[A](codec0: () => RichTextCodec[A]) extends RichTextCodec[A] { + private[codec] final case class Lazy[A](codec0: () => RichTextCodec[A]) extends RichTextCodec[A] { lazy val codec: RichTextCodec[A] = codec0() } private[codec] final case class Zip[A, B, C]( @@ -188,6 +201,12 @@ object RichTextCodec { */ def char(c: Char): RichTextCodec[Char] = CharIn(BitSet(c.toInt)) + def chars(cs: Char*): RichTextCodec[Char] = + CharIn(BitSet(cs.map(_.toInt): _*)) + + def charsNot(cs: Char*): RichTextCodec[Char] = + filter(c => !cs.contains(c)) + /** * A codec that describes a digit character. */ @@ -200,6 +219,9 @@ object RichTextCodec { */ val empty: RichTextCodec[Unit] = Empty + def fail[A](message: String): RichTextCodec[A] = + empty.transformOrFail(_ => Left(message))(_ => Left(message)) + /** * Defines a new codec for a single character based on the specified * predicate. @@ -207,6 +229,9 @@ object RichTextCodec { def filter(pred: Char => Boolean): RichTextCodec[Char] = CharIn(BitSet((Char.MinValue to Char.MaxValue).filter(pred).map(_.toInt): _*)) + def filterOrFail(pred: Char => Boolean)(failure: String): RichTextCodec[Char] = + filter(pred).collectOrFail(failure) { case c => c } + /** * A codec that describes a letter character. */ @@ -528,9 +553,9 @@ object RichTextCodec { case Empty => Right((value, ())) - case CharIn(bitset) => + case self @ CharIn(bitset) => if (value.length == 0 || !bitset.contains(value.charAt(0).toInt)) - Left(s"Not found: ${bitset.toArray.map(_.toChar).mkString}") + self.errorMessage else Right((value.subSequence(1, value.length), value.charAt(0))) diff --git a/zio-http/src/main/scala/zio/http/codec/SegmentCodec.scala b/zio-http/src/main/scala/zio/http/codec/SegmentCodec.scala index 24137c0ced..ece07e6ba3 100644 --- a/zio-http/src/main/scala/zio/http/codec/SegmentCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/SegmentCodec.scala @@ -18,9 +18,9 @@ package zio.http.codec import scala.language.implicitConversions import zio.Chunk -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http.Path +import zio.http.codec.SegmentCodec.{Annotated, MetaData} sealed trait SegmentCodec[A] { self => private var _hashCode: Int = 0 @@ -28,9 +28,24 @@ sealed trait SegmentCodec[A] { self => final type Type = A - def ??(doc: Doc): SegmentCodec[A] + def ??(doc: Doc): SegmentCodec[A] = + SegmentCodec.Annotated(self, Chunk(MetaData.Documented(doc))) + + def example(name: String, example: A): SegmentCodec[A] = + SegmentCodec.Annotated(self, Chunk(MetaData.Examples(Map(name -> example)))) + + def examples(examples: (String, A)*): SegmentCodec[A] = + SegmentCodec.Annotated(self, Chunk(MetaData.Examples(examples.toMap))) + + lazy val doc: Doc = self.asInstanceOf[SegmentCodec[_]] match { + case SegmentCodec.Annotated(_, annotations) => + annotations.collectFirst { case MetaData.Documented(doc) => doc }.getOrElse(Doc.Empty) + case _ => + Doc.Empty + } override def equals(that: Any): Boolean = that match { + case Annotated(codec, _) => codec == this case that: SegmentCodec[_] => (this.getClass == that.getClass) && (this.render == that.render) case _ => false } @@ -43,8 +58,8 @@ sealed trait SegmentCodec[A] { self => } final def isEmpty: Boolean = self.asInstanceOf[SegmentCodec[_]] match { - case SegmentCodec.Empty(_) => true - case _ => false + case SegmentCodec.Empty => true + case _ => false } // Returns number of segments matched, or -1 if not matched: @@ -54,14 +69,15 @@ sealed trait SegmentCodec[A] { self => final def render: String = { if (_render == "") _render = self.asInstanceOf[SegmentCodec[_]] match { - case SegmentCodec.Empty(_) => s"" - case SegmentCodec.Literal(value, _) => s"/$value" - case SegmentCodec.IntSeg(name, _) => s"/{$name}" - case SegmentCodec.LongSeg(name, _) => s"/{$name}" - case SegmentCodec.Text(name, _) => s"/{$name}" - case SegmentCodec.BoolSeg(name, _) => s"/{$name}" - case SegmentCodec.UUID(name, _) => s"/{$name}" - case SegmentCodec.Trailing(_) => s"/..." + case _: SegmentCodec.Empty.type => s"" + case SegmentCodec.Literal(value) => s"/$value" + case SegmentCodec.IntSeg(name) => s"/{$name}" + case SegmentCodec.LongSeg(name) => s"/{$name}" + case SegmentCodec.Text(name) => s"/{$name}" + case SegmentCodec.BoolSeg(name) => s"/{$name}" + case SegmentCodec.UUID(name) => s"/{$name}" + case _: SegmentCodec.Trailing.type => s"/..." + case SegmentCodec.Annotated(codec, _) => codec.render } _render } @@ -81,7 +97,7 @@ sealed trait SegmentCodec[A] { self => object SegmentCodec { def bool(name: String): SegmentCodec[Boolean] = SegmentCodec.BoolSeg(name) - val empty: SegmentCodec[Unit] = SegmentCodec.Empty() + val empty: SegmentCodec[Unit] = SegmentCodec.Empty def int(name: String): SegmentCodec[Int] = SegmentCodec.IntSeg(name) @@ -92,20 +108,43 @@ object SegmentCodec { def string(name: String): SegmentCodec[String] = SegmentCodec.Text(name) - def trailing: SegmentCodec[Path] = SegmentCodec.Trailing() + def trailing: SegmentCodec[Path] = SegmentCodec.Trailing def uuid(name: String): SegmentCodec[java.util.UUID] = SegmentCodec.UUID(name) - private[http] final case class Empty(doc: Doc = Doc.empty) extends SegmentCodec[Unit] { - def ??(doc: Doc): Empty = copy(doc = this.doc + doc) + final case class Annotated[A](codec: SegmentCodec[A], annotations: Chunk[MetaData[A]]) extends SegmentCodec[A] { + + override def equals(that: Any): Boolean = + codec.equals(that) + override def ??(doc: Doc): Annotated[A] = + copy(annotations = annotations :+ MetaData.Documented(doc)) + + override def example(name: String, example: A): Annotated[A] = + copy(annotations = annotations :+ MetaData.Examples(Map(name -> example))) + + override def examples(examples: (String, A)*): Annotated[A] = + copy(annotations = annotations :+ MetaData.Examples(examples.toMap)) + + def format(value: A): Path = codec.format(value) + + def matches(segments: Chunk[String], index: Int): Int = codec.matches(segments, index) + } + + sealed trait MetaData[A] extends Product with Serializable + + object MetaData { + final case class Documented[A](value: Doc) extends MetaData[A] + final case class Examples[A](examples: Map[String, A]) extends MetaData[A] + } + + private[http] case object Empty extends SegmentCodec[Unit] { self => def format(unit: Unit): Path = Path(s"") def matches(segments: Chunk[String], index: Int): Int = 0 } - private[http] final case class Literal(value: String, doc: Doc = Doc.empty) extends SegmentCodec[Unit] { - def ??(doc: Doc): Literal = copy(doc = this.doc + doc) + private[http] final case class Literal(value: String) extends SegmentCodec[Unit] { def format(unit: Unit): Path = Path(s"/$value") @@ -115,8 +154,7 @@ object SegmentCodec { else -1 } } - private[http] final case class BoolSeg(name: String, doc: Doc = Doc.empty) extends SegmentCodec[Boolean] { - def ??(doc: Doc): BoolSeg = copy(doc = this.doc + doc) + private[http] final case class BoolSeg(name: String) extends SegmentCodec[Boolean] { def format(value: Boolean): Path = Path(s"/$value") @@ -128,8 +166,7 @@ object SegmentCodec { if (segment == "true" || segment == "false") 1 else -1 } } - private[http] final case class IntSeg(name: String, doc: Doc = Doc.empty) extends SegmentCodec[Int] { - def ??(doc: Doc): IntSeg = copy(doc = this.doc + doc) + private[http] final case class IntSeg(name: String) extends SegmentCodec[Int] { def format(value: Int): Path = Path(s"/$value") @@ -151,8 +188,7 @@ object SegmentCodec { } } } - private[http] final case class LongSeg(name: String, doc: Doc = Doc.empty) extends SegmentCodec[Long] { - def ??(doc: Doc): LongSeg = copy(doc = this.doc + doc) + private[http] final case class LongSeg(name: String) extends SegmentCodec[Long] { def format(value: Long): Path = Path(s"/$value") @@ -174,8 +210,7 @@ object SegmentCodec { } } } - private[http] final case class Text(name: String, doc: Doc = Doc.empty) extends SegmentCodec[String] { - def ??(doc: Doc): Text = copy(doc = this.doc + doc) + private[http] final case class Text(name: String) extends SegmentCodec[String] { def format(value: String): Path = Path(s"/$value") @@ -183,8 +218,7 @@ object SegmentCodec { if (index < 0 || index >= segments.length) -1 else 1 } - private[http] final case class UUID(name: String, doc: Doc = Doc.empty) extends SegmentCodec[java.util.UUID] { - def ??(doc: Doc): UUID = copy(doc = this.doc + doc) + private[http] final case class UUID(name: String) extends SegmentCodec[java.util.UUID] { def format(value: java.util.UUID): Path = Path(s"/$value") @@ -221,9 +255,7 @@ object SegmentCodec { } } - final case class Trailing(doc: Doc = Doc.empty) extends SegmentCodec[Path] { self => - def ??(doc: Doc): SegmentCodec[Path] = copy(doc = this.doc + doc) - + case object Trailing extends SegmentCodec[Path] { self => def format(value: Path): Path = value def matches(segments: Chunk[String], index: Int): Int = diff --git a/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala b/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala index e60f55a60c..a03465f2f3 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala @@ -33,7 +33,7 @@ import zio.http.{Body, MediaType} * A BodyCodec encapsulates the logic necessary to both encode and decode bodies * for a media type, using ZIO Schema codecs and schemas. */ -private[internal] sealed trait BodyCodec[A] { self => +private[http] sealed trait BodyCodec[A] { self => /** * The element type, described by the schema. This could be the type of the @@ -88,7 +88,7 @@ private[internal] sealed trait BodyCodec[A] { self => */ def name: Option[String] } -private[internal] object BodyCodec { +private[http] object BodyCodec { case object Empty extends BodyCodec[Unit] { type Element = Unit diff --git a/zio-http/src/main/scala/zio/http/endpoint/Endpoint.scala b/zio-http/src/main/scala/zio/http/endpoint/Endpoint.scala index b1f6fcf3a4..ffbd9aea88 100644 --- a/zio-http/src/main/scala/zio/http/endpoint/Endpoint.scala +++ b/zio-http/src/main/scala/zio/http/endpoint/Endpoint.scala @@ -19,7 +19,6 @@ package zio.http.endpoint import scala.reflect.ClassTag import zio._ -import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.stream.ZStream @@ -27,7 +26,7 @@ import zio.schema._ import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ -import zio.http.codec.{HttpCodec, _} +import zio.http.codec._ import zio.http.endpoint.Endpoint.{OutErrors, defaultMediaTypes} /** @@ -235,6 +234,30 @@ final case class Endpoint[PathInput, Input, Err, Output, Middleware <: EndpointM ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = copy(input = input ++ (HttpCodec.content(name)(schema) ?? doc)) + def in[Input2](mediaType: MediaType)(implicit + schema: Schema[Input2], + combiner: Combiner[Input, Input2], + ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = + copy(input = input ++ HttpCodec.content(mediaType)(schema)) + + def in[Input2](mediaType: MediaType, doc: Doc)(implicit + schema: Schema[Input2], + combiner: Combiner[Input, Input2], + ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = + copy(input = input ++ (HttpCodec.content(mediaType) ?? doc)) + + def in[Input2](mediaType: MediaType, name: String)(implicit + schema: Schema[Input2], + combiner: Combiner[Input, Input2], + ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = + copy(input = input ++ HttpCodec.content(name, mediaType)) + + def in[Input2](mediaType: MediaType, name: String, doc: Doc)(implicit + schema: Schema[Input2], + combiner: Combiner[Input, Input2], + ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = + copy(input = input ++ (HttpCodec.content(name, mediaType) ?? doc)) + /** * Returns a new endpoint derived from this one, whose request must satisfy * the specified codec. @@ -334,7 +357,7 @@ final case class Endpoint[PathInput, Input, Err, Output, Middleware <: EndpointM Endpoint( route, input, - output = (self.output | HttpCodec.content(implicitly[Schema[Output2]])) ++ StatusCodec.status(Status.Ok), + output = self.output | (HttpCodec.content(implicitly[Schema[Output2]]) ++ StatusCodec.status(Status.Ok)), error, doc, mw, @@ -387,7 +410,7 @@ final case class Endpoint[PathInput, Input, Err, Output, Middleware <: EndpointM input, output = self.output | ((HttpCodec.content(implicitly[Schema[Output2]]) ++ StatusCodec.status(status)) ?? doc), error, - doc, + Doc.empty, mw, ) @@ -402,7 +425,7 @@ final case class Endpoint[PathInput, Input, Err, Output, Middleware <: EndpointM Endpoint( route, input, - output = self.output | (HttpCodec.content(mediaType)(implicitly[Schema[Output2]]) ?? doc), + output = self.output | (HttpCodec.content(mediaType)(implicitly[Schema[Output2]]) ++ StatusCodec.Ok ?? doc), error, doc, mw, @@ -583,6 +606,41 @@ final case class Endpoint[PathInput, Input, Err, Output, Middleware <: EndpointM combiner: Combiner[Input, A], ): Endpoint[PathInput, combiner.Out, Err, Output, Middleware] = copy(input = self.input ++ codec) + + /** + * Transforms the input of this endpoint using the specified functions. This + * is useful to build from different http inputs a domain specific input. + * + * For example + * {{{ + * case class ChangeUserName(userId: UUID, name: String) + * val endpoint = + * Endpoint(Method.POST / "user" / uuid("userId") / "changeName").in[String] + * .transformIn { case (userId, name) => ChangeUserName(userId, name) } { + * case ChangeUserName(userId, name) => (userId, name) + * } + * }}} + */ + def transformIn[Input1](f: Input => Input1)( + g: Input1 => Input, + ): Endpoint[PathInput, Input1, Err, Output, Middleware] = + copy(input = self.input.transform(f)(g)) + + /** + * Transforms the output of this endpoint using the specified functions. + */ + def transformOut[Output1](f: Output => Output1)( + g: Output1 => Output, + ): Endpoint[PathInput, Input, Err, Output1, Middleware] = + copy(output = self.output.transform(f)(g)) + + /** + * Transforms the error of this endpoint using the specified functions. + */ + def transformError[Err1](f: Err => Err1)( + g: Err1 => Err, + ): Endpoint[PathInput, Input, Err1, Output, Middleware] = + copy(error = self.error.transform(f)(g)) } object Endpoint { diff --git a/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonRenderer.scala b/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonRenderer.scala deleted file mode 100644 index 194c768f39..0000000000 --- a/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonRenderer.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package zio.http.endpoint.openapi - -import zio.NonEmptyChunk -import zio.http.codec.Doc -import zio.http.endpoint.openapi.OpenAPI.LiteralOrExpression -import zio.http.Status -import zio.stacktracer.TracingImplicits.disableAutoTrace - -import java.net.URI -import java.util.Base64 -import scala.language.implicitConversions // scalafix:ok; - -private[openapi] object JsonRenderer { - sealed trait Renderable[A] { - def render(a: A): String - } - - implicit class Renderer[T](t: T)(implicit val renderable: Renderable[T]) { - def render: String = renderable.render(t) - - val skip: Boolean = - t.asInstanceOf[Any] match { - case None => true - case _ => false - } - } - - def renderFields(fieldsIt: (String, Renderer[_])*): String = { - if (fieldsIt.map(_._1).toSet.size != fieldsIt.size) { - throw new IllegalArgumentException("Duplicate field names") - } else { - val fields = fieldsIt - .filterNot(_._2.skip) - .map { case (name, value) => s""""$name":${value.render}""" } - s"{${fields.mkString(",")}}" - } - } - - private def renderKey[K](k: K)(implicit renderable: Renderable[K]) = - if (renderable.render(k).startsWith("\"") && renderable.render(k).endsWith("\"")) renderable.render(k) - else s""""${renderable.render(k)}"""" - - implicit def stringRenderable[T <: String]: Renderable[T] = new Renderable[T] { - def render(a: T): String = s""""$a"""" - } - - implicit def intRenderable[T <: Int]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toString - } - - implicit def Renderable[T <: Long]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toString - } - - implicit def floatRenderable[T <: Float]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toString - } - - implicit def doubleRenderable[T <: Double]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toString - } - - implicit def booleanRenderable[T <: Boolean]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toString - } - - implicit val uriRenderable: Renderable[URI] = new Renderable[URI] { - def render(a: URI): String = s""""${a.toString}"""" - } - - implicit val statusRenderable: Renderable[Status] = new Renderable[Status] { - def render(a: Status): String = a.code.toString - } - - implicit val docRenderable: Renderable[Doc] = new Renderable[Doc] { - def render(a: Doc): String = s""""${Base64.getEncoder.encodeToString(a.toCommonMark.getBytes)}"""" - } - - implicit def openapiBaseRenderable[T <: OpenAPIBase]: Renderable[T] = new Renderable[T] { - def render(a: T): String = a.toJson - } - - implicit def optionRenderable[A](implicit renderable: Renderable[A]): Renderable[Option[A]] = - new Renderable[Option[A]] { - def render(a: Option[A]): String = a match { - case Some(value) => renderable.render(value) - case None => "null" - } - } - - implicit def nonEmptyChunkRenderable[A](implicit renderable: Renderable[A]): Renderable[NonEmptyChunk[A]] = - new Renderable[NonEmptyChunk[A]] { - def render(a: NonEmptyChunk[A]): String = s"[${a.map(renderable.render).mkString(",")}]" - } - - implicit def setRenderable[A](implicit renderable: Renderable[A]): Renderable[Set[A]] = - new Renderable[Set[A]] { - def render(a: Set[A]): String = s"[${a.map(renderable.render).mkString(",")}]" - } - - implicit def listRenderable[A](implicit renderable: Renderable[A]): Renderable[List[A]] = - new Renderable[List[A]] { - def render(a: List[A]): String = s"[${a.map(renderable.render).mkString(",")}]" - } - - implicit def mapRenderable[K, V](implicit rK: Renderable[K], rV: Renderable[V]): Renderable[Map[K, V]] = - new Renderable[Map[K, V]] { - def render(a: Map[K, V]): String = - s"{${a.map { case (k, v) => s"${renderKey(k)}:${rV.render(v)}" }.mkString(",")}}" - } - - implicit def tupleRenderable[A, B](implicit rA: Renderable[A], rB: Renderable[B]): Renderable[(A, B)] = - new Renderable[(A, B)] { - def render(a: (A, B)): String = s"{${renderKey(a._1)}:${rB.render(a._2)}}" - } - - implicit def literalOrExpressionRenderable: Renderable[LiteralOrExpression] = - new Renderable[LiteralOrExpression] { - def render(a: LiteralOrExpression): String = a match { - case LiteralOrExpression.NumberLiteral(value) => implicitly[Renderable[Long]].render(value) - case LiteralOrExpression.DecimalLiteral(value) => implicitly[Renderable[Double]].render(value) - case LiteralOrExpression.StringLiteral(value) => implicitly[Renderable[String]].render(value) - case LiteralOrExpression.BooleanLiteral(value) => implicitly[Renderable[Boolean]].render(value) - case LiteralOrExpression.Expression(value) => implicitly[Renderable[String]].render(value) - } - } -} diff --git a/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala new file mode 100644 index 0000000000..abb6642845 --- /dev/null +++ b/zio-http/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -0,0 +1,994 @@ +package zio.http.endpoint.openapi + +import zio._ +import zio.json.ast.Json + +import zio.schema.Schema.CaseClass0 +import zio.schema._ +import zio.schema.annotation._ +import zio.schema.codec._ +import zio.schema.codec.json._ + +import zio.http.codec.{SegmentCodec, TextCodec} + +private[openapi] case class SerializableJsonSchema( + @fieldName("$ref") ref: Option[String] = None, + @fieldName("type") schemaType: Option[TypeOrTypes] = None, + format: Option[String] = None, + oneOf: Option[Chunk[SerializableJsonSchema]] = None, + allOf: Option[Chunk[SerializableJsonSchema]] = None, + anyOf: Option[Chunk[SerializableJsonSchema]] = None, + enumValues: Option[Chunk[Json]] = None, + properties: Option[Map[String, SerializableJsonSchema]] = None, + additionalProperties: Option[BoolOrSchema] = None, + required: Option[Chunk[String]] = None, + items: Option[SerializableJsonSchema] = None, + nullable: Option[Boolean] = None, + description: Option[String] = None, + example: Option[Json] = None, + examples: Option[Chunk[Json]] = None, + discriminator: Option[OpenAPI.Discriminator] = None, + deprecated: Option[Boolean] = None, + contentEncoding: Option[String] = None, + contentMediaType: Option[String] = None, + default: Option[Json] = None, + pattern: Option[String] = None, +) { + def asNullableType(nullable: Boolean): SerializableJsonSchema = + if (nullable && schemaType.isDefined) + copy(schemaType = Some(schemaType.get.add("null"))) + else if (nullable && oneOf.isDefined) + copy(oneOf = Some(oneOf.get :+ SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))) + else if (nullable && allOf.isDefined) + SerializableJsonSchema(oneOf = + Some(Chunk(this, SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))), + ) + else if (nullable && anyOf.isDefined) + copy(anyOf = Some(anyOf.get :+ SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))) + else + this + +} + +private[openapi] object SerializableJsonSchema { + implicit val schema: Schema[SerializableJsonSchema] = DeriveSchema.gen[SerializableJsonSchema] + + val binaryCodec: BinaryCodec[SerializableJsonSchema] = + JsonCodec.schemaBasedBinaryCodec[SerializableJsonSchema](JsonCodec.Config(ignoreEmptyCollections = true))( + Schema[SerializableJsonSchema], + ) +} + +@noDiscriminator +private[openapi] sealed trait BoolOrSchema + +private[openapi] object BoolOrSchema { + implicit val schema: Schema[BoolOrSchema] = DeriveSchema.gen[BoolOrSchema] + + final case class SchemaWrapper(schema: SerializableJsonSchema) extends BoolOrSchema + + object SchemaWrapper { + implicit val schema: Schema[SchemaWrapper] = + Schema[SerializableJsonSchema].transform(SchemaWrapper(_), _.schema) + } + + final case class BooleanWrapper(value: Boolean) extends BoolOrSchema + + object BooleanWrapper { + implicit val schema: Schema[BooleanWrapper] = + Schema[Boolean].transform[BooleanWrapper](b => BooleanWrapper(b), _.value) + } +} + +@noDiscriminator +private[openapi] sealed trait TypeOrTypes { self => + def add(value: String): TypeOrTypes = + self match { + case TypeOrTypes.Type(string) => TypeOrTypes.Types(Chunk(string, value)) + case TypeOrTypes.Types(chunk) => TypeOrTypes.Types(chunk :+ value) + } +} + +private[openapi] object TypeOrTypes { + implicit val schema: Schema[TypeOrTypes] = DeriveSchema.gen[TypeOrTypes] + + final case class Type(value: String) extends TypeOrTypes + + object Type { + implicit val schema: Schema[Type] = + Schema[String].transform[Type](s => Type(s), _.value) + } + + final case class Types(value: Chunk[String]) extends TypeOrTypes + + object Types { + implicit val schema: Schema[Types] = + Schema.chunk[String].transform[Types](s => Types(s), _.value) + } +} + +final case class JsonSchemas( + root: JsonSchema, + rootRef: Option[String], + children: Map[String, JsonSchema], +) + +sealed trait JsonSchema extends Product with Serializable { self => + + lazy val toJsonBytes: Chunk[Byte] = JsonCodec.schemaBasedBinaryCodec[JsonSchema].encode(self) + + lazy val toJson: String = toJsonBytes.asString + + protected[openapi] def toSerializableSchema: SerializableJsonSchema + def annotate(annotations: Chunk[JsonSchema.MetaData]): JsonSchema = + annotations.foldLeft(self) { case (schema, annotation) => schema.annotate(annotation) } + def annotate(annotation: JsonSchema.MetaData): JsonSchema = + JsonSchema.AnnotatedSchema(self, annotation) + + def annotations: Chunk[JsonSchema.MetaData] = self match { + case JsonSchema.AnnotatedSchema(schema, annotation) => schema.annotations :+ annotation + case _ => Chunk.empty + } + + def withoutAnnotations: JsonSchema = self match { + case JsonSchema.AnnotatedSchema(schema, _) => schema.withoutAnnotations + case _ => self + } + + def examples(examples: Chunk[Json]): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Examples(examples)) + + def default(default: Option[Json]): JsonSchema = + default match { + case Some(value) => JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Default(value)) + case None => self + } + + def default(default: Json): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Default(default)) + + def description(description: String): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Description(description)) + + def description(description: Option[String]): JsonSchema = + description match { + case Some(value) => JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Description(value)) + case None => self + } + + def description: Option[String] = self.toSerializableSchema.description + + def nullable(nullable: Boolean): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Nullable(nullable)) + + def discriminator(discriminator: OpenAPI.Discriminator): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Discriminator(discriminator)) + + def discriminator(discriminator: Option[OpenAPI.Discriminator]): JsonSchema = + discriminator match { + case Some(discriminator) => + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Discriminator(discriminator)) + case None => + self + } + + def deprecated(deprecated: Boolean): JsonSchema = + if (deprecated) JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Deprecated) + else self + + def contentEncoding(encoding: JsonSchema.ContentEncoding): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.ContentEncoding(encoding)) + + def contentMediaType(mediaType: String): JsonSchema = + JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.ContentMediaType(mediaType)) + + def isPrimitive: Boolean = self match { + case _: JsonSchema.Number => true + case _: JsonSchema.Integer => true + case _: JsonSchema.String => true + case JsonSchema.Boolean => true + case JsonSchema.Null => true + case _ => false + } + +} + +object JsonSchema { + + implicit val schema: Schema[JsonSchema] = + SerializableJsonSchema.schema.transform[JsonSchema](JsonSchema.fromSerializableSchema, _.toSerializableSchema) + + private[openapi] val codec = JsonCodec.schemaBasedBinaryCodec[JsonSchema] + + private def toJsonAst(schema: Schema[_], v: Any): Json = + JsonCodec + .jsonEncoder(schema.asInstanceOf[Schema[Any]]) + .toJsonAST(v) + .toOption + .get + + private def fromSerializableSchema(schema: SerializableJsonSchema): JsonSchema = { + val additionalProperties = schema.additionalProperties match { + case Some(BoolOrSchema.BooleanWrapper(false)) => Left(false) + case Some(BoolOrSchema.BooleanWrapper(true)) => Left(true) + case Some(BoolOrSchema.SchemaWrapper(schema)) => Right(fromSerializableSchema(schema)) + case None => Left(true) + } + + var jsonSchema: JsonSchema = schema match { + case schema if schema.ref.isDefined => + RefSchema(schema.ref.get) + case schema if schema.schemaType.contains(TypeOrTypes.Type("number")) => + JsonSchema.Number(NumberFormat.fromString(schema.format.getOrElse("double"))) + case schema if schema.schemaType.contains(TypeOrTypes.Type("integer")) => + JsonSchema.Integer(IntegerFormat.fromString(schema.format.getOrElse("int64"))) + case schema if schema.schemaType.contains(TypeOrTypes.Type("string")) && schema.enumValues.isEmpty => + JsonSchema.String(schema.format.map(StringFormat.fromString), schema.pattern.map(Pattern.apply)) + case schema if schema.schemaType.contains(TypeOrTypes.Type("boolean")) => + JsonSchema.Boolean + case schema if schema.schemaType.contains(TypeOrTypes.Type("array")) => + JsonSchema.ArrayType(schema.items.map(fromSerializableSchema)) + case schema if schema.schemaType.contains(TypeOrTypes.Type("object")) || schema.schemaType.isEmpty => + JsonSchema.Object( + schema.properties + .map(_.map { case (name, schema) => name -> fromSerializableSchema(schema) }) + .getOrElse(Map.empty), + additionalProperties, + schema.required.getOrElse(Chunk.empty), + ) + case schema if schema.enumValues.isDefined => + JsonSchema.Enum(schema.enumValues.get.map(EnumValue.fromJson)) + case schema if schema.oneOf.isDefined => + OneOfSchema(schema.oneOf.get.map(fromSerializableSchema)) + case schema if schema.allOf.isDefined => + AllOfSchema(schema.allOf.get.map(fromSerializableSchema)) + case schema if schema.anyOf.isDefined => + AnyOfSchema(schema.anyOf.get.map(fromSerializableSchema)) + case schema if schema.schemaType.contains(TypeOrTypes.Type("null")) => + JsonSchema.Null + case _ => + throw new IllegalArgumentException(s"Can't convert $schema") + } + + val examples = Chunk.fromIterable(schema.example) ++ schema.examples.getOrElse(Chunk.empty) + if (examples.nonEmpty) jsonSchema = jsonSchema.examples(examples) + + schema.description match { + case Some(value) => jsonSchema = jsonSchema.description(value) + case None => () + } + + schema.nullable match { + case Some(value) => jsonSchema = jsonSchema.nullable(value) + case None => () + } + + schema.discriminator match { + case Some(value) => jsonSchema = jsonSchema.discriminator(value) + case None => () + } + + schema.contentEncoding.flatMap(ContentEncoding.fromString) match { + case Some(value) => jsonSchema = jsonSchema.contentEncoding(value) + case None => () + } + + schema.contentMediaType match { + case Some(value) => jsonSchema = jsonSchema.contentMediaType(value) + case None => () + } + + jsonSchema = jsonSchema.default(schema.default) + + jsonSchema = jsonSchema.deprecated(schema.deprecated.getOrElse(false)) + + jsonSchema + } + + def fromTextCodec(codec: TextCodec[_]): JsonSchema = + codec match { + case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) + case TextCodec.StringCodec => JsonSchema.String() + case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case TextCodec.BooleanCodec => JsonSchema.Boolean + case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) + } + + def fromSegmentCodec(codec: SegmentCodec[_]): JsonSchema = + codec match { + case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean + case SegmentCodec.IntSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case SegmentCodec.LongSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case SegmentCodec.Text(_) => JsonSchema.String() + case SegmentCodec.UUID(_) => JsonSchema.String(JsonSchema.StringFormat.UUID) + case SegmentCodec.Annotated(codec, annotations) => + fromSegmentCodec(codec).description(segmentDoc(annotations)).examples(segmentExamples(codec, annotations)) + case SegmentCodec.Literal(_) => throw new IllegalArgumentException("Literal segment is not supported.") + case SegmentCodec.Empty => throw new IllegalArgumentException("Empty segment is not supported.") + case SegmentCodec.Trailing => throw new IllegalArgumentException("Trailing segment is not supported.") + } + + private def segmentDoc(annotations: Chunk[SegmentCodec.MetaData[_]]) = + annotations.collect { case SegmentCodec.MetaData.Documented(doc) => doc }.reduceOption(_ + _).map(_.toCommonMark) + + private def segmentExamples(codec: SegmentCodec[_], annotations: Chunk[SegmentCodec.MetaData[_]]) = + Chunk.fromIterable( + annotations.collect { case SegmentCodec.MetaData.Examples(example) => example.values }.flatten.map { value => + codec match { + case SegmentCodec.Empty => throw new IllegalArgumentException("Empty segment is not supported.") + case SegmentCodec.Literal(_) => throw new IllegalArgumentException("Literal segment is not supported.") + case SegmentCodec.BoolSeg(_) => Json.Bool(value.asInstanceOf[Boolean]) + case SegmentCodec.IntSeg(_) => Json.Num(value.asInstanceOf[Int]) + case SegmentCodec.LongSeg(_) => Json.Num(value.asInstanceOf[Long]) + case SegmentCodec.Text(_) => Json.Str(value.asInstanceOf[java.lang.String]) + case SegmentCodec.UUID(_) => Json.Str(value.asInstanceOf[java.util.UUID].toString) + case SegmentCodec.Trailing => + throw new IllegalArgumentException("Trailing segment is not supported.") + case SegmentCodec.Annotated(_, _) => + throw new IllegalStateException("Annotated SegmentCodec should never be nested.") + } + }, + ) + + def fromZSchemaMulti(schema: Schema[_], refType: SchemaStyle = SchemaStyle.Inline): JsonSchemas = { + val ref = nominal(schema, refType) + schema match { + case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) => + JsonSchemas(fromZSchema(enum0, SchemaStyle.Inline), ref, Map.empty) + case enum0: Schema.Enum[_] => + JsonSchemas( + fromZSchema(enum0, SchemaStyle.Inline), + ref, + enum0.cases + .filterNot(_.annotations.exists(_.isInstanceOf[transientCase])) + .flatMap { c => + val key = + nominal(c.schema, refType) + .orElse(nominal(c.schema, SchemaStyle.Compact)) + .getOrElse(throw new Exception(s"Unsupported enum case schema: ${c.schema}")) + val nested = fromZSchemaMulti( + c.schema, + refType, + ) + nested.children + (key -> nested.root) + } + .toMap, + ) + case record: Schema.Record[_] => + val children = record.fields + .filterNot(_.annotations.exists(_.isInstanceOf[transientField])) + .flatMap { field => + val nested = fromZSchemaMulti( + field.schema, + refType, + ) + nested.rootRef.map(k => nested.children + (k -> nested.root)).getOrElse(nested.children) + } + .toMap + JsonSchemas(fromZSchema(record, SchemaStyle.Inline), ref, children) + case collection: Schema.Collection[_, _] => + collection match { + case Schema.Sequence(elementSchema, _, _, _, _) => + arraySchemaMulti(refType, ref, elementSchema) + case Schema.Map(_, valueSchema, _) => + val nested = fromZSchemaMulti(valueSchema, refType) + if (valueSchema.isInstanceOf[Schema.Primitive[_]]) { + JsonSchemas( + JsonSchema.Object( + Map.empty, + Right(nested.root), + Chunk.empty, + ), + ref, + nested.children, + ) + } else { + JsonSchemas( + JsonSchema.Object( + Map.empty, + Right(nested.root), + Chunk.empty, + ), + ref, + nested.children + (nested.rootRef.get -> nested.root), + ) + } + case Schema.Set(elementSchema, _) => + arraySchemaMulti(refType, ref, elementSchema) + } + case Schema.Transform(schema, _, _, _, _) => + fromZSchemaMulti(schema, refType) + case Schema.Primitive(_, _) => + JsonSchemas(fromZSchema(schema, SchemaStyle.Inline), ref, Map.empty) + case Schema.Optional(schema, _) => + fromZSchemaMulti(schema, refType) + case Schema.Fail(_, _) => + throw new IllegalArgumentException("Fail schema is not supported.") + case Schema.Tuple2(left, right, _) => + val leftSchema = fromZSchemaMulti(left, refType) + val rightSchema = fromZSchemaMulti(right, refType) + JsonSchemas( + AllOfSchema(Chunk(leftSchema.root, rightSchema.root)), + ref, + leftSchema.children ++ rightSchema.children, + ) + case Schema.Either(left, right, _) => + val leftSchema = fromZSchemaMulti(left, refType) + val rightSchema = fromZSchemaMulti(right, refType) + JsonSchemas( + OneOfSchema(Chunk(leftSchema.root, rightSchema.root)), + ref, + leftSchema.children ++ rightSchema.children, + ) + case Schema.Lazy(schema0) => + fromZSchemaMulti(schema0(), refType) + case Schema.Dynamic(_) => + JsonSchemas(AnyJson, None, Map.empty) + } + } + + private def arraySchemaMulti( + refType: SchemaStyle, + ref: Option[java.lang.String], + elementSchema: Schema[_], + ): JsonSchemas = { + val nested = fromZSchemaMulti(elementSchema, refType) + if (elementSchema.isInstanceOf[Schema.Primitive[_]]) { + JsonSchemas( + JsonSchema.ArrayType(Some(nested.root)), + ref, + nested.children, + ) + } else { + JsonSchemas( + JsonSchema.ArrayType(Some(nested.root)), + ref, + nested.children ++ (nested.rootRef.map(_ -> nested.root)), + ) + } + } + + def fromZSchema(schema: Schema[_], refType: SchemaStyle = SchemaStyle.Inline): JsonSchema = + schema match { + case enum0: Schema.Enum[_] if refType != SchemaStyle.Inline && nominal(enum0).isDefined => + JsonSchema.RefSchema(nominal(enum0, refType).get) + case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) => + JsonSchema.Enum( + enum0.cases.map(c => + EnumValue.Str(c.annotations.collectFirst { case caseName(name) => name }.getOrElse(c.id)), + ), + ) + case enum0: Schema.Enum[_] => + val noDiscriminator = enum0.annotations.exists(_.isInstanceOf[noDiscriminator]) + val discriminatorName0 = + enum0.annotations.collectFirst { case discriminatorName(name) => name } + val nonTransientCases = enum0.cases.filterNot(_.annotations.exists(_.isInstanceOf[transientCase])) + if (noDiscriminator) { + JsonSchema + .OneOfSchema(nonTransientCases.map(c => fromZSchema(c.schema, SchemaStyle.Compact))) + } else if (discriminatorName0.isDefined) { + JsonSchema + .OneOfSchema(nonTransientCases.map(c => fromZSchema(c.schema, SchemaStyle.Compact))) + .discriminator( + OpenAPI.Discriminator( + propertyName = discriminatorName0.get, + mapping = nonTransientCases.map { c => + val name = c.annotations.collectFirst { case caseName(name) => name }.getOrElse(c.id) + name -> nominal(c.schema, refType).orElse(nominal(c.schema, SchemaStyle.Compact)).get + }.toMap, + ), + ) + } else { + JsonSchema + .OneOfSchema(nonTransientCases.map { c => + val name = c.annotations.collectFirst { case caseName(name) => name }.getOrElse(c.id) + Object(Map(name -> fromZSchema(c.schema, SchemaStyle.Compact)), Left(false), Chunk(name)) + }) + } + case record: Schema.Record[_] if refType != SchemaStyle.Inline && nominal(record).isDefined => + JsonSchema.RefSchema(nominal(record, refType).get) + case record: Schema.Record[_] => + val additionalProperties = + if (record.annotations.exists(_.isInstanceOf[rejectExtraFields])) { + Left(false) + } else { + Left(true) + } + val nonTransientFields = + record.fields.filterNot(_.annotations.exists(_.isInstanceOf[transientField])) + JsonSchema + .Object( + Map.empty, + additionalProperties, + Chunk.empty, + ) + .addAll(nonTransientFields.map { field => + field.name -> + fromZSchema(field.schema, SchemaStyle.Compact) + .deprecated(deprecated(field.schema)) + .description(fieldDoc(field)) + .default(fieldDefault(field)) + }) + .required( + nonTransientFields + .filterNot(_.schema.isInstanceOf[Schema.Optional[_]]) + .filterNot(_.annotations.exists(_.isInstanceOf[fieldDefaultValue[_]])) + .filterNot(_.annotations.exists(_.isInstanceOf[optionalField])) + .map(_.name), + ) + .deprecated(deprecated(record)) + case collection: Schema.Collection[_, _] => + collection match { + case Schema.Sequence(elementSchema, _, _, _, _) => + JsonSchema.ArrayType(Some(fromZSchema(elementSchema, refType))) + case Schema.Map(_, valueSchema, _) => + JsonSchema.Object( + Map.empty, + Right(fromZSchema(valueSchema, refType)), + Chunk.empty, + ) + case Schema.Set(elementSchema, _) => + JsonSchema.ArrayType(Some(fromZSchema(elementSchema, refType))) + } + case Schema.Transform(schema, _, _, _, _) => + fromZSchema(schema, refType) + case Schema.Primitive(standardType, _) => + standardType match { + case StandardType.UnitType => JsonSchema.Null + case StandardType.StringType => JsonSchema.String() + case StandardType.BoolType => JsonSchema.Boolean + case StandardType.ByteType => JsonSchema.String() + case StandardType.ShortType => JsonSchema.Integer(IntegerFormat.Int32) + case StandardType.IntType => JsonSchema.Integer(IntegerFormat.Int32) + case StandardType.LongType => JsonSchema.Integer(IntegerFormat.Int64) + case StandardType.FloatType => JsonSchema.Number(NumberFormat.Float) + case StandardType.DoubleType => JsonSchema.Number(NumberFormat.Double) + case StandardType.BinaryType => JsonSchema.String() + case StandardType.CharType => JsonSchema.String() + case StandardType.UUIDType => JsonSchema.String(StringFormat.UUID) + case StandardType.BigDecimalType => JsonSchema.Number(NumberFormat.Double) // TODO: Is this correct? + case StandardType.BigIntegerType => JsonSchema.Integer(IntegerFormat.Int64) + case StandardType.DayOfWeekType => JsonSchema.String() + case StandardType.MonthType => JsonSchema.String() + case StandardType.MonthDayType => JsonSchema.String() + case StandardType.PeriodType => JsonSchema.String() + case StandardType.YearType => JsonSchema.String() + case StandardType.YearMonthType => JsonSchema.String() + case StandardType.ZoneIdType => JsonSchema.String() + case StandardType.ZoneOffsetType => JsonSchema.String() + case StandardType.DurationType => JsonSchema.String(StringFormat.Duration) + case StandardType.InstantType => JsonSchema.String() + case StandardType.LocalDateType => JsonSchema.String() + case StandardType.LocalTimeType => JsonSchema.String() + case StandardType.LocalDateTimeType => JsonSchema.String() + case StandardType.OffsetTimeType => JsonSchema.String() + case StandardType.OffsetDateTimeType => JsonSchema.String() + case StandardType.ZonedDateTimeType => JsonSchema.String() + } + + case Schema.Optional(schema, _) => fromZSchema(schema, refType).nullable(true) + case Schema.Fail(_, _) => throw new IllegalArgumentException("Fail schema is not supported.") + case Schema.Tuple2(left, right, _) => AllOfSchema(Chunk(fromZSchema(left, refType), fromZSchema(right, refType))) + case Schema.Either(left, right, _) => OneOfSchema(Chunk(fromZSchema(left, refType), fromZSchema(right, refType))) + case Schema.Lazy(schema0) => fromZSchema(schema0(), refType) + case Schema.Dynamic(_) => AnyJson + + } + + sealed trait SchemaStyle extends Product with Serializable + object SchemaStyle { + + /** Generates inline json schema */ + case object Inline extends SchemaStyle + + /** + * Generates references to json schemas under #/components/schemas/{schema} + * and uses the full package path to help to generate unique schema names. + * @see + * SchemaStyle.Compact for compact schema names. + */ + case object Reference extends SchemaStyle + + /** + * Generates references to json schemas under #/components/schemas/{schema} + * and uses the type name to help to generate schema names. + * @see + * SchemaStyle.Reference for full package path schema names to avoid name + * collisions. + */ + case object Compact extends SchemaStyle + } + + private def deprecated(schema: Schema[_]): Boolean = + schema.annotations.exists(_.isInstanceOf[scala.deprecated]) + + private def fieldDoc(schema: Schema.Field[_, _]): Option[java.lang.String] = { + val description0 = schema.annotations.collectFirst { case description(value) => value } + val defaultValue = schema.annotations.collectFirst { case fieldDefaultValue(value) => value }.map { _ => + s"${if (description0.isDefined) "\n" else ""}If not set, this field defaults to the value of the default annotation." + } + Some(description0.getOrElse("") + defaultValue.getOrElse("")) + .filter(_.nonEmpty) + } + + private def fieldDefault(schema: Schema.Field[_, _]): Option[Json] = + schema.annotations.collectFirst { case fieldDefaultValue(value) => value } + .map(toJsonAst(schema.schema, _)) + + private def nominal(schema: Schema[_], referenceType: SchemaStyle = SchemaStyle.Reference): Option[java.lang.String] = + schema match { + case enumSchema: Schema.Enum[_] => refForTypeId(enumSchema.id, referenceType) + case record: Schema.Record[_] => refForTypeId(record.id, referenceType) + case _ => None + } + + private def refForTypeId(id: TypeId, referenceType: SchemaStyle): Option[java.lang.String] = + id match { + case nominal: TypeId.Nominal if referenceType == SchemaStyle.Reference => + Some(s"#/components/schemas/${nominal.fullyQualified.replace(".", "_")}") + case nominal: TypeId.Nominal if referenceType == SchemaStyle.Compact => + Some(s"#/components/schemas/${nominal.typeName}") + case _ => + None + } + + def obj(properties: (java.lang.String, JsonSchema)*): JsonSchema = + JsonSchema.Object( + properties = properties.toMap, + additionalProperties = Left(false), + required = Chunk.fromIterable(properties.toMap.keys), + ) + + final case class AnnotatedSchema(schema: JsonSchema, annotation: MetaData) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = { + annotation match { + case MetaData.Examples(chunk) => + schema.toSerializableSchema.copy(examples = Some(chunk)) + case MetaData.Discriminator(discriminator) => + schema.toSerializableSchema.copy(discriminator = Some(discriminator)) + case MetaData.Nullable(nullable) => + schema.toSerializableSchema.asNullableType(nullable) + case MetaData.Description(description) => + schema.toSerializableSchema.copy(description = Some(description)) + case MetaData.ContentEncoding(encoding) => + schema.toSerializableSchema.copy(contentEncoding = Some(encoding.productPrefix.toLowerCase)) + case MetaData.ContentMediaType(mediaType) => + schema.toSerializableSchema.copy(contentMediaType = Some(mediaType)) + case MetaData.Deprecated => + schema.toSerializableSchema.copy(deprecated = Some(true)) + case MetaData.Default(default) => + schema.toSerializableSchema.copy(default = Some(default)) + } + } + } + + sealed trait MetaData extends Product with Serializable + object MetaData { + final case class Examples(chunk: Chunk[Json]) extends MetaData + final case class Default(default: Json) extends MetaData + final case class Discriminator(discriminator: OpenAPI.Discriminator) extends MetaData + final case class Nullable(nullable: Boolean) extends MetaData + final case class Description(description: java.lang.String) extends MetaData + final case class ContentEncoding(encoding: JsonSchema.ContentEncoding) extends MetaData + final case class ContentMediaType(mediaType: java.lang.String) extends MetaData + case object Deprecated extends MetaData + } + + sealed trait ContentEncoding extends Product with Serializable + object ContentEncoding { + case object SevenBit extends ContentEncoding + case object EightBit extends ContentEncoding + case object Binary extends ContentEncoding + case object QuotedPrintable extends ContentEncoding + case object Base16 extends ContentEncoding + case object Base32 extends ContentEncoding + case object Base64 extends ContentEncoding + + def fromString(string: java.lang.String): Option[ContentEncoding] = + string.toLowerCase match { + case "7bit" => Some(SevenBit) + case "8bit" => Some(EightBit) + case "binary" => Some(Binary) + case "quoted-print" => Some(QuotedPrintable) + case "base16" => Some(Base16) + case "base32" => Some(Base32) + case "base64" => Some(Base64) + case _ => None + } + } + + final case class RefSchema(ref: java.lang.String) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema(ref = Some(ref)) + } + + final case class OneOfSchema(oneOf: Chunk[JsonSchema]) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + oneOf = Some(oneOf.map(_.toSerializableSchema)), + ) + } + + final case class AllOfSchema(allOf: Chunk[JsonSchema]) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + allOf = Some(allOf.map(_.toSerializableSchema)), + ) + } + + final case class AnyOfSchema(anyOf: Chunk[JsonSchema]) extends JsonSchema { + def minify: JsonSchema = { + val (objects, others) = anyOf.distinct.span(_.withoutAnnotations.isInstanceOf[JsonSchema.Object]) + val markedForRemoval = (for { + obj <- objects + otherObj <- objects + notNullableSchemas = obj.withoutAnnotations.asInstanceOf[JsonSchema.Object].properties.collect { + case (name, schema) + if !schema.annotations.exists { case MetaData.Nullable(nullable) => nullable; case _ => false } => + name -> schema + } + if notNullableSchemas == otherObj.withoutAnnotations.asInstanceOf[JsonSchema.Object].properties + } yield otherObj).distinct + + val minified = objects.filterNot(markedForRemoval.contains).map { obj => + val annotations = obj.annotations + val asObject = obj.withoutAnnotations.asInstanceOf[JsonSchema.Object] + val notNullableSchemas = asObject.properties.collect { + case (name, schema) + if !schema.annotations.exists { case MetaData.Nullable(nullable) => nullable; case _ => false } => + name -> schema + } + asObject.required(asObject.required.filter(notNullableSchemas.contains)).annotate(annotations) + } + val newAnyOf = minified ++ others + + if (newAnyOf.size == 1) newAnyOf.head else AnyOfSchema(newAnyOf) + } + + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + anyOf = Some(anyOf.map(_.toSerializableSchema)), + ) + } + + final case class Number(format: NumberFormat) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("number")), + format = Some(format.productPrefix.toLowerCase), + ) + } + + sealed trait NumberFormat extends Product with Serializable + object NumberFormat { + + def fromString(string: java.lang.String): NumberFormat = + string match { + case "float" => Float + case "double" => Double + case _ => throw new IllegalArgumentException(s"Unknown number format: $string") + } + case object Float extends NumberFormat + case object Double extends NumberFormat + + } + + final case class Integer(format: IntegerFormat) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("integer")), + format = Some(format.productPrefix.toLowerCase), + ) + } + + sealed trait IntegerFormat extends Product with Serializable + object IntegerFormat { + + def fromString(string: java.lang.String): IntegerFormat = + string match { + case "int32" => Int32 + case "int64" => Int64 + case "timestamp" => Timestamp + case _ => throw new IllegalArgumentException(s"Unknown integer format: $string") + } + case object Int32 extends IntegerFormat + case object Int64 extends IntegerFormat + case object Timestamp extends IntegerFormat + } + + // TODO: Add string formats and patterns + final case class String(format: Option[StringFormat], pattern: Option[Pattern]) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("string")), + format = format.map(_.value), + pattern = pattern.map(_.value), + ) + } + + object String { + def apply(format: StringFormat): String = String(Some(format), None) + def apply(pattern: Pattern): String = String(None, Some(pattern)) + def apply(): String = String(None, None) + } + + sealed trait StringFormat extends Product with Serializable { + def value: java.lang.String + } + + object StringFormat { + case class Custom(value: java.lang.String) extends StringFormat + case object Date extends StringFormat { val value = "date" } + case object DateTime extends StringFormat { val value = "date-time" } + case object Duration extends StringFormat { val value = "duration" } + case object Email extends StringFormat { val value = "email" } + case object Hostname extends StringFormat { val value = "hostname" } + case object IdnEmail extends StringFormat { val value = "idn-email" } + case object IdnHostname extends StringFormat { val value = "idn-hostname" } + case object IPv4 extends StringFormat { val value = "ipv4" } + case object IPv6 extends StringFormat { val value = "ipv6" } + case object IRI extends StringFormat { val value = "iri" } + case object IRIReference extends StringFormat { val value = "iri-reference" } + case object JSONPointer extends StringFormat { val value = "json-pointer" } + case object Password extends StringFormat { val value = "password" } + case object Regex extends StringFormat { val value = "regex" } + case object RelativeJSONPointer extends StringFormat { val value = "relative-json-pointer" } + case object Time extends StringFormat { val value = "time" } + case object URI extends StringFormat { val value = "uri" } + case object URIRef extends StringFormat { val value = "uri-reference" } + case object URITemplate extends StringFormat { val value = "uri-template" } + case object UUID extends StringFormat { val value = "uuid" } + + def fromString(string: java.lang.String): StringFormat = + string match { + case "date" => Date + case "date-time" => DateTime + case "duration" => Duration + case "email" => Email + case "hostname" => Hostname + case "idn-email" => IdnEmail + case "idn-hostname" => IdnHostname + case "ipv4" => IPv4 + case "ipv6" => IPv6 + case "iri" => IRI + case "iri-reference" => IRIReference + case "json-pointer" => JSONPointer + case "password" => Password + case "regex" => Regex + case "relative-json-pointer" => RelativeJSONPointer + case "time" => Time + case "uri" => URI + case "uri-reference" => URIRef + case "uri-template" => URITemplate + case "uuid" => UUID + case value => Custom(value) + } + } + + final case class Pattern(value: java.lang.String) extends Product with Serializable + + case object Boolean extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("boolean"))) + } + + final case class ArrayType(items: Option[JsonSchema]) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("array")), + items = items.map(_.toSerializableSchema), + ) + } + + final case class Object( + properties: Map[java.lang.String, JsonSchema], + additionalProperties: Either[Boolean, JsonSchema], + required: Chunk[java.lang.String], + ) extends JsonSchema { + def addAll(value: Chunk[(java.lang.String, JsonSchema)]): Object = + value.foldLeft(this) { case (obj, (name, schema)) => + schema match { + case Object(properties, additionalProperties, required) => + obj.copy( + properties = obj.properties ++ properties, + additionalProperties = combineAdditionalProperties(obj.additionalProperties, additionalProperties), + required = obj.required ++ required, + ) + case schema => obj.copy(properties = obj.properties + (name -> schema)) + } + } + + def required(required: Chunk[java.lang.String]): Object = + this.copy(required = required) + + private def combineAdditionalProperties( + left: Either[Boolean, JsonSchema], + right: Either[Boolean, JsonSchema], + ): Either[Boolean, JsonSchema] = + (left, right) match { + case (Left(false), _) => Left(false) + case (_, Left(_)) => left + case (Left(true), _) => right + case (Right(left), Right(right)) => + Right(AllOfSchema(Chunk(left, right))) + } + + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = { + val additionalProperties = this.additionalProperties match { + case Left(true) => Some(BoolOrSchema.BooleanWrapper(true)) + case Left(false) => Some(BoolOrSchema.BooleanWrapper(false)) + case Right(schema) => Some(BoolOrSchema.SchemaWrapper(schema.toSerializableSchema)) + } + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("object")), + properties = Some(properties.map { case (name, schema) => name -> schema.toSerializableSchema }), + additionalProperties = additionalProperties, + required = if (required.isEmpty) None else Some(required), + ) + } + } + + object Object { + val empty: JsonSchema.Object = JsonSchema.Object(Map.empty, Left(true), Chunk.empty) + } + + final case class Enum(values: Chunk[EnumValue]) extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("string")), + enumValues = Some(values.map(_.toJson)), + ) + } + + @noDiscriminator + sealed trait EnumValue { self => + def toJson: Json = self match { + case EnumValue.Bool(value) => Json.Bool(value) + case EnumValue.Str(value) => Json.Str(value) + case EnumValue.Num(value) => Json.Num(value) + case EnumValue.Null => Json.Null + case EnumValue.SchemaValue(value) => + Json.decoder + .decodeJson(value.toJson) + .getOrElse(throw new IllegalArgumentException(s"Can't convert $self")) + } + } + + object EnumValue { + + def fromJson(json: Json): EnumValue = + json match { + case Json.Str(value) => Str(value) + case Json.Num(value) => Num(value) + case Json.Bool(value) => Bool(value) + case Json.Null => Null + case other => + SchemaValue( + JsonSchema.codec + .decode(Chunk.fromArray(other.toString().getBytes)) + .getOrElse(throw new IllegalArgumentException(s"Can't convert $json")), + ) + } + + final case class SchemaValue(value: JsonSchema) extends EnumValue + final case class Bool(value: Boolean) extends EnumValue + final case class Str(value: java.lang.String) extends EnumValue + final case class Num(value: BigDecimal) extends EnumValue + case object Null extends EnumValue + + } + + case object Null extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema( + schemaType = Some(TypeOrTypes.Type("null")), + ) + } + + case object AnyJson extends JsonSchema { + override protected[openapi] def toSerializableSchema: SerializableJsonSchema = + SerializableJsonSchema() + } + +} diff --git a/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPI.scala b/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPI.scala index 213ca93d2f..d1e964f03e 100644 --- a/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPI.scala +++ b/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPI.scala @@ -20,78 +20,232 @@ import java.net.URI import scala.util.matching.Regex -import zio.NonEmptyChunk -import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.Chunk +import zio.json.ast._ + +import zio.schema._ +import zio.schema.annotation.{fieldName, noDiscriminator} +import zio.schema.codec.JsonCodec +import zio.schema.codec.json._ import zio.http.Status import zio.http.codec.Doc -import zio.http.endpoint.openapi -import zio.http.endpoint.openapi.JsonRenderer._ +import zio.http.endpoint.openapi.OpenAPI.SecurityScheme.SecurityRequirement + +/** + * This is the root document object of the OpenAPI document. + * + * @param openapi + * This string MUST be the semantic version number of the OpenAPI + * Specification version that the OpenAPI document uses. The openapi field + * SHOULD be used by tooling specifications and clients to interpret the + * OpenAPI document. This is not related to the API info.version string. + * @param info + * Provides metadata about the API. The metadata MAY be used by tooling as + * required. + * @param servers + * A List of Server Objects, which provide connectivity information to a + * target server. If the servers property is empty, the default value would be + * a Server Object with a url value of /. + * @param paths + * The available paths and operations for the API. + * @param components + * An element to hold various schemas for the specification. + * @param security + * A declaration of which security mechanisms can be used across the API. The + * list of values includes alternative security requirement objects that can + * be used. Only one of the security requirement objects need to be satisfied + * to authorize a request. Individual operations can override this definition. + * To make security optional, an empty security requirement ({}) can be + * included in the List. + * @param tags + * A list of tags used by the specification with additional metadata. The + * order of the tags can be used to reflect on their order by the parsing + * tools. Not all tags that are used by the Operation Object must be declared. + * The tags that are not declared MAY be organized randomly or based on the + * tools’ logic. Each tag name in the list MUST be unique. + * @param externalDocs + * Additional external documentation. + */ +final case class OpenAPI( + openapi: String, + info: OpenAPI.Info, + servers: List[OpenAPI.Server] = List.empty, + paths: Map[OpenAPI.Path, OpenAPI.PathItem] = Map.empty, + components: Option[OpenAPI.Components], + security: List[SecurityRequirement] = List.empty, + tags: List[OpenAPI.Tag] = List.empty, + externalDocs: Option[OpenAPI.ExternalDoc], +) { + def ++(other: OpenAPI): OpenAPI = OpenAPI( + openapi = openapi, + info = info, + servers = servers ++ other.servers, + paths = mergePaths(paths, other.paths), + components = (components.toSeq ++ other.components).reduceOption(_ ++ _), + security = security ++ other.security, + tags = tags ++ other.tags, + externalDocs = externalDocs, + ) + + private def mergePaths(paths: Map[OpenAPI.Path, OpenAPI.PathItem]*): Map[OpenAPI.Path, OpenAPI.PathItem] = + paths + .foldRight[Seq[(OpenAPI.Path, OpenAPI.PathItem)]](Seq.empty)((z, p) => z.toSeq ++ p) + .groupBy(_._1) + .map { case (path, pathItems) => + val pathItem = pathItems.map(_._2).reduce { (i, j) => + i.copy( + get = i.get.orElse(j.get), + put = i.put.orElse(j.put), + post = i.post.orElse(j.post), + delete = i.delete.orElse(j.delete), + options = i.options.orElse(j.options), + head = i.head.orElse(j.head), + patch = i.patch.orElse(j.patch), + trace = i.trace.orElse(j.trace), + ) + } + (path, pathItem) + } + + def path(path: OpenAPI.Path, pathItem: OpenAPI.PathItem): OpenAPI = + copy(paths = mergePaths(Map(path -> pathItem), paths)) -private[openapi] sealed trait OpenAPIBase { - self => - def toJson: String + def toJson: String = + JsonCodec + .jsonEncoder(JsonCodec.Config(ignoreEmptyCollections = true))(OpenAPI.schema) + .encodeJson(this, None) + .toString + + def toJsonPretty: String = + JsonCodec + .jsonEncoder(JsonCodec.Config(ignoreEmptyCollections = true))(OpenAPI.schema) + .encodeJson(this, Some(0)) + .toString + + def title(title: String): OpenAPI = copy(info = info.copy(title = title)) + + def version(version: String): OpenAPI = copy(info = info.copy(version = version)) } object OpenAPI { - /** - * This is the root document object of the OpenAPI document. - * - * @param openapi - * This string MUST be the semantic version number of the OpenAPI - * Specification version that the OpenAPI document uses. The openapi field - * SHOULD be used by tooling specifications and clients to interpret the - * OpenAPI document. This is not related to the API info.version string. - * @param info - * Provides metadata about the API. The metadata MAY be used by tooling as - * required. - * @param servers - * A List of Server Objects, which provide connectivity information to a - * target server. If the servers property is empty, the default value would - * be a Server Object with a url value of /. - * @param paths - * The available paths and operations for the API. - * @param components - * An element to hold various schemas for the specification. - * @param security - * A declaration of which security mechanisms can be used across the API. - * The list of values includes alternative security requirement objects that - * can be used. Only one of the security requirement objects need to be - * satisfied to authorize a request. Individual operations can override this - * definition. To make security optional, an empty security requirement ({}) - * can be included in the List. - * @param tags - * A list of tags used by the specification with additional metadata. The - * order of the tags can be used to reflect on their order by the parsing - * tools. Not all tags that are used by the Operation Object must be - * declared. The tags that are not declared MAY be organized randomly or - * based on the tools’ logic. Each tag name in the list MUST be unique. - * @param externalDocs - * Additional external documentation. - */ - final case class OpenAPI( - openapi: String, - info: Info, - servers: List[Server], - paths: Paths, - components: Option[Components], - security: List[SecurityRequirement], - tags: List[Tag], - externalDocs: Option[ExternalDoc], - ) extends OpenAPIBase { - def toJson: String = - JsonRenderer.renderFields( - "openapi" -> openapi, - "info" -> info, - "servers" -> servers, - "paths" -> paths, - "components" -> components, - "security" -> security, - "tags" -> tags, - "externalDocs" -> externalDocs, + implicit val schema: Schema[OpenAPI] = + DeriveSchema.gen[OpenAPI] + + def fromJson(json: String): Either[String, OpenAPI] = + JsonCodec + .jsonDecoder(OpenAPI.schema) + .decodeJson(json) + + def empty: OpenAPI = OpenAPI( + openapi = "3.1.0", + info = Info( + title = "", + description = None, + termsOfService = None, + contact = None, + license = None, + version = "", + ), + servers = List.empty, + paths = Map.empty, + components = None, + security = List.empty, + tags = List.empty, + externalDocs = None, + ) + + implicit def statusSchema: Schema[Status] = + zio.schema + .Schema[String] + .transform[Status]( + s => Status.fromInt(s.toInt), + p => p.text, + ) + + implicit def pathMapSchema: Schema[Map[Path, PathItem]] = + DeriveSchema + .gen[Map[String, PathItem]] + .transformOrFail( + m => { + val it = m.iterator + var transformed = Map.empty[Path, PathItem] + var error: Left[String, Map[Path, PathItem]] = null + while (it.hasNext && error == null) { + val (k, v) = it.next() + Path.fromString(k) match { + case Some(path) => transformed += path -> v + case None => error = Left(s"Invalid path: $k") + } + } + if (error != null) error + else Right(transformed) + }, + (m: Map[Path, PathItem]) => Right(m.map { case (k, v) => k.name -> v }), + ) + + implicit def keyMapSchema[T](implicit + schema: Schema[T], + ): Schema[Map[Key, T]] = + Schema + .map[String, T] + .transformOrFail( + m => { + val it = m.iterator + var transformed = Map.empty[Key, T] + var error: Left[String, Map[Key, T]] = null + while (it.hasNext && error == null) { + val (k, v) = it.next() + Key.fromString(k) match { + case Some(key) => transformed += key -> v + case None => error = Left(s"Invalid key: $k") + } + } + if (error != null) error + else Right(transformed) + }, + (m: Map[Key, T]) => Right(m.map { case (k, v) => k.name -> v }), + ) + + implicit def statusMapSchema[T](implicit + schema: Schema[T], + ): Schema[Map[StatusOrDefault, T]] = + Schema + .map[String, T] + .transformOrFail( + m => { + val it = m.iterator + var transformed = Map.empty[StatusOrDefault, T] + var error: Left[String, Map[StatusOrDefault, T]] = null + while (it.hasNext && error == null) { + val (k, v) = it.next() + if (k == "default") transformed += StatusOrDefault.Default -> v + else { + zio.http.Status.fromString(k) match { + case Some(key) => transformed += StatusOrDefault.StatusValue(key) -> v + case None => error = Left(s"Invalid status: $k") + } + } + } + if (error != null) error + else Right(transformed) + }, + (m: Map[StatusOrDefault, T]) => Right(m.map { case (k, v) => k.text -> v }), + ) + + implicit def mediaTypeTupleSchema: Schema[(String, MediaType)] = + zio.schema + .Schema[Map[String, MediaType]] + .transformOrFail( + m => { + if (m.size == 1) { + val (k, v) = m.head + Right((k, v)) + } else Left("Invalid media type") + }, + t => Right(Map(t._1 -> t._2)), ) - } /** * Allows referencing an external resource for extended documentation. @@ -102,9 +256,7 @@ object OpenAPI { * @param url * The URL for the target documentation. */ - final case class ExternalDoc(description: Option[Doc], url: URI) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields("description" -> description, "url" -> url) - } + final case class ExternalDoc(description: Option[Doc], url: URI) /** * The object provides metadata about the API. The metadata MAY be used by the @@ -127,21 +279,12 @@ object OpenAPI { */ final case class Info( title: String, - description: Doc, - termsOfService: URI, + description: Option[Doc], + termsOfService: Option[URI], contact: Option[Contact], license: Option[License], version: String, - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "title" -> title, - "description" -> description, - "termsOfService" -> termsOfService, - "contact" -> contact, - "license" -> license, - "version" -> version, - ) - } + ) /** * Contact information for the exposed API. @@ -154,9 +297,7 @@ object OpenAPI { * The email address of the contact person/organization. MUST be in the * format of an email address. */ - final case class Contact(name: Option[String], url: Option[URI], email: String) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields("name" -> name, "url" -> url, "email" -> email) - } + final case class Contact(name: Option[String], url: Option[URI], email: Option[String]) /** * License information for the exposed API. @@ -166,9 +307,7 @@ object OpenAPI { * @param url * A URL to the license used for the API. */ - final case class License(name: String, url: Option[URI]) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields("name" -> name, "url" -> url) - } + final case class License(name: String, url: Option[URI]) /** * An object representing a Server. @@ -184,14 +323,11 @@ object OpenAPI { * A map between a variable name and its value. The value is used for * substitution in the server’s URL template. */ - final case class Server(url: URI, description: Doc, variables: Map[String, ServerVariable]) - extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "url" -> url, - "description" -> description, - "variables" -> variables, - ) - } + final case class Server( + url: URI, + description: Option[Doc], + variables: Map[String, ServerVariable] = Map.empty, + ) /** * An object representing a Server Variable for server URL template @@ -209,13 +345,11 @@ object OpenAPI { * @param description * A description for the server variable. */ - final case class ServerVariable(`enum`: NonEmptyChunk[String], default: String, description: Doc) - extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "enum" -> `enum`, - "default" -> default, - "description" -> description, - ) + final case class ServerVariable(`enum`: Chunk[String], default: String, description: Doc) + + object ServerVariable { + implicit val schema: Schema[ServerVariable] = + DeriveSchema.gen[ServerVariable] } /** @@ -244,39 +378,45 @@ object OpenAPI { * An object to hold reusable Callback Objects. */ final case class Components( - schemas: Map[Key, SchemaOrReference], - responses: Map[Key, ResponseOrReference], - parameters: Map[Key, ParameterOrReference], - examples: Map[Key, ExampleOrReference], - requestBodies: Map[Key, RequestBodyOrReference], - headers: Map[Key, HeaderOrReference], - securitySchemes: Map[Key, SecuritySchemeOrReference], - links: Map[Key, LinkOrReference], - callbacks: Map[Key, CallbackOrReference], - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "schemas" -> schemas, - "responses" -> responses, - "parameters" -> parameters, - "examples" -> examples, - "requestBodies" -> requestBodies, - "headers" -> headers, - "securitySchemes" -> securitySchemes, - "links" -> links, - "callbacks" -> callbacks, + schemas: Map[Key, ReferenceOr[JsonSchema]] = Map.empty, + responses: Map[Key, ReferenceOr[Response]] = Map.empty, + parameters: Map[Key, ReferenceOr[Parameter]] = Map.empty, + examples: Map[Key, ReferenceOr[Example]] = Map.empty, + requestBodies: Map[Key, ReferenceOr[RequestBody]] = Map.empty, + headers: Map[Key, ReferenceOr[Header]] = Map.empty, + securitySchemes: Map[Key, ReferenceOr[SecurityScheme]] = Map.empty, + links: Map[Key, ReferenceOr[Link]] = Map.empty, + callbacks: Map[Key, ReferenceOr[Callback]] = Map.empty, + ) { + def ++(other: Components): Components = Components( + schemas = schemas ++ other.schemas, + responses = responses ++ other.responses, + parameters = parameters ++ other.parameters, + examples = examples ++ other.examples, + requestBodies = requestBodies ++ other.requestBodies, + headers = headers ++ other.headers, + securitySchemes = securitySchemes ++ other.securitySchemes, + links = links ++ other.links, + callbacks = callbacks ++ other.callbacks, ) } - sealed abstract case class Key private (name: String) extends openapi.OpenAPIBase { - override def toJson: String = name - } + sealed abstract case class Key private (name: String) object Key { + implicit val schema: Schema[Key] = + zio.schema + .Schema[String] + .transformOrFail[Key]( + s => fromString(s).toRight(s"Invalid Key $s"), + p => Right(p.name), + ) + /** * All Components objects MUST use Keys that match the regular expression. */ - val validName: Regex = "^[a-zA-Z0-9.\\-_]+$.".r + val validName: Regex = "^[a-zA-Z0-9.\\-_]+$".r def fromString(name: String): Option[Key] = name match { case validName() => Some(new Key(name) {}) @@ -303,16 +443,19 @@ object OpenAPI { * @param name * The field name of the relative path MUST begin with a forward slash (/). */ - sealed abstract case class Path private (name: String) extends openapi.OpenAPIBase { - override def toJson: String = name - } + case class Path private (name: String) object Path { + implicit val schema: Schema[Path] = Schema[String].transformOrFail[Path]( + s => fromString(s).toRight(s"Invalid Path $s"), + p => Right(p.name), + ) + // todo maybe not the best regex, but the old one was not working at all - val validPath: Regex = "/[a-zA-Z0-9\\-_\\{\\}]+".r + val validPath: Regex = """/[/a-zA-Z0-9\-_{}]*""".r def fromString(name: String): Option[Path] = name match { - case validPath() => Some(new Path(name) {}) + case validPath() => Some(Path(name)) case _ => None } } @@ -359,9 +502,9 @@ object OpenAPI { * components/parameters. */ final case class PathItem( - ref: String, - summary: String = "", - description: Doc, + @fieldName("$ref") ref: Option[String], + summary: Option[String], + description: Option[Doc], get: Option[Operation], put: Option[Operation], post: Option[Operation], @@ -370,23 +513,48 @@ object OpenAPI { head: Option[Operation], patch: Option[Operation], trace: Option[Operation], - servers: List[Server], - parameters: Set[ParameterOrReference], - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - s"$$ref" -> ref, - "summary" -> summary, - "description" -> description, - "get" -> get, - "put" -> put, - "post" -> post, - "delete" -> delete, - "options" -> options, - "head" -> head, - "patch" -> patch, - "trace" -> trace, - "servers" -> servers, - "parameters" -> parameters, + servers: List[Server] = List.empty, + parameters: Set[ReferenceOr[Parameter]] = Set.empty, + ) { + def addGet(operation: Operation): PathItem = copy(get = Some(operation)) + def addPut(operation: Operation): PathItem = copy(put = Some(operation)) + def addPost(operation: Operation): PathItem = copy(post = Some(operation)) + def addDelete(operation: Operation): PathItem = copy(delete = Some(operation)) + def addOptions(operation: Operation): PathItem = copy(options = Some(operation)) + def addHead(operation: Operation): PathItem = copy(head = Some(operation)) + def addPatch(operation: Operation): PathItem = copy(patch = Some(operation)) + def addTrace(operation: Operation): PathItem = copy(trace = Some(operation)) + def any(operation: Operation): PathItem = + copy( + get = Some(operation), + put = Some(operation), + post = Some(operation), + delete = Some(operation), + options = Some(operation), + head = Some(operation), + patch = Some(operation), + trace = Some(operation), + ) + } + + object PathItem { + implicit val schema: Schema[PathItem] = + DeriveSchema.gen[PathItem] + + val empty: PathItem = PathItem( + ref = None, + summary = None, + description = None, + get = None, + put = None, + post = None, + delete = None, + options = None, + head = None, + patch = None, + trace = None, + servers = List.empty, + parameters = Set.empty, ) } @@ -445,98 +613,61 @@ object OpenAPI { * be overridden by this value. */ final case class Operation( - tags: List[String], - summary: String = "", - description: Doc, + tags: List[String] = List.empty, + summary: Option[String], + description: Option[Doc], externalDocs: Option[ExternalDoc], operationId: Option[String], - parameters: Set[ParameterOrReference], - requestBody: Option[RequestBodyOrReference], - responses: Responses, - callbacks: Map[String, CallbackOrReference], + parameters: Set[ReferenceOr[Parameter]] = Set.empty, + requestBody: Option[ReferenceOr[RequestBody]], + responses: Map[StatusOrDefault, ReferenceOr[Response]] = Map.empty, + callbacks: Map[String, ReferenceOr[Callback]] = Map.empty, deprecated: Boolean = false, - security: List[SecurityRequirement], - servers: List[Server], - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "tags" -> tags, - "summary" -> summary, - "description" -> description, - "externalDocs" -> externalDocs, - "operationId" -> operationId, - "parameters" -> parameters, - "requestBody" -> requestBody, - "responses" -> responses, - "callbacks" -> callbacks, - "deprecated" -> deprecated, - "security" -> security, - "servers" -> servers, - ) - } - - sealed trait ParameterOrReference extends openapi.OpenAPIBase + security: List[SecurityRequirement] = List.empty, + servers: List[Server] = List.empty, + ) /** * Describes a single operation parameter. */ - sealed trait Parameter extends ParameterOrReference { - def name: String - def in: String - def description: Doc - def required: Boolean - def deprecated: Boolean - def allowEmptyValue: Boolean - def definition: Parameter.Definition - def explode: Boolean - def examples: Map[String, ExampleOrReference] - - /** - * A unique parameter is defined by a combination of a name and location. - */ + final case class Parameter( + name: String, + in: String, + description: Option[Doc], + required: Boolean = false, + deprecated: Boolean = false, + schema: Option[ReferenceOr[JsonSchema]], + explode: Boolean = false, + examples: Map[String, ReferenceOr[Example]] = Map.empty, + allowReserved: Option[Boolean], + style: Option[String], + content: Option[(String, MediaType)], + ) { override def equals(obj: Any): Boolean = obj match { - case p: Parameter.QueryParameter if name == p.name && in == p.in => true - case _ => false + case p: Parameter if name == p.name && in == p.in => true + case _ => false } - - override def toJson: String = - JsonRenderer.renderFields( - "name" -> name, - "in" -> in, - "description" -> description, - "required" -> required, - "deprecated" -> deprecated, - "allowEmptyValue" -> allowEmptyValue, - "definition" -> definition, - "explode" -> explode, - "examples" -> examples, - ) } object Parameter { - sealed trait Definition extends SchemaOrReference - object Definition { - final case class Content(key: String, mediaType: String) extends Definition { - override def toJson: String = JsonRenderer.renderFields( - "key" -> key, - "mediaType" -> mediaType, - ) - } - } + implicit val schema: Schema[Parameter] = + DeriveSchema.gen[Parameter] - sealed trait PathStyle + final case class Content(key: String, mediaType: MediaType) + sealed trait PathStyle sealed trait QueryStyle - object QueryStyle { + object Style { case object Matrix extends PathStyle case object Label extends PathStyle - case object Simple extends PathStyle - case object Form extends QueryStyle + case object Simple extends PathStyle + case object SpaceDelimited extends QueryStyle case object PipeDelimited extends QueryStyle @@ -555,27 +686,35 @@ object OpenAPI { * @param deprecated * Specifies that a parameter is deprecated and SHOULD be transitioned out * of usage. - * @param allowEmptyValue - * Sets the ability to pass empty-valued parameters. This is valid only - * for query parameters and allows sending a parameter with an empty - * value. If style is used, and if behavior is n/a (cannot be serialized), - * the value of allowEmptyValue SHALL be ignored. Use of this property is - * NOT RECOMMENDED, as it is likely to be removed in a later revision. */ - final case class QueryParameter( + def queryParameter( name: String, - description: Doc, + description: Option[Doc], + schema: Option[ReferenceOr[JsonSchema]], + examples: Map[String, ReferenceOr[Example]], deprecated: Boolean = false, - allowEmptyValue: Boolean = false, - definition: Definition, - allowReserved: Boolean = false, - style: QueryStyle = QueryStyle.Form, explode: Boolean = true, - examples: Map[String, ExampleOrReference], - ) extends Parameter { - def in: String = "query" - def required: Boolean = true - } + required: Boolean = false, + allowReserved: Boolean = false, + style: QueryStyle = Style.Form, + ): Parameter = Parameter( + name, + "query", + description, + required, + deprecated, + schema, + explode, + examples, + Some(allowReserved), + style = Some(style match { + case Style.Form => "form" + case Style.SpaceDelimited => "spaceDelimited" + case Style.PipeDelimited => "pipeDelimited" + case Style.DeepObject => "deepObject" + }), + None, + ) /** * Custom headers that are expected as part of the request. Note that @@ -590,26 +729,28 @@ object OpenAPI { * @param deprecated * Specifies that a parameter is deprecated and SHOULD be transitioned out * of usage. - * @param allowEmptyValue - * Sets the ability to pass empty-valued parameters. This is valid only - * for query parameters and allows sending a parameter with an empty - * value. If style is used, and if behavior is n/a (cannot be serialized), - * the value of allowEmptyValue SHALL be ignored. Use of this property is - * NOT RECOMMENDED, as it is likely to be removed in a later revision. */ - final case class HeaderParameter( + def headerParameter( name: String, - description: Doc, + description: Option[Doc], required: Boolean, deprecated: Boolean = false, - allowEmptyValue: Boolean = false, - definition: Definition, + definition: Option[ReferenceOr[JsonSchema]] = None, explode: Boolean = false, - examples: Map[String, ExampleOrReference], - ) extends Parameter { - def in: String = "header" - def style: String = "simple" - } + examples: Map[String, ReferenceOr[Example]], + ): Parameter = Parameter( + name, + "header", + description, + required, + deprecated, + definition, + explode, + examples, + allowReserved = None, + style = Some("simple"), + None, + ) /** * Used together with Path Templating, where the parameter value is actually @@ -621,31 +762,35 @@ object OpenAPI { * The name of the parameter. Parameter names are case sensitive. * @param description * A brief description of the parameter. - * @param required - * Determines whether this parameter is mandatory. * @param deprecated * Specifies that a parameter is deprecated and SHOULD be transitioned out * of usage. - * @param allowEmptyValue - * Sets the ability to pass empty-valued parameters. This is valid only - * for query parameters and allows sending a parameter with an empty - * value. If style is used, and if behavior is n/a (cannot be serialized), - * the value of allowEmptyValue SHALL be ignored. Use of this property is - * NOT RECOMMENDED, as it is likely to be removed in a later revision. */ - final case class PathParameter( + def pathParameter( name: String, - description: Doc, - required: Boolean, + description: Option[Doc], deprecated: Boolean = false, - allowEmptyValue: Boolean = false, - definition: Definition, - style: PathStyle = QueryStyle.Simple, + definition: Option[ReferenceOr[JsonSchema]] = None, + style: PathStyle = Style.Simple, explode: Boolean = false, - examples: Map[String, ExampleOrReference], - ) extends Parameter { - def in: String = "path" - } + examples: Map[String, ReferenceOr[Example]], + ): Parameter = Parameter( + name, + "path", + description, + required = true, + deprecated, + definition, + explode, + examples, + allowReserved = None, + style = Some(style match { + case Style.Matrix => "matrix" + case Style.Label => "label" + case Style.Simple => "simple" + }), + None, + ) /** * Used to pass a specific cookie value to the API. @@ -659,47 +804,42 @@ object OpenAPI { * @param deprecated * Specifies that a parameter is deprecated and SHOULD be transitioned out * of usage. - * @param allowEmptyValue - * Sets the ability to pass empty-valued parameters. This is valid only - * for query parameters and allows sending a parameter with an empty - * value. If style is used, and if behavior is n/a (cannot be serialized), - * the value of allowEmptyValue SHALL be ignored. Use of this property is - * NOT RECOMMENDED, as it is likely to be removed in a later revision. */ - final case class CookieParameter( + def cookieParameter( name: String, - description: Doc, + description: Option[Doc], required: Boolean, deprecated: Boolean = false, - allowEmptyValue: Boolean = false, - definition: Definition, + definition: Option[ReferenceOr[JsonSchema]] = None, explode: Boolean = false, - examples: Map[String, ExampleOrReference], - ) extends Parameter { - def in: String = "cookie" - def style: String = "form" - } + examples: Map[String, ReferenceOr[Example]], + ): Parameter = Parameter( + name, + "cookie", + description, + required, + deprecated, + definition, + explode, + examples, + allowReserved = None, + style = Some("form"), + None, + ) } - sealed trait HeaderOrReference extends openapi.OpenAPIBase - final case class Header( - description: Doc, - required: Boolean, - deprecate: Boolean = false, + description: Option[Doc], + required: Boolean = false, + deprecated: Boolean = false, allowEmptyValue: Boolean = false, - content: (String, MediaType), - ) extends HeaderOrReference { - override def toJson: String = JsonRenderer.renderFields( - "description" -> description, - "required" -> required, - "deprecated" -> deprecate, - "allowEmptyValue" -> allowEmptyValue, - "content" -> content, - ) - } + schema: Option[JsonSchema], + ) - sealed trait RequestBodyOrReference extends openapi.OpenAPIBase + object Header { + implicit val schema: Schema[Header] = + DeriveSchema.gen[Header] + } /** * Describes a single request body. @@ -714,13 +854,15 @@ object OpenAPI { * @param required * Determines if the request body is required in the request. */ - final case class RequestBody(description: Doc, content: Map[String, MediaType], required: Boolean = false) - extends ResponseOrReference { - override def toJson: String = JsonRenderer.renderFields( - "description" -> description, - "content" -> content, - "required" -> required, - ) + final case class RequestBody( + description: Option[Doc] = None, + content: Map[String, MediaType] = Map.empty, + required: Boolean = false, + ) + + object RequestBody { + implicit val schema: Schema[RequestBody] = + DeriveSchema.gen[RequestBody] } /** @@ -741,15 +883,14 @@ object OpenAPI { * type is multipart or application/x-www-form-urlencoded. */ final case class MediaType( - schema: SchemaOrReference, - examples: Map[String, ExampleOrReference], - encoding: Map[String, Encoding], - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "schema" -> schema, - "examples" -> examples, - "encoding" -> encoding, - ) + schema: ReferenceOr[JsonSchema], + examples: Map[String, ReferenceOr[Example]] = Map.empty, + encoding: Map[String, Encoding] = Map.empty, + ) + + object MediaType { + implicit val schema: Schema[MediaType] = + DeriveSchema.gen[MediaType] } /** @@ -780,18 +921,15 @@ object OpenAPI { */ final case class Encoding( contentType: String, - headers: Map[String, HeaderOrReference], + headers: Map[String, ReferenceOr[Header]] = Map.empty, style: String = "form", explode: Boolean, allowReserved: Boolean = false, - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "contentType" -> contentType, - "headers" -> headers, - "style" -> style, - "explode" -> explode, - "allowReserved" -> allowReserved, - ) + ) + + object Encoding { + implicit val schema: Schema[Encoding] = + DeriveSchema.gen[Encoding] } /** @@ -800,9 +938,38 @@ object OpenAPI { * contain at least one response code, and it SHOULD be the response for a * successful operation call. */ - type Responses = Map[Status, ResponseOrReference] + type Responses = Map[StatusOrDefault, ReferenceOr[Response]] - sealed trait ResponseOrReference extends openapi.OpenAPIBase + sealed trait StatusOrDefault extends Product with Serializable { + def text: String + } + + object StatusOrDefault { + case class StatusValue(status: Status) extends StatusOrDefault { + override def text: String = status.text + } + + object StatusValue { + implicit val schema: Schema[StatusValue] = + zio.schema + .Schema[Status] + .transformOrFail[StatusValue]( + s => Right(StatusValue(s)), + p => Right(p.status), + ) + } + case object Default extends StatusOrDefault { + implicit val schema: Schema[Default.type] = + zio.schema + .Schema[String] + .transformOrFail[Default.type]( + s => if (s == "default") Right(Default) else Left("Invalid default status"), + _ => Right("default"), + ) + + override def text: String = "default" + } + } /** * Describes a single response from an API Operation, including design-time, @@ -825,21 +992,17 @@ object OpenAPI { * of the names for Component Objects. */ final case class Response( - description: Doc, - headers: Map[String, HeaderOrReference], - content: Map[String, MediaType], - links: Map[String, LinkOrReference], - ) extends ResponseOrReference { - override def toJson: String = JsonRenderer.renderFields( - "description" -> description, - "headers" -> headers, - "content" -> content, - "links" -> links, - ) + description: Doc = Doc.Empty, + headers: Map[String, ReferenceOr[Header]] = Map.empty, + content: Map[String, MediaType] = Map.empty, + links: Map[String, ReferenceOr[Link]] = Map.empty, + ) + + object Response { + implicit val schema: Schema[Response] = + DeriveSchema.gen[Response] } - sealed trait CallbackOrReference extends openapi.OpenAPIBase - /** * A map of possible out-of band callbacks related to the parent operation. * Each value in the map is a Path Item Object that describes a set of @@ -852,16 +1015,12 @@ object OpenAPI { * A Path Item Object used to define a callback request and expected * responses. */ - final case class Callback(expressions: Map[String, PathItem]) extends CallbackOrReference { - override def toJson: String = { - val toRender = expressions.foldLeft(List.empty[(String, Renderer[PathItem])]) { case (acc, (k, v)) => - (k, v: Renderer[PathItem]) :: acc - } - JsonRenderer.renderFields(toRender: _*) - } - } + final case class Callback(expressions: Map[String, PathItem] = Map.empty) - sealed trait ExampleOrReference extends openapi.OpenAPIBase + object Callback { + implicit val schema: Schema[Callback] = + DeriveSchema.gen[Callback] + } /** * In all cases, the example value is expected to be compatible with the type @@ -878,16 +1037,19 @@ object OpenAPI { * reference examples that cannot easily be included in JSON or YAML * documents. */ - final case class Example(summary: String = "", description: Doc, externalValue: URI) extends ExampleOrReference { - override def toJson: String = JsonRenderer.renderFields( - "summary" -> summary, - "description" -> description, - "externalValue" -> externalValue, - ) + // There is currently no API to set the summary, description or externalValue + final case class Example( + value: Json, + summary: Option[String] = None, + description: Option[Doc] = None, + externalValue: Option[URI] = None, + ) + + object Example { + implicit val schema: Schema[Example] = + DeriveSchema.gen[Example] } - sealed trait LinkOrReference extends openapi.OpenAPIBase - /** * The Link object represents a possible design-time link for a response. The * presence of a link does not guarantee the caller’s ability to successfully @@ -923,30 +1085,53 @@ object OpenAPI { */ final case class Link( operationRef: URI, - parameters: Map[String, LiteralOrExpression], + parameters: Map[String, LiteralOrExpression] = Map.empty, requestBody: LiteralOrExpression, - description: Doc, + description: Option[Doc], server: Option[Server], - ) extends LinkOrReference { - override def toJson: String = JsonRenderer.renderFields( - "operationRef" -> operationRef, - "parameters" -> parameters, - "requestBody" -> requestBody, - "description" -> description, - "server" -> server, - ) + ) + + object Link { + implicit val schema: Schema[Link] = + DeriveSchema.gen[Link] } sealed trait LiteralOrExpression object LiteralOrExpression { - final case class NumberLiteral(value: Long) extends LiteralOrExpression - final case class DecimalLiteral(value: Double) extends LiteralOrExpression - final case class StringLiteral(value: String) extends LiteralOrExpression - final case class BooleanLiteral(value: Boolean) extends LiteralOrExpression - sealed abstract case class Expression private (value: String) extends LiteralOrExpression + implicit val schema: Schema[LiteralOrExpression] = + DeriveSchema.gen[LiteralOrExpression] + + final case class NumberLiteral(value: Long) extends LiteralOrExpression + + object NumberLiteral { + implicit val schema: Schema[NumberLiteral] = + Schema[Long].transform[NumberLiteral](s => NumberLiteral(s), p => p.value) + } + final case class DecimalLiteral(value: Double) extends LiteralOrExpression + + object DecimalLiteral { + implicit val schema: Schema[DecimalLiteral] = + Schema[Double].transform[DecimalLiteral](s => DecimalLiteral(s), p => p.value) + } + final case class StringLiteral(value: String) extends LiteralOrExpression + + object StringLiteral { + implicit val schema: Schema[StringLiteral] = + Schema[String].transform[StringLiteral](s => StringLiteral(s), p => p.value) + } + final case class BooleanLiteral(value: Boolean) extends LiteralOrExpression + + object BooleanLiteral { + implicit val schema: Schema[BooleanLiteral] = + Schema[Boolean].transform[BooleanLiteral](s => BooleanLiteral(s), p => p.value) + } + case class Expression(value: String) extends LiteralOrExpression object Expression { - private[openapi] def create(value: String): Expression = new Expression(value) {} + implicit val schema: Schema[Expression] = + Schema[String].transform[Expression](s => Expression.create(s), p => p.value) + + private[openapi] def create(value: String): Expression = Expression(value) } // TODO: maybe one could make a regex to validate the expression. For now just accept anything @@ -972,13 +1157,7 @@ object OpenAPI { * @param externalDocs * Additional external documentation for this tag. */ - final case class Tag(name: String, description: Doc, externalDocs: URI) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "name" -> name, - "description" -> description, - "externalDocs" -> externalDocs, - ) - } + final case class Tag(name: String, description: Option[Doc], externalDocs: Option[ExternalDoc]) /** * A simple object to allow referencing other components in the specification, @@ -987,130 +1166,40 @@ object OpenAPI { * @param ref * The reference string. */ - final case class Reference(ref: String) - extends SchemaOrReference - with ResponseOrReference - with ParameterOrReference - with ExampleOrReference - with RequestBodyOrReference - with HeaderOrReference - with SecuritySchemeOrReference - with LinkOrReference - with CallbackOrReference { - override def toJson: String = JsonRenderer.renderFields(s"$$ref" -> ref) - } - sealed trait SchemaOrReference extends openapi.OpenAPIBase - - sealed trait Schema extends openapi.OpenAPIBase with SchemaOrReference { - def nullable: Boolean - def discriminator: Option[Discriminator] - def readOnly: Boolean - def writeOnly: Boolean - def xml: Option[XML] - def externalDocs: URI - def example: String - def deprecated: Boolean - - override def toJson: String = - JsonRenderer.renderFields( - "nullable" -> nullable, - "discriminator" -> discriminator, - "readOnly" -> readOnly, - "writeOnly" -> writeOnly, - "xml" -> xml, - "externalDocs" -> externalDocs, - "example" -> example, - "deprecated" -> deprecated, - ) + @noDiscriminator + sealed trait ReferenceOr[+T] { + def asJsonSchema(implicit ev: T <:< JsonSchema): JsonSchema = this match { + case ReferenceOr.Reference(ref, summary, description) => + JsonSchema + .RefSchema(ref) + .description((summary.getOrElse(Doc.empty) + description.getOrElse(Doc.empty)).toCommonMark) + case ReferenceOr.Or(value) => ev(value) + } + } - object Schema { + object ReferenceOr { + implicit def schema[T: Schema]: Schema[ReferenceOr[T]] = + DeriveSchema.gen[ReferenceOr[T]] - /** - * The Schema Object allows the definition of input and output data types. - * - * Marked as readOnly. This means that it MAY be sent as part of a response - * but SHOULD NOT be sent as part of the request. If the property is in the - * required list, the required will take effect on the response only. - * - * @param nullable - * A true value adds "null" to the allowed type specified by the type - * keyword, only if type is explicitly defined within the same Schema - * Object. Other Schema Object constraints retain their defined behavior, - * and therefore may disallow the use of null as a value. A false value - * leaves the specified or default type unmodified. - * @param discriminator - * Adds support for polymorphism. The discriminator is an object name that - * is used to differentiate between other schemas which may satisfy the - * payload description. - * @param xml - * This MAY be used only on properties schemas. It has no effect on root - * schemas. Adds additional metadata to describe the XML representation of - * this property. - * @param externalDocs - * Additional external documentation for this schema. - * @param example - * A free-form property to include an example of an instance for this - * schema. - * @param deprecated - * Specifies that a schema is deprecated and SHOULD be transitioned out of - * usage. - */ - final case class ResponseSchema( - nullable: Boolean = false, - discriminator: Option[Discriminator], - xml: Option[XML], - externalDocs: URI, - example: String, - deprecated: Boolean = false, - ) extends Schema - with Parameter.Definition { - def readOnly: Boolean = true - def writeOnly: Boolean = false + final case class Reference( + @fieldName("$ref") ref: String, + summary: Option[Doc] = None, + description: Option[Doc] = None, + ) extends ReferenceOr[Nothing] + + object Reference { + implicit val schema: Schema[Reference] = + DeriveSchema.gen[Reference] } - /** - * The Schema Object allows the definition of input and output data types. - * - * Marked as writeOnly. This means that it MAY be sent as part of a request - * but SHOULD NOT be sent as part of the response. If the property is in the - * required list, the required will take effect on the request only. - * - * @param nullable - * A true value adds "null" to the allowed type specified by the type - * keyword, only if type is explicitly defined within the same Schema - * Object. Other Schema Object constraints retain their defined behavior, - * and therefore may disallow the use of null as a value. A false value - * leaves the specified or default type unmodified. - * @param discriminator - * Adds support for polymorphism. The discriminator is an object name that - * is used to differentiate between other schemas which may satisfy the - * payload description. - * @param xml - * This MAY be used only on properties schemas. It has no effect on root - * schemas. Adds additional metadata to describe the XML representation of - * this property. - * @param externalDocs - * Additional external documentation for this schema. - * @param example - * A free-form property to include an example of an instance for this - * schema. - * @param deprecated - * Specifies that a schema is deprecated and SHOULD be transitioned out of - * usage. - */ - final case class RequestSchema( - nullable: Boolean = false, - discriminator: Option[Discriminator], - xml: Option[XML], - externalDocs: URI, - example: String, - deprecated: Boolean = false, - ) extends Schema - with Parameter.Definition { - def readOnly: Boolean = false - def writeOnly: Boolean = true + final case class Or[T](value: T) extends ReferenceOr[T] + + object Or { + implicit def schema[T: Schema]: Schema[Or[T]] = + Schema[T].transform(Or(_), _.value) + } } @@ -1131,11 +1220,14 @@ object OpenAPI { * An object to hold mappings between payload values and schema names or * references. */ - final case class Discriminator(propertyName: String, mapping: Map[String, String]) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "propertyName" -> propertyName, - "mapping" -> mapping, - ) + final case class Discriminator( + propertyName: String, + mapping: Map[String, String] = Map.empty, + ) + + object Discriminator { + implicit val schema: Schema[Discriminator] = + DeriveSchema.gen[Discriminator] } /** @@ -1164,25 +1256,17 @@ object OpenAPI { * type being array (outside the items). */ final case class XML(name: String, namespace: URI, prefix: String, attribute: Boolean = false, wrapped: Boolean) - extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "name" -> name, - "namespace" -> namespace, - "prefix" -> prefix, - "attribute" -> attribute, - "wrapped" -> wrapped, - ) - } - sealed trait SecuritySchemeOrReference extends openapi.OpenAPIBase - - sealed trait SecurityScheme extends SecuritySchemeOrReference { + sealed trait SecurityScheme { def `type`: String - def description: Doc + def description: Option[Doc] } object SecurityScheme { + implicit val schema: Schema[SecurityScheme] = + DeriveSchema.gen[SecurityScheme] + /** * Defines an HTTP security scheme that can be used by the operations. * @@ -1193,28 +1277,18 @@ object OpenAPI { * @param in * The location of the API key. */ - final case class ApiKey(description: Doc, name: String, in: ApiKey.In) extends SecurityScheme { + final case class ApiKey(description: Option[Doc], name: String, in: ApiKey.In) extends SecurityScheme { override def `type`: String = "apiKey" - - override def toJson: String = - JsonRenderer.renderFields( - "type" -> `type`, - "description" -> description, - "name" -> name, - "in" -> in, - ) } object ApiKey { - sealed trait In extends openapi.OpenAPIBase { - self: Product => - override def toJson: String = - s""""${self.productPrefix.updated(0, self.productPrefix.charAt(0).toLower)}"""" - } + sealed trait In extends Product with Serializable object In { - case object Query extends In + case object Query extends In + case object Header extends In + case object Cookie extends In } } @@ -1231,16 +1305,10 @@ object OpenAPI { * Bearer tokens are usually generated by an authorization server, so this * information is primarily for documentation purposes. */ - final case class Http(description: Doc, scheme: String, bearerFormat: Option[String]) extends SecurityScheme { + final case class Http(description: Option[Doc], scheme: String, bearerFormat: Option[String]) + extends SecurityScheme { override def `type`: String = "http" - override def toJson: String = - JsonRenderer.renderFields( - "type" -> `type`, - "description" -> description, - "scheme" -> scheme, - "bearerFormat" -> bearerFormat, - ) } /** @@ -1250,15 +1318,9 @@ object OpenAPI { * An object containing configuration information for the flow types * supported. */ - final case class OAuth2(description: Doc, flows: OAuthFlows) extends SecurityScheme { + final case class OAuth2(description: Option[Doc], flows: OAuthFlows) extends SecurityScheme { override def `type`: String = "oauth2" - override def toJson: String = - JsonRenderer.renderFields( - "type" -> `type`, - "description" -> description, - "flows" -> flows, - ) } /** @@ -1267,165 +1329,124 @@ object OpenAPI { * @param openIdConnectUrl * OpenId Connect URL to discover OAuth2 configuration values. */ - final case class OpenIdConnect(description: Doc, openIdConnectUrl: URI) extends SecurityScheme { + final case class OpenIdConnect(description: Option[Doc], openIdConnectUrl: URI) extends SecurityScheme { override def `type`: String = "openIdConnect" - override def toJson: String = - JsonRenderer.renderFields( - "type" -> `type`, - "description" -> description, - "openIdConnectUrl" -> openIdConnectUrl, - ) } - } - - /** - * Allows configuration of the supported OAuth Flows. - * - * @param `implicit` - * Configuration for the OAuth Implicit flow. - * @param password - * Configuration for the OAuth Resource Owner Password flow - * @param clientCredentials - * Configuration for the OAuth Client Credentials flow. Previously called - * application in OpenAPI 2.0. - * @param authorizationCode - * Configuration for the OAuth Authorization Code flow. Previously called - * accessCode in OpenAPI 2.0. - */ - final case class OAuthFlows( - `implicit`: Option[OAuthFlow.Implicit], - password: Option[OAuthFlow.Password], - clientCredentials: Option[OAuthFlow.ClientCredentials], - authorizationCode: Option[OAuthFlow.AuthorizationCode], - ) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "implicit" -> `implicit`, - "password" -> password, - "clientCredentials" -> clientCredentials, - "authorizationCode" -> authorizationCode, - ) - } - - sealed trait OAuthFlow extends openapi.OpenAPIBase { - def refreshUrl: Option[URI] - def scopes: Map[String, String] - } - - object OAuthFlow { /** - * Configuration for the OAuth Implicit flow. + * Allows configuration of the supported OAuth Flows. * - * @param authorizationUrl - * The authorization URL to be used for this flow. - * @param refreshUrl - * The URL to be used for obtaining refresh tokens. - * @param scopes - * The available scopes for the OAuth2 security scheme. A map between the - * scope name and a short description for it. The map MAY be empty. + * @param `implicit` + * Configuration for the OAuth Implicit flow. + * @param password + * Configuration for the OAuth Resource Owner Password flow + * @param clientCredentials + * Configuration for the OAuth Client Credentials flow. Previously called + * application in OpenAPI 2.0. + * @param authorizationCode + * Configuration for the OAuth Authorization Code flow. Previously called + * accessCode in OpenAPI 2.0. */ - final case class Implicit(authorizationUrl: URI, refreshUrl: Option[URI], scopes: Map[String, String]) - extends OAuthFlow { - override def toJson: String = JsonRenderer.renderFields( - "authorizationUrl" -> authorizationUrl, - "refreshUrl" -> refreshUrl, - "scopes" -> scopes, - ) - } + final case class OAuthFlows( + `implicit`: Option[OAuthFlow.Implicit], + password: Option[OAuthFlow.Password], + clientCredentials: Option[OAuthFlow.ClientCredentials], + authorizationCode: Option[OAuthFlow.AuthorizationCode], + ) - /** - * Configuration for the OAuth Authorization Code flow. Previously called - * accessCode in OpenAPI 2.0. - * - * @param authorizationUrl - * The authorization URL to be used for this flow. - * @param refreshUrl - * The URL to be used for obtaining refresh tokens. - * @param scopes - * The available scopes for the OAuth2 security scheme. A map between the - * scope name and a short description for it. The map MAY be empty. - * @param tokenUrl - * The token URL to be used for this flow. - */ - final case class AuthorizationCode( - authorizationUrl: URI, - refreshUrl: Option[URI], - scopes: Map[String, String], - tokenUrl: URI, - ) extends OAuthFlow { - override def toJson: String = JsonRenderer.renderFields( - "authorizationUrl" -> authorizationUrl, - "refreshUrl" -> refreshUrl, - "scopes" -> scopes, - "tokenUrl" -> tokenUrl, - ) + sealed trait OAuthFlow { + def refreshUrl: Option[URI] + + def scopes: Map[String, String] } - /** - * Configuration for the OAuth Resource Owner Password flow. - * - * @param refreshUrl - * The URL to be used for obtaining refresh tokens. - * @param scopes - * The available scopes for the OAuth2 security scheme. A map between the - * scope name and a short description for it. The map MAY be empty. - * @param tokenUrl - * The token URL to be used for this flow. - */ - final case class Password(refreshUrl: Option[URI], scopes: Map[String, String], tokenUrl: URI) extends OAuthFlow { - override def toJson: String = JsonRenderer.renderFields( - "refreshUrl" -> refreshUrl, - "scopes" -> scopes, - "tokenUrl" -> tokenUrl, - ) + object OAuthFlow { + + /** + * Configuration for the OAuth Implicit flow. + * + * @param authorizationUrl + * The authorization URL to be used for this flow. + * @param refreshUrl + * The URL to be used for obtaining refresh tokens. + * @param scopes + * The available scopes for the OAuth2 security scheme. A map between + * the scope name and a short description for it. The map MAY be empty. + */ + final case class Implicit(authorizationUrl: URI, refreshUrl: Option[URI], scopes: Map[String, String]) + extends OAuthFlow + + /** + * Configuration for the OAuth Authorization Code flow. Previously called + * accessCode in OpenAPI 2.0. + * + * @param authorizationUrl + * The authorization URL to be used for this flow. + * @param refreshUrl + * The URL to be used for obtaining refresh tokens. + * @param scopes + * The available scopes for the OAuth2 security scheme. A map between + * the scope name and a short description for it. The map MAY be empty. + * @param tokenUrl + * The token URL to be used for this flow. + */ + final case class AuthorizationCode( + authorizationUrl: URI, + refreshUrl: Option[URI], + scopes: Map[String, String], + tokenUrl: URI, + ) extends OAuthFlow + + /** + * Configuration for the OAuth Resource Owner Password flow. + * + * @param refreshUrl + * The URL to be used for obtaining refresh tokens. + * @param scopes + * The available scopes for the OAuth2 security scheme. A map between + * the scope name and a short description for it. The map MAY be empty. + * @param tokenUrl + * The token URL to be used for this flow. + */ + final case class Password(refreshUrl: Option[URI], scopes: Map[String, String], tokenUrl: URI) extends OAuthFlow + + /** + * Configuration for the OAuth Client Credentials flow. Previously called + * application in OpenAPI 2.0. + * + * @param refreshUrl + * The URL to be used for obtaining refresh tokens. + * @param scopes + * The available scopes for the OAuth2 security scheme. A map between + * the scope name and a short description for it. The map MAY be empty. + * @param tokenUrl + * The token URL to be used for this flow. + */ + final case class ClientCredentials(refreshUrl: Option[URI], scopes: Map[String, String], tokenUrl: URI) + extends OAuthFlow {} } /** - * Configuration for the OAuth Client Credentials flow. Previously called - * application in OpenAPI 2.0. + * Lists the required security schemes to execute this operation. The name + * used for each property MUST correspond to a security scheme declared in + * the Security Schemes under the Components Object. + * + * Security Requirement Objects that contain multiple schemes require that + * all schemes MUST be satisfied for a request to be authorized. This + * enables support for scenarios where multiple query parameters or HTTP + * headers are required to convey security information. + * + * When a list of Security Requirement Objects is defined on the OpenAPI + * Object or Operation Object, only one of the Security Requirement Objects + * in the list needs to be satisfied to authorize the request. * - * @param refreshUrl - * The URL to be used for obtaining refresh tokens. - * @param scopes - * The available scopes for the OAuth2 security scheme. A map between the - * scope name and a short description for it. The map MAY be empty. - * @param tokenUrl - * The token URL to be used for this flow. + * @param securitySchemes + * If the security scheme is of type "oauth2" or "openIdConnect", then the + * value is a list of scope names required for the execution, and the list + * MAY be empty if authorization does not require a specified scope. For + * other security scheme types, the List MUST be empty. */ - final case class ClientCredentials(refreshUrl: Option[URI], scopes: Map[String, String], tokenUrl: URI) - extends OAuthFlow { - override def toJson: String = JsonRenderer.renderFields( - "refreshUrl" -> refreshUrl, - "scopes" -> scopes, - "tokenUrl" -> tokenUrl, - ) - } - } - - /** - * Lists the required security schemes to execute this operation. The name - * used for each property MUST correspond to a security scheme declared in the - * Security Schemes under the Components Object. - * - * Security Requirement Objects that contain multiple schemes require that all - * schemes MUST be satisfied for a request to be authorized. This enables - * support for scenarios where multiple query parameters or HTTP headers are - * required to convey security information. - * - * When a list of Security Requirement Objects is defined on the OpenAPI - * Object or Operation Object, only one of the Security Requirement Objects in - * the list needs to be satisfied to authorize the request. - * - * @param securitySchemes - * If the security scheme is of type "oauth2" or "openIdConnect", then the - * value is a list of scope names required for the execution, and the list - * MAY be empty if authorization does not require a specified scope. For - * other security scheme types, the List MUST be empty. - */ - final case class SecurityRequirement(securitySchemes: Map[String, List[String]]) extends openapi.OpenAPIBase { - override def toJson: String = JsonRenderer.renderFields( - "securitySchemes" -> securitySchemes, - ) + final case class SecurityRequirement(securitySchemes: Map[String, List[String]]) } } diff --git a/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala new file mode 100644 index 0000000000..c2460e942d --- /dev/null +++ b/zio-http/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -0,0 +1,819 @@ +package zio.http.endpoint.openapi + +import java.util.UUID + +import scala.annotation.tailrec +import scala.collection.{immutable, mutable} + +import zio.Chunk +import zio.json.EncoderOps +import zio.json.ast.Json + +import zio.schema.Schema.Record +import zio.schema.codec.JsonCodec +import zio.schema.{Schema, TypeId} + +import zio.http._ +import zio.http.codec.HttpCodec.Metadata +import zio.http.codec._ +import zio.http.endpoint._ +import zio.http.endpoint.openapi.JsonSchema.SchemaStyle + +object OpenAPIGen { + private val PathWildcard = "pathWildcard" + + private[openapi] def groupMap[A, K, B](chunk: Chunk[A])(key: A => K)(f: A => B): immutable.Map[K, Chunk[B]] = { + val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] + for (elem <- chunk) { + val k = key(elem) + val bldr = m.getOrElseUpdate(k, Chunk.newBuilder[B]) + bldr += f(elem) + } + class Result extends runtime.AbstractFunction1[(K, mutable.Builder[B, Chunk[B]]), Unit] { + var built = immutable.Map.empty[K, Chunk[B]] + + def apply(kv: (K, mutable.Builder[B, Chunk[B]])): Unit = + built = built.updated(kv._1, kv._2.result()) + } + val result = new Result + m.foreach(result) + result.built + } + + final case class MetaCodec[T](codec: T, annotations: Chunk[HttpCodec.Metadata[Any]]) { + lazy val docs: Doc = { + val annotatedDoc = annotations.foldLeft(Doc.empty) { + case (doc, HttpCodec.Metadata.Documented(nextDoc)) => doc + nextDoc + case (doc, _) => doc + } + val trailingPathDoc = codec.asInstanceOf[Any] match { + case SegmentCodec.Trailing => + Doc.p( + Doc.Span.bold("WARNING: This is wildcard path segment. There is no official OpenAPI support for this."), + ) + + Doc.p("Tools might URL encode this segment and it might not work as expected.") + case _ => + Doc.empty + } + annotatedDoc + trailingPathDoc + } + + lazy val docsOpt: Option[Doc] = if (docs.isEmpty) None else Some(docs) + + lazy val examples: Map[String, Any] = annotations.foldLeft(Map.empty[String, Any]) { + case (examples, HttpCodec.Metadata.Examples(nextExamples)) => examples ++ nextExamples + case (examples, _) => examples + } + + def examples(schema: Schema[_]): Map[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] = + examples.map { case (k, v) => + k -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(toJsonAst(schema, v))) + } + + def name: Option[String] = + codec match { + case value: SegmentCodec[_] => + value match { + case SegmentCodec.BoolSeg(name) => Some(name) + case SegmentCodec.IntSeg(name) => Some(name) + case SegmentCodec.LongSeg(name) => Some(name) + case SegmentCodec.Text(name) => Some(name) + case SegmentCodec.UUID(name) => Some(name) + case SegmentCodec.Trailing => Some(PathWildcard) + case _ => None + } + case _ => + findName(annotations) + } + + def required: Boolean = + !annotations.exists(_.isInstanceOf[HttpCodec.Metadata.Optional[_]]) + + def deprecated: Boolean = + annotations.exists(_.isInstanceOf[HttpCodec.Metadata.Deprecated[_]]) + } + final case class AtomizedMetaCodecs( + method: Chunk[MetaCodec[SimpleCodec[Method, _]]], + path: Chunk[MetaCodec[SegmentCodec[_]]], + query: Chunk[MetaCodec[HttpCodec.Query[_]]], + header: Chunk[MetaCodec[HttpCodec.Header[_]]], + content: Chunk[MetaCodec[HttpCodec.Atom[HttpCodecType.Content, _]]], + status: Chunk[MetaCodec[HttpCodec.Status[_]]], + ) { + def append(metaCodec: MetaCodec[_]): AtomizedMetaCodecs = metaCodec match { + case MetaCodec(codec: HttpCodec.Method[_], annotations) => + copy(method = + (method :+ MetaCodec(codec.codec, annotations)).asInstanceOf[Chunk[MetaCodec[SimpleCodec[Method, _]]]], + ) + case MetaCodec(_: SegmentCodec[_], _) => + copy(path = path :+ metaCodec.asInstanceOf[MetaCodec[SegmentCodec[_]]]) + case MetaCodec(_: HttpCodec.Query[_], _) => + copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_]]]) + case MetaCodec(_: HttpCodec.Header[_], _) => + copy(header = header :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Header[_]]]) + case MetaCodec(_: HttpCodec.Status[_], _) => + copy(status = status :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Status[_]]]) + case MetaCodec(_: HttpCodec.Content[_], _) => + copy(content = content :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Atom[HttpCodecType.Content, _]]]) + case MetaCodec(_: HttpCodec.ContentStream[_], _) => + copy(content = content :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Atom[HttpCodecType.Content, _]]]) + case _ => this + } + + def ++(that: AtomizedMetaCodecs): AtomizedMetaCodecs = + AtomizedMetaCodecs( + method ++ that.method, + path ++ that.path, + query ++ that.query, + header ++ that.header, + content ++ that.content, + status ++ that.status, + ) + + def contentExamples: Map[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] = + content.flatMap { + case mc @ MetaCodec(HttpCodec.Content(schema, _, _, _), _) => + mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(toJsonAst(schema, value))) + } + case mc @ MetaCodec(HttpCodec.ContentStream(schema, _, _, _), _) => + mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(toJsonAst(schema, value))) + } + case _ => + Map.empty[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] + }.toMap + + // in case of alternatives, + // the doc to the alternation is added to all sub elements of the alternatives. + // This is not ideal. But it is the best we can do. + // To get the doc that is only for the alternation, we take the intersection of all docs, + // since only the alternation doc is added to all sub elements. + def contentDocs: Doc = + content + .flatMap(_.docsOpt) + .map(_.flattened) + .reduceOption(_ intersect _) + .flatMap(_.reduceOption(_ + _)) + .getOrElse(Doc.empty) + + def optimize: AtomizedMetaCodecs = + AtomizedMetaCodecs( + method.materialize, + path.materialize, + query.materialize, + header.materialize, + content.materialize, + status.materialize, + ) + } + + object AtomizedMetaCodecs { + def empty: AtomizedMetaCodecs = AtomizedMetaCodecs( + method = Chunk.empty, + path = Chunk.empty, + query = Chunk.empty, + header = Chunk.empty, + content = Chunk.empty, + status = Chunk.empty, + ) + + def flatten[R, A](codec: HttpCodec[R, A]): AtomizedMetaCodecs = { + val atoms = flattenedAtoms(codec) + + val flattened = atoms + .foldLeft(AtomizedMetaCodecs.empty) { case (acc, atom) => + acc.append(atom) + } + .optimize + flattened + } + + private def flattenedAtoms[R, A]( + in: HttpCodec[R, A], + annotations: Chunk[HttpCodec.Metadata[Any]] = Chunk.empty, + ): Chunk[MetaCodec[_]] = + in match { + case HttpCodec.Combine(left, right, _) => + flattenedAtoms(left, annotations) ++ flattenedAtoms(right, annotations) + case path: HttpCodec.Path[_] => Chunk.fromIterable(path.pathCodec.segments.map(metaCodecFromSegment)) + case atom: HttpCodec.Atom[_, _] => Chunk(MetaCodec(atom, annotations)) + case map: HttpCodec.TransformOrFail[_, _, _] => flattenedAtoms(map.api, annotations) + case HttpCodec.Empty => Chunk.empty + case HttpCodec.Halt => Chunk.empty + case _: HttpCodec.Fallback[_, _, _] => in.alternatives.map(_._1).flatMap(flattenedAtoms(_, annotations)) + case HttpCodec.Annotated(api, annotation) => + flattenedAtoms(api, annotations :+ annotation.asInstanceOf[HttpCodec.Metadata[Any]]) + } + } + + private def metaCodecFromSegment(segment: SegmentCodec[_]) = { + segment match { + case SegmentCodec.Annotated(codec, annotations) => + MetaCodec( + codec, + annotations.map { + case SegmentCodec.MetaData.Documented(value) => HttpCodec.Metadata.Documented(value) + case SegmentCodec.MetaData.Examples(examples) => HttpCodec.Metadata.Examples(examples) + }.asInstanceOf[Chunk[HttpCodec.Metadata[Any]]], + ) + case other => MetaCodec(other, Chunk.empty) + } + } + + def contentAsJsonSchema[R, A]( + codec: HttpCodec[R, A], + metadata: Chunk[HttpCodec.Metadata[_]] = Chunk.empty, + referenceType: SchemaStyle = SchemaStyle.Inline, + wrapInObject: Boolean = false, + ): JsonSchema = { + codec match { + case atom: HttpCodec.Atom[_, _] => + atom match { + case HttpCodec.Content(schema, _, maybeName, _) if wrapInObject => + val name = + findName(metadata).orElse(maybeName).getOrElse(throw new Exception("Multipart content without name")) + JsonSchema.obj( + name -> JsonSchema + .fromZSchema(schema, referenceType) + .description(description(metadata)) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)), + ) + case HttpCodec.ContentStream(schema, _, maybeName, _) if wrapInObject && schema == Schema[Byte] => + val name = + findName(metadata).orElse(maybeName).getOrElse(throw new Exception("Multipart content without name")) + JsonSchema.obj( + name -> JsonSchema + .fromZSchema(schema, referenceType) + .description(description(metadata)) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + // currently we have no information about the encoding. So we just assume binary + .contentEncoding(JsonSchema.ContentEncoding.Binary) + .contentMediaType(MediaType.application.`octet-stream`.fullType), + ) + case HttpCodec.ContentStream(schema, _, maybeName, _) if wrapInObject => + val name = + findName(metadata).orElse(maybeName).getOrElse(throw new Exception("Multipart content without name")) + JsonSchema.obj( + name -> JsonSchema + .fromZSchema(schema, referenceType) + .description(description(metadata)) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)), + ) + case HttpCodec.Content(schema, _, _, _) => + JsonSchema + .fromZSchema(schema, referenceType) + .description(description(metadata)) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + case HttpCodec.ContentStream(schema, _, _, _) => + JsonSchema + .fromZSchema(schema, referenceType) + .description(description(metadata)) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + case _ => JsonSchema.Null + } + case HttpCodec.Annotated(codec, data) => + contentAsJsonSchema(codec, metadata :+ data, referenceType, wrapInObject) + case HttpCodec.TransformOrFail(api, _, _) => contentAsJsonSchema(api, metadata, referenceType, wrapInObject) + case HttpCodec.Empty => JsonSchema.Null + case HttpCodec.Halt => JsonSchema.Null + case HttpCodec.Combine(left, right, _) if isMultipart(codec) => + ( + contentAsJsonSchema(left, Chunk.empty, referenceType, wrapInObject = true), + contentAsJsonSchema(right, Chunk.empty, referenceType, wrapInObject = true), + ) match { + case (left, right) => + val annotations = left.annotations ++ right.annotations + (left.withoutAnnotations, right.withoutAnnotations) match { + case (JsonSchema.Object(p1, _, r1), JsonSchema.Object(p2, _, r2)) => + // seems odd to allow additional properties for multipart. So just hardcode it to false + JsonSchema + .Object(p1 ++ p2, Left(false), r1 ++ r2) + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + .description(description(metadata)) + .annotate(annotations) + case _ => throw new IllegalArgumentException("Multipart content without name.") + } + + } + case HttpCodec.Combine(left, right, _) => + ( + contentAsJsonSchema(left, Chunk.empty, referenceType, wrapInObject), + contentAsJsonSchema(right, Chunk.empty, referenceType, wrapInObject), + ) match { + case (JsonSchema.Null, JsonSchema.Null) => + JsonSchema.Null + case (JsonSchema.Null, schema) => + schema + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + .description(description(metadata)) + case (schema, JsonSchema.Null) => + schema + .deprecated(deprecated(metadata)) + .nullable(optional(metadata)) + .description(description(metadata)) + case _ => + throw new IllegalStateException("A non multipart combine, should lead to at least one null schema.") + } + case HttpCodec.Fallback(_, _, _) => throw new IllegalArgumentException("Fallback not supported at this point") + } + } + + private def findName(metadata: Chunk[HttpCodec.Metadata[_]]): Option[String] = + metadata.reverse + .find(_.isInstanceOf[Metadata.Named[_]]) + .asInstanceOf[Option[Metadata.Named[Any]]] + .map(_.name) + + private def description(metadata: Chunk[HttpCodec.Metadata[_]]): Option[String] = + metadata.collect { case HttpCodec.Metadata.Documented(doc) => doc } + .reduceOption(_ + _) + .map(_.toCommonMark) + + private def deprecated(metadata: Chunk[HttpCodec.Metadata[_]]): Boolean = + metadata.exists(_.isInstanceOf[HttpCodec.Metadata.Deprecated[_]]) + + private def optional(metadata: Chunk[HttpCodec.Metadata[_]]): Boolean = + metadata.exists(_.isInstanceOf[HttpCodec.Metadata.Optional[_]]) + + def status[R, A](codec: HttpCodec[R, A]): Option[Status] = + codec match { + case HttpCodec.Status(simpleCodec, _) if simpleCodec.isInstanceOf[SimpleCodec.Specified[_]] => + Some(simpleCodec.asInstanceOf[SimpleCodec.Specified[Status]].value) + case HttpCodec.Annotated(codec, _) => + status(codec) + case HttpCodec.TransformOrFail(api, _, _) => + status(api) + case HttpCodec.Empty => + None + case HttpCodec.Halt => + None + case HttpCodec.Combine(left, right, _) => + status(left).orElse(status(right)) + case HttpCodec.Fallback(left, right, _) => + status(left).orElse(status(right)) + case _ => + None + } + + def isMultipart[R, A](codec: HttpCodec[R, A]): Boolean = + codec match { + case HttpCodec.Combine(left, right, _) => + (isContent(left) && isContent(right)) || + isMultipart(left) || isMultipart(right) + case HttpCodec.Annotated(codec, _) => isMultipart(codec) + case HttpCodec.TransformOrFail(codec, _, _) => isMultipart(codec) + case _ => false + } + + def isContent(value: HttpCodec[_, _]): Boolean = + value match { + case HttpCodec.Content(_, _, _, _) => true + case HttpCodec.ContentStream(_, _, _, _) => true + case HttpCodec.Annotated(codec, _) => isContent(codec) + case HttpCodec.TransformOrFail(codec, _, _) => isContent(codec) + case HttpCodec.Combine(left, right, _) => isContent(left) || isContent(right) + case _ => false + } + + private def toJsonAst(schema: Schema[_], v: Any): Json = + JsonCodec + .jsonEncoder(schema.asInstanceOf[Schema[Any]]) + .toJsonAST(v) + .toOption + .get + + def fromEndpoints( + endpoint1: Endpoint[_, _, _, _, _], + endpoints: Endpoint[_, _, _, _, _]*, + ): OpenAPI = fromEndpoints(endpoint1 +: endpoints) + + def fromEndpoints( + title: String, + version: String, + endpoint1: Endpoint[_, _, _, _, _], + endpoints: Endpoint[_, _, _, _, _]*, + ): OpenAPI = fromEndpoints(title, version, endpoint1 +: endpoints) + + def fromEndpoints( + title: String, + version: String, + referenceType: SchemaStyle, + endpoint1: Endpoint[_, _, _, _, _], + endpoints: Endpoint[_, _, _, _, _]*, + ): OpenAPI = fromEndpoints(title, version, referenceType, endpoint1 +: endpoints) + + def fromEndpoints( + referenceType: SchemaStyle, + endpoints: Iterable[Endpoint[_, _, _, _, _]], + ): OpenAPI = if (endpoints.isEmpty) OpenAPI.empty else endpoints.map(gen(_, referenceType)).reduce(_ ++ _) + + def fromEndpoints( + endpoints: Iterable[Endpoint[_, _, _, _, _]], + ): OpenAPI = if (endpoints.isEmpty) OpenAPI.empty else endpoints.map(gen(_, SchemaStyle.Compact)).reduce(_ ++ _) + + def fromEndpoints( + title: String, + version: String, + endpoints: Iterable[Endpoint[_, _, _, _, _]], + ): OpenAPI = fromEndpoints(endpoints).title(title).version(version) + + def fromEndpoints( + title: String, + version: String, + referenceType: SchemaStyle, + endpoints: Iterable[Endpoint[_, _, _, _, _]], + ): OpenAPI = fromEndpoints(referenceType, endpoints).title(title).version(version) + + def gen( + endpoint: Endpoint[_, _, _, _, _], + referenceType: SchemaStyle = SchemaStyle.Compact, + ): OpenAPI = { + val inAtoms = AtomizedMetaCodecs.flatten(endpoint.input) + val outs: Map[OpenAPI.StatusOrDefault, Map[MediaType, (JsonSchema, AtomizedMetaCodecs)]] = + schemaByStatusAndMediaType( + endpoint.output.alternatives.map(_._1) ++ endpoint.error.alternatives.map(_._1), + referenceType, + ) + // there is no status for inputs. So we just take the first one (default) + val ins = schemaByStatusAndMediaType(endpoint.input.alternatives.map(_._1), referenceType).values.headOption + + def path: OpenAPI.Paths = { + val path = buildPath(endpoint.input) + val method0 = method(inAtoms.method) + // Endpoint has only one doc. But open api has a summery and a description + val pathItem = OpenAPI.PathItem.empty + .copy(description = Some(endpoint.doc + endpoint.input.doc.getOrElse(Doc.empty)).filter(!_.isEmpty)) + val pathItemWithOp = method0 match { + case Method.OPTIONS => pathItem.addOptions(operation(endpoint)) + case Method.GET => pathItem.addGet(operation(endpoint)) + case Method.HEAD => pathItem.addHead(operation(endpoint)) + case Method.POST => pathItem.addPost(operation(endpoint)) + case Method.PUT => pathItem.addPut(operation(endpoint)) + case Method.PATCH => pathItem.addPatch(operation(endpoint)) + case Method.DELETE => pathItem.addDelete(operation(endpoint)) + case Method.TRACE => pathItem.addTrace(operation(endpoint)) + case Method.ANY => pathItem.any(operation(endpoint)) + case method => throw new IllegalArgumentException(s"OpenAPI does not support method $method") + } + Map(path -> pathItemWithOp) + } + + def buildPath(in: HttpCodec[_, _]): OpenAPI.Path = { + + def pathCodec(in1: HttpCodec[_, _]): Option[HttpCodec.Path[_]] = in1 match { + case atom: HttpCodec.Atom[_, _] => + atom match { + case codec @ HttpCodec.Path(_, _) => Some(codec) + case _ => None + } + case HttpCodec.Annotated(in, _) => pathCodec(in) + case HttpCodec.TransformOrFail(api, _, _) => pathCodec(api) + case HttpCodec.Empty => None + case HttpCodec.Halt => None + case HttpCodec.Combine(left, right, _) => pathCodec(left).orElse(pathCodec(right)) + case HttpCodec.Fallback(left, right, _) => pathCodec(left).orElse(pathCodec(right)) + } + + val pathString = { + val codec = pathCodec(in).getOrElse(throw new Exception("No path found.")).pathCodec + if (codec.render.endsWith(SegmentCodec.Trailing.render)) + codec.renderIgnoreTrailing + s"{$PathWildcard}" + else codec.render + } + OpenAPI.Path.fromString(pathString).getOrElse(throw new Exception(s"Invalid path: $pathString")) + } + + def method(in: Chunk[MetaCodec[SimpleCodec[Method, _]]]): Method = { + if (in.size > 1) throw new Exception("Multiple methods not supported") + in.collectFirst { case MetaCodec(SimpleCodec.Specified(method: Method), _) => method } + .getOrElse(throw new Exception("No method specified")) + } + + def operation(endpoint: Endpoint[_, _, _, _, _]): OpenAPI.Operation = + OpenAPI.Operation( + tags = Nil, + summary = None, + description = Some(endpoint.doc + pathDoc).filter(!_.isEmpty), + externalDocs = None, + operationId = None, + parameters = parameters, + requestBody = requestBody, + responses = responses, + callbacks = Map.empty, + security = Nil, + servers = Nil, + ) + + def pathDoc: Doc = { + def loop(codec: PathCodec[_]): Doc = codec match { + case PathCodec.Segment(_) => + // segment docs are used in path parameters + Doc.empty + case PathCodec.Concat(left, right, _, _) => + loop(left) + loop(right) + case PathCodec.TransformOrFail(api, _, _) => + loop(api) + } + loop(endpoint.route.pathCodec) + } + + def requestBody: Option[OpenAPI.ReferenceOr[OpenAPI.RequestBody]] = + ins.map { mediaTypes => + val combinedAtomizedCodecs = mediaTypes.map { case (_, (_, atomized)) => atomized }.reduce(_ ++ _) + val mediaTypeResponses = mediaTypes.map { case (mediaType, (schema, atomized)) => + mediaType.fullType -> OpenAPI.MediaType( + schema = OpenAPI.ReferenceOr.Or(schema), + examples = atomized.contentExamples, + encoding = Map.empty, + ) + } + OpenAPI.ReferenceOr.Or( + OpenAPI.RequestBody( + content = mediaTypeResponses, + required = combinedAtomizedCodecs.content.exists(_.required), + ), + ) + } + + def responses: OpenAPI.Responses = + responsesForAlternatives(outs) + + def parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + queryParams ++ pathParams ++ headerParams + + def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(name, codec, _), _) => + OpenAPI.ReferenceOr.Or( + OpenAPI.Parameter.queryParameter( + name = name, + description = mc.docsOpt, + schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromTextCodec(codec))), + deprecated = mc.deprecated, + style = OpenAPI.Parameter.Style.Form, + explode = false, + allowReserved = false, + examples = mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) + }, + required = mc.required, + ), + ) + } + }.toSet + + def pathParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + inAtoms.path.collect { + case mc @ MetaCodec(codec, _) if codec != SegmentCodec.Empty && !codec.isInstanceOf[SegmentCodec.Literal] => + OpenAPI.ReferenceOr.Or( + OpenAPI.Parameter.pathParameter( + name = mc.name.getOrElse(throw new Exception("Path parameter must have a name")), + description = mc.docsOpt, + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromSegmentCodec(codec))), + deprecated = mc.deprecated, + style = OpenAPI.Parameter.Style.Simple, + examples = mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(segmentToJson(codec, value))) + }, + ), + ) + }.toSet + + def headerParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + inAtoms.header + .asInstanceOf[Chunk[MetaCodec[HttpCodec.Header[Any]]]] + .map { case mc @ MetaCodec(codec, _) => + OpenAPI.ReferenceOr.Or( + OpenAPI.Parameter.headerParameter( + name = mc.name.getOrElse(codec.name), + description = mc.docsOpt, + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromTextCodec(codec.textCodec))), + deprecated = mc.deprecated, + examples = mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.textCodec.encode(value).toJsonAST.toOption.get)) + }, + required = mc.required, + ), + ) + } + .toSet + + def genDiscriminator(schema: Schema[_]): Option[OpenAPI.Discriminator] = { + schema match { + case enumSchema: Schema.Enum[_] => + val discriminatorName = + enumSchema.annotations.collectFirst { case zio.schema.annotation.discriminatorName(name) => name } + val noDiscriminator = enumSchema.annotations.contains(zio.schema.annotation.noDiscriminator()) + val typeMapping = enumSchema.cases.map { case_ => + val caseName = + case_.annotations.collectFirst { case zio.schema.annotation.caseName(name) => name }.getOrElse(case_.id) + // There should be no enums with cases that are not records with a nominal id + // TODO: not true. Since one could build a schema with a enum with a case that is a primitive + val typeId = + case_.schema + .asInstanceOf[Schema.Record[_]] + .id + .asInstanceOf[TypeId.Nominal] + caseName -> schemaReferencePath(typeId, referenceType) + } + + if (noDiscriminator) None + else discriminatorName.map(name => OpenAPI.Discriminator(name, typeMapping.toMap)) + + case _ => None + } + } + + def components = OpenAPI.Components( + schemas = componentSchemas, + responses = Map.empty, + parameters = Map.empty, + examples = Map.empty, + requestBodies = Map.empty, + headers = Map.empty, + securitySchemes = Map.empty, + links = Map.empty, + callbacks = Map.empty, + ) + + @tailrec + def segmentToJson(codec: SegmentCodec[_], value: Any): Json = { + codec match { + case SegmentCodec.Empty => throw new Exception("Empty segment not allowed") + case SegmentCodec.Literal(_) => throw new Exception("Literal segment not allowed") + case SegmentCodec.BoolSeg(_) => Json.Bool(value.asInstanceOf[Boolean]) + case SegmentCodec.IntSeg(_) => Json.Num(value.asInstanceOf[Int]) + case SegmentCodec.LongSeg(_) => Json.Num(value.asInstanceOf[Long]) + case SegmentCodec.Text(_) => Json.Str(value.asInstanceOf[String]) + case SegmentCodec.UUID(_) => Json.Str(value.asInstanceOf[UUID].toString) + case SegmentCodec.Annotated(codec, _) => segmentToJson(codec, value) + case SegmentCodec.Trailing => throw new Exception("Trailing segment not allowed") + } + } + + def componentSchemas: Map[OpenAPI.Key, OpenAPI.ReferenceOr[JsonSchema]] = + (endpoint.input.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content) + ++ endpoint.error.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content) + ++ endpoint.output.alternatives.map(_._1).map(AtomizedMetaCodecs.flatten(_)).flatMap(_.content)).collect { + case MetaCodec(HttpCodec.Content(schema, _, _, _), _) if nominal(schema, referenceType).isDefined => + val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) + schemas.children.map { case (key, schema) => + OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) + } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> + OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + case MetaCodec(HttpCodec.ContentStream(schema, _, _, _), _) if nominal(schema, referenceType).isDefined => + val schemas = JsonSchema.fromZSchemaMulti(schema, referenceType) + schemas.children.map { case (key, schema) => + OpenAPI.Key.fromString(key.replace("#/components/schemas/", "")).get -> OpenAPI.ReferenceOr.Or(schema) + } + (OpenAPI.Key.fromString(nominal(schema, referenceType).get).get -> + OpenAPI.ReferenceOr.Or(schemas.root.discriminator(genDiscriminator(schema)))) + }.flatten.toMap + + OpenAPI( + "3.1.0", + info = OpenAPI.Info( + title = "", + description = None, + termsOfService = None, + contact = None, + license = None, + version = "", + ), + servers = Nil, + paths = path, + components = Some(components), + security = Nil, + tags = Nil, + externalDocs = None, + ) + } + + private def schemaByStatusAndMediaType( + alternatives: Chunk[HttpCodec[_, _]], + referenceType: SchemaStyle, + ): Map[OpenAPI.StatusOrDefault, Map[MediaType, (JsonSchema, AtomizedMetaCodecs)]] = { + val statusAndCodec = + alternatives.map { codec => + val statusOrDefault = + status(codec).map(OpenAPI.StatusOrDefault.StatusValue(_)).getOrElse(OpenAPI.StatusOrDefault.Default) + statusOrDefault -> (AtomizedMetaCodecs + .flatten(codec), contentAsJsonSchema(codec, referenceType = referenceType)) + } + + groupMap(statusAndCodec) { case (status, _) => status } { case (_, atomizedAndSchema) => + atomizedAndSchema + }.map { case (status, values) => + val mapped = values + .foldLeft(Chunk.empty[(MediaType, (AtomizedMetaCodecs, JsonSchema))]) { case (acc, (atomized, schema)) => + if (atomized.content.size > 1) { + acc :+ (MediaType.multipart.`form-data` -> (atomized, schema)) + } else { + val mediaType = atomized.content.headOption match { + case Some(MetaCodec(HttpCodec.Content(_, Some(mediaType), _, _), _)) => + mediaType + case Some(MetaCodec(HttpCodec.ContentStream(_, Some(mediaType), _, _), _)) => + mediaType + case Some(MetaCodec(HttpCodec.ContentStream(schema, None, _, _), _)) => + if (schema == Schema[Byte]) MediaType.application.`octet-stream` + else MediaType.application.`json` + case _ => + MediaType.application.`json` + } + acc :+ (mediaType -> (atomized, schema)) + } + } + status -> groupMap(mapped) { case (mediaType, _) => mediaType } { case (_, atomizedAndSchema) => + atomizedAndSchema + }.map { + case (mediaType, Chunk((atomized, schema))) if values.size == 1 => + mediaType -> (schema, atomized) + case (mediaType, values) => + val combinedAtomized: AtomizedMetaCodecs = values.map(_._1).reduce(_ ++ _) + val combinedContentDoc = combinedAtomized.contentDocs.toCommonMark + val alternativesSchema = { + JsonSchema + .AnyOfSchema(values.map { case (_, schema) => + schema.description match { + case Some(value) => schema.description(value.replace(combinedContentDoc, "")) + case None => schema + } + }) + .minify + .description(combinedContentDoc) + } + mediaType -> (alternativesSchema, combinedAtomized) + } + } + } + + def nominal(schema: Schema[_], referenceType: SchemaStyle): Option[String] = + schema match { + case enumSchema: Schema.Enum[_] => + enumSchema.id match { + case TypeId.Structural => + None + case nominal: TypeId.Nominal if referenceType == SchemaStyle.Compact => + Some(nominal.typeName) + case nominal: TypeId.Nominal => + Some(nominal.fullyQualified.replace(".", "_")) + } + case record: Record[_] => + record.id match { + case TypeId.Structural => + None + case nominal: TypeId.Nominal if referenceType == SchemaStyle.Compact => + Some(nominal.typeName) + case nominal: TypeId.Nominal => + Some(nominal.fullyQualified.replace(".", "_")) + } + case _ => None + } + + private def responsesForAlternatives( + codecs: Map[OpenAPI.StatusOrDefault, Map[MediaType, (JsonSchema, AtomizedMetaCodecs)]], + ): Map[OpenAPI.StatusOrDefault, OpenAPI.ReferenceOr[OpenAPI.Response]] = + codecs.map { case (status, mediaTypes) => + val combinedAtomizedCodecs = mediaTypes.map { case (_, (_, atomized)) => atomized }.reduce(_ ++ _) + val mediaTypeResponses = mediaTypes.map { case (mediaType, (schema, atomized)) => + mediaType.fullType -> OpenAPI.MediaType( + schema = OpenAPI.ReferenceOr.Or(schema), + examples = atomized.contentExamples, + encoding = Map.empty, + ) + } + status -> OpenAPI.ReferenceOr.Or( + OpenAPI.Response( + headers = headersFrom(combinedAtomizedCodecs), + content = mediaTypeResponses, + links = Map.empty, + ), + ) + } + + private def headersFrom(codec: AtomizedMetaCodecs) = { + codec.header.map { case mc @ MetaCodec(codec, _) => + codec.name -> OpenAPI.ReferenceOr.Or( + OpenAPI.Header( + description = mc.docsOpt, + required = true, + deprecated = mc.deprecated, + allowEmptyValue = false, + schema = Some(JsonSchema.fromTextCodec(codec.textCodec)), + ), + ) + }.toMap + } + private def schemaReferencePath(nominal: TypeId.Nominal, referenceType: SchemaStyle): String = { + referenceType match { + case SchemaStyle.Compact => s"#/components/schemas/${nominal.typeName}}" + case _ => s"#/components/schemas/${nominal.fullyQualified.replace(".", "_")}}" + } + } +} diff --git a/zio-http/src/main/scala/zio/http/endpoint/openapi/SwaggerUI.scala b/zio-http/src/main/scala/zio/http/endpoint/openapi/SwaggerUI.scala new file mode 100644 index 0000000000..215b1786d8 --- /dev/null +++ b/zio-http/src/main/scala/zio/http/endpoint/openapi/SwaggerUI.scala @@ -0,0 +1,104 @@ +package zio.http.endpoint.openapi + +import java.net.URLEncoder + +import zio.http._ +import zio.http.codec.PathCodec + +object SwaggerUI { + + val DefaultSwaggerUIVersion: String = "5.10.3" + + //format: off + /** + * Creates routes for serving the Swagger UI at the given path. + * + * Example: + * {{{ + * val routes: Routes[Any, Response] = ??? + * val openAPIv1: OpenAPI = ??? + * val openAPIv2: OpenAPI = ??? + * val swaggerUIRoutes = SwaggerUI.routes("docs" / "openapi", openAPIv1, openAPIv2) + * val routesWithSwagger = routes ++ swaggerUIRoutes + * }}} + * + * With this middleware in place, a request to `https://www.domain.com/[path]` + * would serve the Swagger UI. The different OpenAPI specifications are served + * at `https://www.domain.com/[path]/[title].json`. Where `title` is the title + * of the OpenAPI specification and is url encoded. + */ + //format: on + def routes(path: PathCodec[Unit], api: OpenAPI, apis: OpenAPI*): Routes[Any, Response] = { + routes(path, DefaultSwaggerUIVersion, api, apis: _*) + } + + //format: off + /** + * Creates a middleware for serving the Swagger UI at the given path and with + * the given swagger ui version. + * + * Example: + * {{{ + * val routes: Routes[Any, Response] = ??? + * val openAPIv1: OpenAPI = ??? + * val openAPIv2: OpenAPI = ??? + * val swaggerUIRoutes = SwaggerUI.routes("docs" / "openapi", openAPIv1, openAPIv2) + * val routesWithSwagger = routes ++ swaggerUIRoutes + * }}} + * + * With this middleware in place, a request to `https://www.domain.com/[path]` + * would serve the Swagger UI. The different OpenAPI specifications are served + * at `https://www.domain.com/[path]/[title].json`. Where `title` is the title + * of the OpenAPI specification and is url encoded. + */ + //format: on + def routes(path: PathCodec[Unit], version: String, api: OpenAPI, apis: OpenAPI*): Routes[Any, Response] = { + import zio.http.template._ + val basePath = Method.GET / path + val jsonRoutes = (api +: apis).map { api => + basePath / s"${URLEncoder.encode(api.info.title, Charsets.Utf8.name())}.json" -> handler { (_: Request) => + Response.json(api.toJson) + } + } + val jsonPaths = jsonRoutes.map(_.routePattern.pathCodec.render) + val jsonTitles = (api +: apis).map(_.info.title) + val jsonUrls = jsonTitles.zip(jsonPaths).map { case (title, path) => s"""{url: "$path", name: "$title"}""" } + val uiRoute = basePath -> handler { (_: Request) => + Response.html( + html( + head( + meta(charsetAttr := "utf-8"), + meta(nameAttr := "viewport", contentAttr := "width=device-width, initial-scale=1"), + meta(nameAttr := "description", contentAttr := "SwaggerUI"), + title("SwaggerUI"), + link(relAttr := "stylesheet", href := s"https://unpkg.com/swagger-ui-dist@$version/swagger-ui.css"), + link( + relAttr := "icon", + typeAttr := "image/png", + href := s"https://unpkg.com/swagger-ui-dist@$version/favicon-32x32.png", + ), + ), + body( + div(id := "swagger-ui"), + script(srcAttr := s"https://unpkg.com/swagger-ui-dist@$version/swagger-ui-bundle.js"), + script(srcAttr := s"https://unpkg.com/swagger-ui-dist@$version/swagger-ui-standalone-preset.js"), + Dom.raw(s"""""".stripMargin), + ), + ), + ) + } + Routes.fromIterable(jsonRoutes) :+ uiRoute + } +} diff --git a/zio-http/src/main/scala/zio/http/netty/EventLoopGroups.scala b/zio-http/src/main/scala/zio/http/netty/EventLoopGroups.scala index c1ea30143e..9904cd3c09 100644 --- a/zio-http/src/main/scala/zio/http/netty/EventLoopGroups.scala +++ b/zio-http/src/main/scala/zio/http/netty/EventLoopGroups.scala @@ -51,11 +51,12 @@ object EventLoopGroups { def make(config: Config, eventLoopGroup: UIO[EventLoopGroup])(implicit trace: Trace, ): ZIO[Scope, Nothing, EventLoopGroup] = - ZIO.acquireRelease(eventLoopGroup)(ev => + ZIO.acquireRelease(eventLoopGroup) { ev => + val future = ev.shutdownGracefully(config.shutdownQuietPeriod, config.shutdownTimeOut, config.shutdownTimeUnit) NettyFutureExecutor - .executed(ev.shutdownGracefully(config.shutdownQuietPeriod, config.shutdownTimeOut, config.shutdownTimeUnit)) - .orDie, - ) + .executed(future) + .orDie + } def epoll(config: Config)(implicit trace: Trace): ZIO[Scope, Nothing, EventLoopGroup] = make(config, ZIO.succeed(new EpollEventLoopGroup(config.nThreads))) diff --git a/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala b/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala index 222a6b8b54..ea539fa1f1 100644 --- a/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala +++ b/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala @@ -47,6 +47,7 @@ object NettyConnectionPool { location: URL.Location.Absolute, proxy: Option[Proxy], sslOptions: ClientSSLConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, decompression: Decompression, idleTimeout: Option[Duration], @@ -70,7 +71,7 @@ object NettyConnectionPool { case None => } - if (location.scheme.isSecure) { + if (location.scheme.isSecure.getOrElse(false)) { pipeline.addLast( Names.SSLHandler, ClientSSLConverter @@ -92,7 +93,7 @@ object NettyConnectionPool { // This way, if the server closes the connection before the whole response has been sent, // we get an error. (We can also handle the channelInactive callback, but since for now // we always buffer the whole HTTP response we can letty Netty take care of this) - pipeline.addLast(Names.HttpClientCodec, new HttpClientCodec(4096, maxHeaderSize, 8192, true)) + pipeline.addLast(Names.HttpClientCodec, new HttpClientCodec(maxInitialLineLength, maxHeaderSize, 8192, true)) // HttpContentDecompressor if (decompression.enabled) @@ -135,6 +136,7 @@ object NettyConnectionPool { location: Location.Absolute, proxy: Option[Proxy], sslOptions: ClientSSLConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, decompression: Decompression, idleTimeout: Option[Duration], @@ -147,6 +149,7 @@ object NettyConnectionPool { location, proxy, sslOptions, + maxInitialLineLength, maxHeaderSize, decompression, idleTimeout, @@ -166,6 +169,7 @@ object NettyConnectionPool { location: Location.Absolute, proxy: Option[Proxy], sslOptions: ClientSSLConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, decompression: Decompression, idleTimeout: Option[Duration], @@ -179,6 +183,7 @@ object NettyConnectionPool { location: Location.Absolute, proxy: Option[Proxy], sslOptions: ClientSSLConfig, + maxInitialLineLength: Int, maxHeaderSize: Int, decompression: Decompression, idleTimeout: Option[Duration], @@ -186,7 +191,18 @@ object NettyConnectionPool { localAddress: Option[InetSocketAddress] = None, )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = pool - .get(PoolKey(location, proxy, sslOptions, maxHeaderSize, decompression, idleTimeout, connectionTimeout)) + .get( + PoolKey( + location, + proxy, + sslOptions, + maxInitialLineLength, + maxHeaderSize, + decompression, + idleTimeout, + connectionTimeout, + ), + ) override def invalidate(channel: JChannel)(implicit trace: Trace): ZIO[Any, Nothing, Unit] = pool.invalidate(channel) @@ -243,6 +259,7 @@ object NettyConnectionPool { key.location, key.proxy, key.sslOptions, + key.maxInitialLineLength, key.maxHeaderSize, key.decompression, key.idleTimeout, @@ -287,6 +304,7 @@ object NettyConnectionPool { key.location, key.proxy, key.sslOptions, + key.maxInitialLineLength, key.maxHeaderSize, key.decompression, key.idleTimeout, diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerChannelInitializer.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerChannelInitializer.scala index 418626c99f..31186ad54c 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerChannelInitializer.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerChannelInitializer.scala @@ -60,7 +60,7 @@ private[zio] final case class ServerChannelInitializer( // Instead of ServerCodec, we should use Decoder and Encoder separately to have more granular control over performance. pipeline.addLast( Names.HttpRequestDecoder, - new HttpRequestDecoder(DEFAULT_MAX_INITIAL_LINE_LENGTH, cfg.maxHeaderSize, DEFAULT_MAX_CHUNK_SIZE, false), + new HttpRequestDecoder(cfg.maxInitialLineLength, cfg.maxHeaderSize, DEFAULT_MAX_CHUNK_SIZE, false), ) pipeline.addLast(Names.HttpResponseEncoder, new HttpResponseEncoder()) diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala index 9ffeed426d..10b093e63c 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerSSLDecoder.scala @@ -19,11 +19,9 @@ package zio.http.netty.server import java.io.FileInputStream import java.util -import zio.stacktracer.TracingImplicits.disableAutoTrace - import zio.http.SSLConfig.{HttpBehaviour, Provider} import zio.http.netty.Names -import zio.http.{SSLConfig, Server} +import zio.http.{ClientAuth, SSLConfig, Server} import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext @@ -33,27 +31,38 @@ import io.netty.handler.ssl.ApplicationProtocolConfig.{ SelectedListenerFailureBehavior, SelectorFailureBehavior, } +import io.netty.handler.ssl._ import io.netty.handler.ssl.util.SelfSignedCertificate -import io.netty.handler.ssl.{SslContext, SslHandler, _} +import io.netty.handler.ssl.{ClientAuth => NettyClientAuth} object SSLUtil { + def getClientAuth(clientAuth: ClientAuth): NettyClientAuth = clientAuth match { + case ClientAuth.Required => NettyClientAuth.REQUIRE + case ClientAuth.Optional => NettyClientAuth.OPTIONAL + case _ => NettyClientAuth.NONE + } + implicit class SslContextBuilderOps(self: SslContextBuilder) { def toNettyProvider(sslProvider: Provider): SslProvider = sslProvider match { case Provider.OpenSSL => SslProvider.OPENSSL case Provider.JDK => SslProvider.JDK } - def buildWithDefaultOptions(sslConfig: SSLConfig): SslContext = self - .sslProvider(toNettyProvider(sslConfig.provider)) - .applicationProtocolConfig( - new ApplicationProtocolConfig( - Protocol.ALPN, - SelectorFailureBehavior.NO_ADVERTISE, - SelectedListenerFailureBehavior.ACCEPT, - ApplicationProtocolNames.HTTP_1_1, - ), - ) - .build() + def buildWithDefaultOptions(sslConfig: SSLConfig): SslContext = { + val clientAuthConfig: Option[ClientAuth] = sslConfig.clientAuth + clientAuthConfig.foreach(ca => self.clientAuth(getClientAuth(ca))) + self + .sslProvider(toNettyProvider(sslConfig.provider)) + .applicationProtocolConfig( + new ApplicationProtocolConfig( + Protocol.ALPN, + SelectorFailureBehavior.NO_ADVERTISE, + SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_1_1, + ), + ) + .build() + } } def sslConfigToSslContext(sslConfig: SSLConfig): SslContext = sslConfig.data match { diff --git a/zio-http/src/main/scala/zio/http/template/Dom.scala b/zio-http/src/main/scala/zio/http/template/Dom.scala index 0ad20889ba..b2e3cf9260 100644 --- a/zio-http/src/main/scala/zio/http/template/Dom.scala +++ b/zio-http/src/main/scala/zio/http/template/Dom.scala @@ -34,8 +34,9 @@ sealed trait Dom { self => def encode(spaces: Int): CharSequence = encode(EncodingState.Indentation(0, spaces)) - private[template] def encode(state: EncodingState): CharSequence = self match { + private[template] def encode(state: EncodingState, encodeHtml: Boolean = true): CharSequence = self match { case Dom.Element(name, children) => + val encode = if (name == "script" || name == "style") false else encodeHtml val attributes = children.collect { case self: Dom.Attribute => self.encode } val innerState = state.inner @@ -51,9 +52,9 @@ sealed trait Dom { self => def inner: CharSequence = elements match { - case Seq(singleText: Dom.Text) => singleText.encode(innerState) + case Seq(singleText: Dom.Text) => singleText.encode(innerState, encode) case _ => - s"${innerState.nextElemSeparator}${elements.map(_.encode(innerState)).mkString(innerState.nextElemSeparator)}${state.nextElemSeparator}" + s"${innerState.nextElemSeparator}${elements.map(_.encode(innerState, encode)).mkString(innerState.nextElemSeparator)}${state.nextElemSeparator}" } if (noElements && noAttributes && isVoid) s"<$name/>" @@ -64,11 +65,11 @@ sealed trait Dom { self => else s"<$name ${attributes.mkString(" ")}>$inner" - case Dom.Text(data) => OutputEncoder.encodeHtml(data.toString) - case Dom.Attribute(name, value) => - s"""$name="${OutputEncoder.encodeHtml(value.toString)}"""" - case Dom.Empty => "" - case Dom.Raw(raw) => raw + case Dom.Text(data) if encodeHtml => OutputEncoder.encodeHtml(data.toString) + case Dom.Text(data) => data + case Dom.Attribute(name, value) => s"""$name="${OutputEncoder.encodeHtml(value.toString)}"""" + case Dom.Empty => "" + case Dom.Raw(raw) => raw } } diff --git a/zio-http/src/test/resources/endpoint/openapi/multiple-methods-on-same-path.json b/zio-http/src/test/resources/endpoint/openapi/multiple-methods-on-same-path.json new file mode 100644 index 0000000000..1fc50f4375 --- /dev/null +++ b/zio-http/src/test/resources/endpoint/openapi/multiple-methods-on-same-path.json @@ -0,0 +1,62 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "Multiple Methods on Same Path", + "version": "1.0" + }, + "paths": { + "/test": { + "get": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "null" + } + } + }, + "required": false + }, + "responses": { + "200": { + "description": "", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + } + }, + "deprecated": false + }, + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "string" + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + } + }, + "deprecated": false + } + } + }, + "components": {} +} diff --git a/zio-http/src/test/scala/zio/http/ClientHttpsSpec.scala b/zio-http/src/test/scala/zio/http/ClientHttpsSpec.scala index ad93ec2dbd..1a43a40153 100644 --- a/zio-http/src/test/scala/zio/http/ClientHttpsSpec.scala +++ b/zio-http/src/test/scala/zio/http/ClientHttpsSpec.scala @@ -77,5 +77,5 @@ object ClientHttpsSpec extends ZIOHttpSpec { DnsResolver.default, ZLayer.succeed(NettyConfig.default), Scope.default, - ) + ) @@ ignore } diff --git a/zio-http/src/test/scala/zio/http/FlashSpec.scala b/zio-http/src/test/scala/zio/http/FlashSpec.scala new file mode 100644 index 0000000000..3e81753966 --- /dev/null +++ b/zio-http/src/test/scala/zio/http/FlashSpec.scala @@ -0,0 +1,139 @@ +package zio.http + +import zio._ +import zio.test._ + +import zio.schema.{DeriveSchema, Schema} + +object FlashSpec extends ZIOHttpSpec { + + case class Article(name: String, price: Double) + object Article { + implicit val schema: Schema[Article] = DeriveSchema.gen + } + + case class Articles(list: List[Article]) + + object Articles { + implicit val schema: Schema[Articles] = DeriveSchema.gen + } + + override def spec = + suite("flash")( + test("set and get") { + + val flash1 = Flash.setValue("articles", Articles(List(Article("m\"i`l'k", 2.99), Article("choco", 4.99)))) + val flash2 = Flash.setValue("dataMap", Map("a" -> "A", "b" -> "B", "c" -> "CCC\"CCC\"CCCCC")) + val flash3 = Flash.setValue("dataList", List("a", "b", "c")) + val flash4 = Flash.setValue("articlesTuple", Article("a", 1.00) -> Article("b", 2.00)) + val cookie1 = Flash.Setter.run(flash1 ++ flash2 ++ flash3 ++ flash4) + + val cookie2 = Cookie.Request(Flash.COOKIE_NAME, cookie1.content) + val request = Request(headers = Headers(Header.Cookie(NonEmptyChunk(cookie2)))) + + assertTrue(request.flash(Flash.get[Articles]("does-not-exist") <> Flash.get[Articles]("articles")).isDefined) && + assertTrue(request.flash(Flash.get[Map[String, String]]).isDefined) && + assertTrue( + request + .flash(Flash.get[Articles]("articles").zip(Flash.get[Map[String, String]]).map { case (a, b) => + s" ---> $a @@@ $b <---- " + }) + .isDefined, + ) && + assertTrue( + request.flash(Flash.getString("articles").optional.zip(Flash.getDouble("bbb").optional)).isDefined, + ) && + assertTrue(request.flash(Flash.get[List[String]]("dataList")).isDefined) && + assertTrue(request.flash(Flash.get[List[String]]).isDefined) && + assertTrue(request.flash(Flash.get[List[Int]]).isEmpty) && + assertTrue(request.flash(Flash.get[(Article, Article)]("articlesTuple")).isDefined) + }, + test("flash message") { + val flashMessageDefaultBoth = Flash.setNotice[String]("notice") ++ Flash.setAlert[String]("alert") + val flashMessageCustomBoth = + Flash.setValue("custom-notice", Article("custom-notice", 10)) ++ Flash.setValue( + "custom-alert", + List("custom", "alert"), + ) + val flashMessageCustomOnlyNotice = Flash.setValue("custom-notice-only", "custom-notice-only-value") + val flashMessageCustomOnlyAlert = Flash.setValue("custom-alert-only", "custom-alert-only") + + val cookie1 = Flash.Setter.run( + flashMessageCustomBoth ++ flashMessageDefaultBoth ++ flashMessageCustomOnlyNotice ++ flashMessageCustomOnlyAlert, + ) + val cookie2 = Cookie.Request(Flash.COOKIE_NAME, cookie1.content) + val request = Request(headers = Headers(Header.Cookie(NonEmptyChunk(cookie2)))) + + assertTrue(request.flash(Flash.getMessageHtml).get.isBoth) && + assertTrue( + request + .flash(Flash.getMessage(Flash.get[Article]("custom-notice"), Flash.get[List[String]]("custom-alert"))) + .get + .isBoth, + ) && + assertTrue(request.flash(Flash.getMessage(Flash.get[Article], Flash.get[List[String]])).get.isBoth) && + assertTrue( + request + .flash(Flash.getMessage(Flash.getString("custom-notice-only"), Flash.getInt("does-not-exist"))) + .get + .isNotice, + ) && + assertTrue( + request + .flash(Flash.getMessage(Flash.getInt("does-not-exist"), Flash.getString("custom-alert-only"))) + .get + .isAlert, + ) + }, + test("flash backend") { + + import zio.http.template._ + + object ui { + def flashEmpty = Html.fromString("no-flash") + def flashBoth(notice: Html, alert: Html): Html = notice ++ alert + def flashNotice(html: Html): Html = div(styleAttr := Seq("background" -> "green"), html) + def flashAlert(html: Html): Html = div(styleAttr := Seq("background" -> "red"), html) + } + + val routeUserSavePath = Method.POST / "users" / "save" + val routeUserSave = routeUserSavePath -> handler { + for { + flashBackend <- ZIO.service[Flash.Backend] + respose <- flashBackend.addFlash( + Response.seeOther(URL.empty / "users"), + Flash.setNotice("user saved successfully"), + ) + } yield respose + } + + val routeConfirmPath = Method.GET / "users" + val routeConfirm = routeConfirmPath -> handler { (req: Request) => + for { + flashBackend <- ZIO.service[Flash.Backend] + html <- flashBackend.flashOrElse( + req, + Flash.getMessageHtml.foldHtml(ui.flashNotice, ui.flashAlert)(ui.flashBoth), + )(ui.flashEmpty) + } yield Response.html(html) + } + + val app = Routes(routeUserSave, routeConfirm).toHttpApp + + for { + response1 <- app.runZIO(Request.post(URL(routeUserSavePath.format(()).toOption.get), Body.empty)) + flashString = response1.header(Header.SetCookie).get.value.content + cookie = Cookie.Request(Flash.COOKIE_NAME, flashString) + response2 <- app.runZIO( + Request( + method = Method.GET, + url = URL(routeConfirmPath.format(()).toOption.get), + headers = Headers(Header.Cookie(NonEmptyChunk(cookie))), + ), + ) + bodyString <- response2.body.asString + } yield assertTrue(bodyString.contains("successfully") && bodyString.contains("green")) + }.provideLayer(Flash.Backend.inMemory), + ) + +} diff --git a/zio-http/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala b/zio-http/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala new file mode 100644 index 0000000000..f9f58b5a1f --- /dev/null +++ b/zio-http/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala @@ -0,0 +1,72 @@ +package zio.http + +import zio._ +import zio.test._ + +object LogAnnotationMiddlewareSpec extends ZIOSpecDefault { + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("LogAnnotationMiddlewareSpec")( + test("add static log annotation") { + val response = Routes + .singleton( + handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))), + ) + .@@(Middleware.logAnnotate("label", "value")) + .toHttpApp + .runZIO(Request.get("/")) + + for { + _ <- response + logs <- ZTestLogger.logOutput + log = logs.filter(_.message() == "Oh!").head + } yield assertTrue(log.annotations.get("label").contains("value")) + + }, + test("add request method and path as annotation") { + val response = Routes + .singleton( + handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))), + ) + .@@( + Middleware.logAnnotate(req => + Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode)), + ), + ) + .toHttpApp + .runZIO(Request.get("/")) + + for { + _ <- response + logs <- ZTestLogger.logOutput + log = logs.filter(_.message() == "Oh!").head + } yield assertTrue( + log.annotations.get("method").contains("GET"), + log.annotations.get("path").contains("/"), + ) + }, + test("add headers as annotation") { + val response = Routes + .singleton( + handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))), + ) + .@@(Middleware.logAnnotateHeaders("header")) + .@@(Middleware.logAnnotateHeaders(Header.UserAgent.name)) + .toHttpApp + .runZIO { + Request + .get("/") + .addHeader("header", "value") + .addHeader(Header.UserAgent.Product("zio-http", Some("3.0.0"))) + } + + for { + _ <- response + logs <- ZTestLogger.logOutput + log = logs.filter(_.message() == "Oh!").head + } yield assertTrue( + log.annotations.get("header").contains("value"), + log.annotations.get(Header.UserAgent.name).contains("zio-http/3.0.0"), + ) + }, + ) +} diff --git a/zio-http/src/test/scala/zio/http/NettyMaxInitialLineLengthSpec.scala b/zio-http/src/test/scala/zio/http/NettyMaxInitialLineLengthSpec.scala new file mode 100644 index 0000000000..dc10ec208e --- /dev/null +++ b/zio-http/src/test/scala/zio/http/NettyMaxInitialLineLengthSpec.scala @@ -0,0 +1,59 @@ +/* + * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.http + +import zio.test.TestAspect.withLiveClock +import zio.test._ +import zio.{Scope, ZLayer} + +object NettyMaxInitialLineLength extends ZIOHttpSpec { + val minimalInitialLineLength: Int = "GET / HTTP/1.1".getBytes.length + + def extractStatus(response: Response): Status = response.status + + private val serverConfig: Server.Config = + Server.Config.default.onAnyOpenPort.copy(maxInitialLineLength = minimalInitialLineLength) + + override def spec: Spec[TestEnvironment with Scope, Any] = + test("should get a failure instead of an empty body") { + val app = Handler + .fromFunctionZIO[Request] { request => + request.body.asString.map { body => + val responseBody = if (body.isEmpty) "" else body + Response.text(responseBody) + } // this should not be run, as the request is invalid + } + .sandbox + .toHttpApp + for { + port <- Server.install(app) + url = URL + .decode(s"http://localhost:$port/a%20looooooooooooooooooooooooooooong%20query%20parameter") + .toOption + .get + headers = Headers.empty + + res <- Client.request(Request(url = url, headers = headers, body = Body.fromString("some-body"))) + data <- res.body.asString + } yield assertTrue(extractStatus(res) == Status.InternalServerError, data == "") + }.provide( + Client.default, + Server.live, + ZLayer.succeed(serverConfig), + Scope.default, + ) @@ withLiveClock +} diff --git a/zio-http/src/test/scala/zio/http/PathSpec.scala b/zio-http/src/test/scala/zio/http/PathSpec.scala index 6c0b0dbb96..c3dc1a2338 100644 --- a/zio-http/src/test/scala/zio/http/PathSpec.scala +++ b/zio-http/src/test/scala/zio/http/PathSpec.scala @@ -413,5 +413,56 @@ object PathSpec extends ZIOHttpSpec with ExitAssertion { } }, ), + suite("removeDotSegments")( + test("only leading slash and dots") { + val path = Path.decode("/./../") + val result = path.removeDotSegments + val expected = Path.root + + assertTrue(result == expected) + }, + test("only leading dots") { + val path = Path.decode("./../") + val result = path.removeDotSegments + val expected = Path.empty + + assertTrue(result == expected) + }, + test("leading slash and dots") { + val path = Path.decode("/./../path") + val result = path.removeDotSegments + val expected = Path.decode("/path") + + assertTrue(result == expected) + }, + test("leading dots and path") { + val path = Path.decode("./../path") + val result = path.removeDotSegments + val expected = Path.decode("path") + + assertTrue(result == expected) + }, + test("double dot to top") { + val path = Path.decode("path/../subpath") + val result = path.removeDotSegments + val expected = Path.decode("/subpath") + + assertTrue(result == expected) + }, + test("trailing double dots") { + val path = Path.decode("path/ignored/..") + val result = path.removeDotSegments + val expected = Path.decode("path/") + + assertTrue(result == expected) + }, + test("path traversal") { + val path = Path.decode("/start/ignored/./../path/also/ignored/../../end/.") + val result = path.removeDotSegments + val expected = Path.decode("/start/path/end/") + + assertTrue(result == expected) + }, + ), ) } diff --git a/zio-http/src/test/scala/zio/http/QueryParamsSpec.scala b/zio-http/src/test/scala/zio/http/QueryParamsSpec.scala index 277ace21ed..7d7f551427 100644 --- a/zio-http/src/test/scala/zio/http/QueryParamsSpec.scala +++ b/zio-http/src/test/scala/zio/http/QueryParamsSpec.scala @@ -16,7 +16,7 @@ package zio.http -import zio.test.Assertion.equalTo +import zio.test.Assertion.{anything, equalTo, fails, hasSize} import zio.test._ import zio.{Chunk, ZIO} @@ -236,10 +236,45 @@ object QueryParamsSpec extends ZIOHttpSpec { val default = "default" val unknown = "non-existent" val queryParams = QueryParams(name -> "a", name -> "b") - assertTrue(queryParams.get(name).get == "a") && - assertTrue(queryParams.getOrElse(unknown, default) == default) && - assertTrue(queryParams.getAll(name).get.length == 2) && - assertTrue(queryParams.getAllOrElse(unknown, Chunk(default)).length == 1) + assertTrue( + queryParams.get(name).get == "a", + queryParams.get(unknown).isEmpty, + queryParams.getOrElse(name, default) == "a", + queryParams.getOrElse(unknown, default) == default, + queryParams.getAll(name).get.length == 2, + queryParams.getAll(unknown).isEmpty, + queryParams.getAllOrElse(name, Chunk(default)).length == 2, + queryParams.getAllOrElse(unknown, Chunk(default)).length == 1, + ) + }, + ), + suite("getAs - getAllAs")( + test("success") { + val typed = "typed" + val default = 3 + val invalidTyped = "invalidTyped" + val unknown = "non-existent" + val queryParams = QueryParams(typed -> "1", typed -> "2", invalidTyped -> "str") + assertTrue( + queryParams.getAs[Int](typed) == Right(1), + queryParams.getAs[Int](invalidTyped).isLeft, + queryParams.getAs[Int](unknown).isLeft, + queryParams.getAsOrElse[Int](typed, default) == 1, + queryParams.getAsOrElse[Int](invalidTyped, default) == default, + queryParams.getAsOrElse[Int](unknown, default) == default, + queryParams.getAllAs[Int](typed).map(_.length) == Right(2), + queryParams.getAllAs[Int](invalidTyped).isLeft, + queryParams.getAllAs[Int](unknown).isLeft, + queryParams.getAllAsOrElse[Int](typed, Chunk(default)).length == 2, + queryParams.getAllAsOrElse[Int](invalidTyped, Chunk(default)).length == 1, + queryParams.getAllAsOrElse[Int](unknown, Chunk(default)).length == 1, + ) + assertZIO(queryParams.getAsZIO[Int](typed))(equalTo(1)) && + assertZIO(queryParams.getAsZIO[Int](invalidTyped).exit)(fails(anything)) && + assertZIO(queryParams.getAsZIO[Int](unknown).exit)(fails(anything)) && + assertZIO(queryParams.getAllAsZIO[Int](typed))(hasSize(equalTo(2))) && + assertZIO(queryParams.getAllAsZIO[Int](invalidTyped).exit)(fails(anything)) && + assertZIO(queryParams.getAllAsZIO[Int](unknown).exit)(fails(anything)) }, ), suite("encode - decode")( diff --git a/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala b/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala index 511b6e50e3..5868d7a381 100644 --- a/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala +++ b/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala @@ -66,7 +66,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec { response <- client.request( Request( method = Method.GET, - url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)), + url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))), ) .addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())), ) @@ -82,7 +82,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec { response <- client.request( Request( method = Method.GET, - url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)), + url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))), ) .addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())), ) diff --git a/zio-http/src/test/scala/zio/http/RouteSpec.scala b/zio-http/src/test/scala/zio/http/RouteSpec.scala index 8555051a53..ab24e57f13 100644 --- a/zio-http/src/test/scala/zio/http/RouteSpec.scala +++ b/zio-http/src/test/scala/zio/http/RouteSpec.scala @@ -16,8 +16,6 @@ package zio.http -import scala.collection.Seq - import zio._ import zio.test._ @@ -25,6 +23,16 @@ object RouteSpec extends ZIOHttpSpec { def extractStatus(response: Response): Status = response.status def spec = suite("RouteSpec")( + suite("Route#prefix")( + test("prefix should add a prefix to the route") { + val route = + Method.GET / "foo" -> handler(Response.ok) + + val prefixed = route.nest("bar") + + assertTrue(prefixed.isDefinedAt(Request.get(url"/bar/foo"))) + }, + ), suite("Route#sandbox")( test("infallible route does not change under sandbox") { val route = @@ -64,5 +72,58 @@ object RouteSpec extends ZIOHttpSpec { } yield assertTrue(cnt == 2) }, ), + suite("error handle")( + test("handleErrorCauseZIO should execute a ZIO effect") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.fail(new Exception("hmm...")) } + for { + p <- zio.Promise.make[Exception, String] + + errorHandled = route + .handleErrorCauseZIO(c => p.failCause(c).as(Response.internalServerError)) + + request = Request.get(URL.decode("/endpoint").toOption.get) + response <- errorHandled.toHttpApp.runZIO(request) + result <- p.await.catchAllCause(c => ZIO.succeed(c.prettyPrint)) + + } yield assertTrue(extractStatus(response) == Status.InternalServerError, result.contains("hmm...")) + }, + test("handleErrorCauseRequestZIO should produce an error based on the request") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.fail(new Exception("hmm...")) } + for { + p <- zio.Promise.make[Exception, String] + + errorHandled = route + .handleErrorRequestCauseZIO((req, c) => + p.failCause(c).as(Response.internalServerError(s"error accessing ${req.path.encode}")), + ) + + request = Request.get(URL.decode("/endpoint").toOption.get) + response <- errorHandled.toHttpApp.runZIO(request) + result <- p.await.catchAllCause(c => ZIO.succeed(c.prettyPrint)) + resultWarning <- ZIO.fromOption(response.headers.get(Header.Warning).map(_.text)) + + } yield assertTrue( + extractStatus(response) == Status.InternalServerError, + resultWarning == "error accessing /endpoint", + result.contains("hmm..."), + ) + }, + test("handleErrorCauseRequest should produce an error based on the request") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.fail(new Exception("hmm...")) } + val errorHandled = + route.handleErrorRequest((e, req) => + Response.internalServerError(s"error accessing ${req.path.encode}: ${e.getMessage}"), + ) + val request = Request.get(URL.decode("/endpoint").toOption.get) + for { + response <- errorHandled.toHttpApp.runZIO(request) + resultWarning <- ZIO.fromOption(response.headers.get(Header.Warning).map(_.text)) + + } yield assertTrue( + extractStatus(response) == Status.InternalServerError, + resultWarning == "error accessing /endpoint: hmm...", + ) + }, + ), ) } diff --git a/zio-http/src/test/scala/zio/http/SchemeSpec.scala b/zio-http/src/test/scala/zio/http/SchemeSpec.scala index 46e9026126..dcbaaaddb2 100644 --- a/zio-http/src/test/scala/zio/http/SchemeSpec.scala +++ b/zio-http/src/test/scala/zio/http/SchemeSpec.scala @@ -31,5 +31,8 @@ object SchemeSpec extends ZIOHttpSpec { test("null string decode") { assert(Scheme.decode(null))(isNone) }, + test("decode chrome-extension") { + assertTrue(Scheme.decode("chrome-extension").isDefined) + }, ) } diff --git a/zio-http/src/test/scala/zio/http/URLSpec.scala b/zio-http/src/test/scala/zio/http/URLSpec.scala index 0d0ab387b8..4d3c7f2280 100644 --- a/zio-http/src/test/scala/zio/http/URLSpec.scala +++ b/zio-http/src/test/scala/zio/http/URLSpec.scala @@ -49,14 +49,14 @@ object URLSpec extends ZIOHttpSpec { ), suite("normalize")( test("adds leading slash") { - val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None) + val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None) val url2 = url.normalize assertTrue(extractPath(url2) == Path("/a/b/c")) }, test("deletes leading slash if there are no path segments") { - val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None) + val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None) val url2 = url.normalize assertTrue(extractPath(url2) == Path.empty) @@ -239,5 +239,87 @@ object URLSpec extends ZIOHttpSpec { assertZIO(result)(isLeft) }, ), + suite("relative resolution")( + // next ones are edge cases + test("absolute reference with relative base") { + val base = url"base/relative#basefrag" + val reference = url"https://reference/ignored/.././absolute#reffrag" + + // uses reference without dot segments + val expected = url"https://reference/absolute#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("absolute reference with absolute base") { + val base = url"https://base#basefrag" + val reference = url"https://reference/ignored/.././absolute#reffrag" + + // uses reference without dot segments + val expected = url"https://reference/absolute#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("relative reference with relative base") { + val base = url"base/relative" + val reference = url"reference/relative" + + val result = base.resolve(reference) + assertTrue(result.isLeft) + }, + + // remainder are main resolution logic - absolute base, relative reference + test("empty reference path without query params") { + val base = url"https://base/./ignored/../absolute?param=base#basefrag" + val reference = url"#reffrag" + + // uses unmodified base path and base query params + val expected = url"https://base/./ignored/../absolute?param=base#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("empty reference path with query params") { + val base = url"https://base/./ignored/../absolute?param=base#basefrag" + val reference = url"?param=reference#reffrag" + + // uses unmodified base path and reference query params + val expected = url"https://base/./ignored/../absolute?param=reference#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("non-empty reference path with a leading slash") { + val base = url"https://base/./ignored/../first/second?param=base#basefrag" + val reference = url"/reference/./ignored/../last?param=reference#reffrag" + + // uses reference path without dot segments and reference query params + val expected = url"https://base/reference/last?param=reference#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("non-empty reference path without a leading slash") { + val base = url"https://base/./ignored/../first/..?param=base#basefrag" + val reference = url"reference/./ignored/../last?param=reference#reffrag" + + // uses base path without last segment, reference segments appended, without dot segments, and reference query params + val expected = url"https://base/first/reference/last?param=reference#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + test("non-empty reference path without a leading slash and empty base path") { + val base = url"https://base?param=base#basefrag" + val reference = url"reference/./ignored/../last?param=reference#reffrag" + + // uses reference path without dot segments and a leading slash + val expected = url"https://base/reference/last?param=reference#reffrag" + + val result = base.resolve(reference) + assertTrue(result.contains(expected)) + }, + ), ) } diff --git a/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala b/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala index ac453659d1..394e92b461 100644 --- a/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala +++ b/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala @@ -25,7 +25,13 @@ import zio.http.URL.Location object ZClientAspectSpec extends ZIOHttpSpec { def extractStatus(response: Response): Status = response.status - val app: HttpApp[Any] = Handler.fromFunction[Request] { _ => Response.text("hello") }.toHttpApp + val app: HttpApp[Any] = { + Route.handled(Method.GET / "hello")(Handler.response(Response.text("hello"))) + }.toHttpApp + + val redir: HttpApp[Any] = { + Route.handled(Method.GET / "redirect")(Handler.response(Response.redirect(URL.empty / "hello"))) + }.toHttpApp override def spec: Spec[TestEnvironment with Scope, Any] = suite("ZClientAspect")( @@ -34,7 +40,7 @@ object ZClientAspectSpec extends ZIOHttpSpec { port <- Server.install(app) baseClient <- ZIO.service[Client] client = baseClient.url( - URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)), + URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))), ) @@ ZClientAspect.debug response <- client.request(Request.get(URL.empty / "hello")) output <- TestConsole.output @@ -51,7 +57,7 @@ object ZClientAspectSpec extends ZIOHttpSpec { baseClient <- ZIO.service[Client] client = baseClient .url( - URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)), + URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))), ) .disableStreaming @@ ZClientAspect.requestLogging( loggedRequestHeaders = Set(Header.UserAgent), @@ -78,6 +84,20 @@ object ZClientAspectSpec extends ZIOHttpSpec { annotations.head.contains("duration_ms"), ), ), + test("followRedirects")( + for { + port <- Server.install(redir ++ app) + baseClient <- ZIO.service[Client] + client = baseClient + .url( + URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))), + ) + .disableStreaming @@ ZClientAspect.followRedirects(2)((resp, message) => ZIO.logInfo(message).as(resp)) + response <- client.request(Request.get(URL.empty / "redirect")) + } yield assertTrue( + extractStatus(response) == Status.Ok, + ), + ), ).provide( ZLayer.succeed(Server.Config.default.onAnyOpenPort), Server.live, diff --git a/zio-http/src/test/scala/zio/http/codec/PathCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/PathCodecSpec.scala index 8773bf109a..68216f7d82 100644 --- a/zio-http/src/test/scala/zio/http/codec/PathCodecSpec.scala +++ b/zio-http/src/test/scala/zio/http/codec/PathCodecSpec.scala @@ -41,28 +41,28 @@ object PathCodecSpec extends ZIOHttpSpec { test("/users") { val codec = PathCodec.path("/users") - assertTrue(codec.segments.length == 2) + assertTrue(codec.segments.length == 1) }, test("/users/{user-id}/posts/{post-id}") { val codec = - PathCodec.path("/users") / SegmentCodec.int("user-id") / SegmentCodec.literal("posts") / SegmentCodec + PathCodec.path("/users") / PathCodec.int("user-id") / PathCodec.literal("posts") / PathCodec .string( "post-id", ) - assertTrue(codec.segments.length == 5) + assertTrue(codec.segments.length == 4) }, test("transformed") { val codec = PathCodec.path("/users") / - SegmentCodec.int("user-id").transform(UserId.apply)(_.value) / - SegmentCodec.literal("posts") / - SegmentCodec + PathCodec.int("user-id").transform(UserId.apply)(_.value) / + PathCodec.literal("posts") / + PathCodec .string("post-id") .transformOrFailLeft(s => Try(s.toInt).toEither.left.map(_ => "Not a number").map(n => PostId(n.toString)), )(_.value) - assertTrue(codec.segments.length == 5) + assertTrue(codec.segments.length == 4) }, ), suite("decoding")( @@ -86,14 +86,14 @@ object PathCodecSpec extends ZIOHttpSpec { assertTrue(codec.decode(Path("/users")) == Right(Path("/users"))) }, test("/users") { - val codec = PathCodec.empty / SegmentCodec.literal("users") + val codec = PathCodec.empty / PathCodec.literal("users") assertTrue(codec.decode(Path("/users")) == Right(())) && assertTrue(codec.decode(Path("/users/")) == Right(())) }, test("concat") { - val codec1 = PathCodec.empty / SegmentCodec.literal("users") / SegmentCodec.int("user-id") - val codec2 = PathCodec.empty / SegmentCodec.literal("posts") / SegmentCodec.string("post-id") + val codec1 = PathCodec.empty / PathCodec.literal("users") / PathCodec.int("user-id") + val codec2 = PathCodec.empty / PathCodec.literal("posts") / PathCodec.string("post-id") val codec = codec1 ++ codec2 @@ -102,9 +102,9 @@ object PathCodecSpec extends ZIOHttpSpec { test("transformed") { val codec = PathCodec.path("/users") / - SegmentCodec.int("user-id").transform(UserId.apply)(_.value) / - SegmentCodec.literal("posts") / - SegmentCodec + PathCodec.int("user-id").transform(UserId.apply)(_.value) / + PathCodec.literal("posts") / + PathCodec .string("post-id") .transformOrFailLeft(s => Try(s.toInt).toEither.left.map(_ => "Not a number").map(n => PostId(n.toString)), @@ -122,7 +122,7 @@ object PathCodecSpec extends ZIOHttpSpec { assertTrue(codec.segments == Chunk(SegmentCodec.empty)) }, test("/users") { - val codec = PathCodec.empty / SegmentCodec.literal("users") + val codec = PathCodec.empty / PathCodec.literal("users") assertTrue( codec.segments == @@ -137,24 +137,24 @@ object PathCodecSpec extends ZIOHttpSpec { assertTrue(codec.render == "") }, test("/users") { - val codec = PathCodec.empty / SegmentCodec.literal("users") + val codec = PathCodec.empty / PathCodec.literal("users") assertTrue(codec.render == "/users") }, test("/users/{user-id}/posts/{post-id}") { val codec = - PathCodec.empty / SegmentCodec.literal("users") / SegmentCodec.int("user-id") / SegmentCodec.literal( + PathCodec.empty / PathCodec.literal("users") / PathCodec.int("user-id") / PathCodec.literal( "posts", - ) / SegmentCodec.string("post-id") + ) / PathCodec.string("post-id") assertTrue(codec.render == "/users/{user-id}/posts/{post-id}") }, test("transformed") { val codec = PathCodec.path("/users") / - SegmentCodec.int("user-id").transform(UserId.apply)(_.value) / - SegmentCodec.literal("posts") / - SegmentCodec + PathCodec.int("user-id").transform(UserId.apply)(_.value) / + PathCodec.literal("posts") / + PathCodec .string("post-id") .transformOrFailLeft(s => Try(s.toInt).toEither.left.map(_ => "Not a number").map(n => PostId(n.toString)), diff --git a/zio-http/src/test/scala/zio/http/codec/RichTextCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/RichTextCodecSpec.scala index 50ed2a74d6..5ac325ae9f 100644 --- a/zio-http/src/test/scala/zio/http/codec/RichTextCodecSpec.scala +++ b/zio-http/src/test/scala/zio/http/codec/RichTextCodecSpec.scala @@ -250,6 +250,10 @@ object RichTextCodecSpec extends ZIOHttpSpec { assertTrue(success(123) == codec.decode("123--")) && assertTrue(codec.decode("4123").isLeft) }, + test("With error message") { + val codec = RichTextCodec.literal("123").withError("Not 123") + assertTrue(codec.decode("678") == Left("(Expected, but did not find: Paragraph(Code(“1”,Inline)), Not 123)")) + }, ), ) } diff --git a/zio-http/src/test/scala/zio/http/endpoint/openapi/JsonRendererSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/openapi/JsonRendererSpec.scala deleted file mode 100644 index c5dc49f1b9..0000000000 --- a/zio-http/src/test/scala/zio/http/endpoint/openapi/JsonRendererSpec.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package zio.http.endpoint.openapi - -import java.net.URI - -import scala.util.Try - -import zio.test._ - -import zio.http.codec.Doc -import zio.http.endpoint.openapi.OpenAPI.Parameter.{Definition, QueryParameter} -import zio.http.endpoint.openapi.OpenAPI.Schema.ResponseSchema -import zio.http.endpoint.openapi.OpenAPI.SecurityScheme.ApiKey -import zio.http.endpoint.openapi.OpenAPI.{Info, Operation, PathItem} -import zio.http.{Status, ZIOHttpSpec} - -object JsonRendererSpec extends ZIOHttpSpec { - case object Html - override def spec = - suite("JsonRenderer")( - test("render numbers") { - val rendered = - JsonRenderer.renderFields("int" -> 1, "double" -> 1.0d, "float" -> 1.0f, "long" -> 1L) - val expected = """{"int":1,"double":1.0,"float":1.0,"long":1}""" - assertTrue(rendered == expected) - }, - test("render strings") { - val rendered = JsonRenderer.renderFields("string" -> "string") - val expected = """{"string":"string"}""" - assertTrue(rendered == expected) - }, - test("render booleans") { - val rendered = JsonRenderer.renderFields("boolean" -> true) - val expected = """{"boolean":true}""" - assertTrue(rendered == expected) - }, - test("render tuples") { - val rendered = JsonRenderer.renderFields(("tuple", (1, "string"))) - val expected = """{"tuple":{"1":"string"}}""" - assertTrue(rendered == expected) - }, - test("render list") { - val rendered = JsonRenderer.renderFields("array" -> List(1, 2, 3)) - val expected = """{"array":[1,2,3]}""" - assertTrue(rendered == expected) - }, - test("render map") { - val rendered = - JsonRenderer.renderFields("map" -> Map("key" -> "value"), "otherMap" -> Map(1 -> "value")) - val expected = """{"map":{"key":"value"},"otherMap":{"1":"value"}}""" - assertTrue(rendered == expected) - }, - test("render In") { - val rendered = JsonRenderer.renderFields("type" -> ApiKey.In.Query) - val expected = """{"type":"query"}""" - assertTrue(rendered == expected) - }, - test("render empty doc") { - val rendered = JsonRenderer.renderFields("doc" -> Doc.empty) - val expected = """{"doc":""}""" - assertTrue(rendered == expected) - }, - test("render doc") { - val rendered = JsonRenderer.renderFields("doc" -> Doc.p(Doc.Span.link(new URI("https://google.com")))) - val expected = """{"doc":"W2h0dHBzOi8vZ29vZ2xlLmNvbV0oaHR0cHM6Ly9nb29nbGUuY29tKQoK"}""" - assertTrue(rendered == expected) - }, - test("render LiteralOrExpression") { - val rendered = JsonRenderer.renderFields( - "string" -> (OpenAPI.LiteralOrExpression.StringLiteral("string"): OpenAPI.LiteralOrExpression), - "number" -> (OpenAPI.LiteralOrExpression.NumberLiteral(1): OpenAPI.LiteralOrExpression), - "decimal" -> (OpenAPI.LiteralOrExpression.DecimalLiteral(1.0): OpenAPI.LiteralOrExpression), - "boolean" -> (OpenAPI.LiteralOrExpression.BooleanLiteral(true): OpenAPI.LiteralOrExpression), - "expression" -> OpenAPI.LiteralOrExpression.expression("expression"), - ) - val expected = """{"string":"string","number":1,"decimal":1.0,"boolean":true,"expression":"expression"}""" - assertTrue(rendered == expected) - }, - test("throw exception for duplicate keys") { - val rendered = Try(JsonRenderer.renderFields("key" -> 1, "key" -> 2)) - assertTrue(rendered.failed.toOption.exists(_.isInstanceOf[IllegalArgumentException])) - }, - test("render OpenAPI") { - val rendered = - OpenAPI - .OpenAPI( - info = Info( - title = "title", - version = "version", - description = Doc.p("description"), - termsOfService = new URI("https://google.com"), - contact = None, - license = None, - ), - servers = List(OpenAPI.Server(new URI("https://google.com"), Doc.p("description"), Map.empty)), - paths = Map( - OpenAPI.Path.fromString("/test").get -> PathItem( - get = Some( - Operation( - responses = Map( - Status.Ok -> OpenAPI.Response( - description = Doc.p(Doc.Span.text("description")), - content = Map( - "application/json" -> OpenAPI.MediaType( - schema = ResponseSchema( - discriminator = None, - xml = None, - externalDocs = new URI("https://google.com"), - example = "Example", - ), - examples = Map.empty, - encoding = Map.empty, - ), - ), - headers = Map.empty, - links = Map.empty, - ), - ), - tags = List("tag"), - summary = "summary", - description = Doc.p("description"), - externalDocs = Some(OpenAPI.ExternalDoc(None, new URI("https://google.com"))), - operationId = Some("operationId"), - parameters = Set( - QueryParameter( - "name", - Doc.p("description"), - definition = Definition.Content("key", "mediaType"), - examples = Map.empty, - ), - ), - servers = List(OpenAPI.Server(new URI("https://google.com"), Doc.p("description"), Map.empty)), - requestBody = None, - callbacks = Map.empty, - security = List.empty, - ), - ), - ref = "ref", - description = Doc.p("description"), - put = None, - post = None, - delete = None, - options = None, - head = None, - patch = None, - trace = None, - servers = List.empty, - parameters = Set.empty, - ), - ), - components = Some( - OpenAPI.Components( - schemas = Map.empty, - responses = Map.empty, - parameters = Map.empty, - examples = Map.empty, - requestBodies = Map.empty, - headers = Map.empty, - securitySchemes = Map.empty, - links = Map.empty, - callbacks = Map.empty, - ), - ), - security = List.empty, - tags = List.empty, - externalDocs = Some(OpenAPI.ExternalDoc(None, new URI("https://google.com"))), - openapi = "3.0.0", - ) - .toJson - - val expected = - """{"openapi":"3.0.0","info":{"title":"title","description":"ZGVzY3JpcHRpb24KCg==","termsOfService":"https://google.com","version":"version"},"servers":[{"url":"https://google.com","description":"ZGVzY3JpcHRpb24KCg==","variables":{}}],"paths":{"/test":{"$ref":"ref","summary":"","description":"ZGVzY3JpcHRpb24KCg==","get":{"tags":["tag"],"summary":"summary","description":"ZGVzY3JpcHRpb24KCg==","externalDocs":{"url":"https://google.com"},"operationId":"operationId","parameters":[{"name":"name","in":"query","description":"ZGVzY3JpcHRpb24KCg==","required":true,"deprecated":false,"allowEmptyValue":false,"definition":{"key":"key","mediaType":"mediaType"},"explode":true,"examples":{}}],"responses":{"200":{"description":"ZGVzY3JpcHRpb24KCg==","headers":{},"content":{"application/json":{"schema":{"nullable":false,"readOnly":true,"writeOnly":false,"externalDocs":"https://google.com","example":"Example","deprecated":false},"examples":{},"encoding":{}}},"links":{}}},"callbacks":{},"deprecated":false,"security":[],"servers":[{"url":"https://google.com","description":"ZGVzY3JpcHRpb24KCg==","variables":{}}]},"servers":[],"parameters":[]}},"components":{"schemas":{},"responses":{},"parameters":{},"examples":{},"requestBodies":{},"headers":{},"securitySchemes":{},"links":{},"callbacks":{}},"security":[],"tags":[],"externalDocs":{"url":"https://google.com"}}""" - assertTrue(rendered == expected) - }, - ) -} diff --git a/zio-http/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala new file mode 100644 index 0000000000..6ed10807e4 --- /dev/null +++ b/zio-http/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -0,0 +1,2362 @@ +package zio.http.endpoint.openapi + +import zio.json.ast.Json +import zio.json.{EncoderOps, JsonEncoder} +import zio.test._ +import zio.{Scope, ZIO} + +import zio.schema.annotation.{caseName, discriminatorName, noDiscriminator, optionalField, transientField} +import zio.schema.codec.JsonCodec +import zio.schema.{DeriveSchema, Schema} + +import zio.http.Method.{GET, POST} +import zio.http._ +import zio.http.codec.{Doc, HttpCodec, QueryCodec} +import zio.http.endpoint._ + +object OpenAPIGenSpec extends ZIOSpecDefault { + + final case class SimpleInputBody(name: String, age: Int) + implicit val simpleInputBodySchema: Schema[SimpleInputBody] = + DeriveSchema.gen[SimpleInputBody] + final case class OtherSimpleInputBody(fullName: String, shoeSize: Int) + implicit val otherSimpleInputBodySchema: Schema[OtherSimpleInputBody] = + DeriveSchema.gen[OtherSimpleInputBody] + final case class SimpleOutputBody(userName: String, score: Int) + implicit val simpleOutputBodySchema: Schema[SimpleOutputBody] = + DeriveSchema.gen[SimpleOutputBody] + final case class NotFoundError(message: String) + implicit val notFoundErrorSchema: Schema[NotFoundError] = + DeriveSchema.gen[NotFoundError] + final case class ImageMetadata(name: String, size: Int) + implicit val imageMetadataSchema: Schema[ImageMetadata] = + DeriveSchema.gen[ImageMetadata] + + final case class WithTransientField(name: String, @transientField age: Int) + implicit val withTransientFieldSchema: Schema[WithTransientField] = + DeriveSchema.gen[WithTransientField] + + final case class WithDefaultValue(age: Int = 42) + implicit val withDefaultValueSchema: Schema[WithDefaultValue] = + DeriveSchema.gen[WithDefaultValue] + final case class WithComplexDefaultValue(data: ImageMetadata = ImageMetadata("default", 42)) + implicit val withDefaultComplexValueSchema: Schema[WithComplexDefaultValue] = + DeriveSchema.gen[WithComplexDefaultValue] + + final case class WithOptionalField(name: String, @optionalField age: Int) + implicit val withOptionalFieldSchema: Schema[WithOptionalField] = + DeriveSchema.gen[WithOptionalField] + + final case class NestedProduct(imageMetadata: ImageMetadata, withOptionalField: WithOptionalField) + implicit val nestedProductSchema: Schema[NestedProduct] = + DeriveSchema.gen[NestedProduct] + + sealed trait SimpleEnum + object SimpleEnum { + implicit val schema: Schema[SimpleEnum] = DeriveSchema.gen[SimpleEnum] + case object One extends SimpleEnum + case object Two extends SimpleEnum + case object Three extends SimpleEnum + } + + sealed trait SealedTraitDefaultDiscriminator + + object SealedTraitDefaultDiscriminator { + implicit val schema: Schema[SealedTraitDefaultDiscriminator] = + DeriveSchema.gen[SealedTraitDefaultDiscriminator] + + case object One extends SealedTraitDefaultDiscriminator + + case class Two(name: String) extends SealedTraitDefaultDiscriminator + + @caseName("three") + case class Three(name: String) extends SealedTraitDefaultDiscriminator + } + + @discriminatorName("type") + sealed trait SealedTraitCustomDiscriminator + + object SealedTraitCustomDiscriminator { + implicit val schema: Schema[SealedTraitCustomDiscriminator] = DeriveSchema.gen[SealedTraitCustomDiscriminator] + + case object One extends SealedTraitCustomDiscriminator + + case class Two(name: String) extends SealedTraitCustomDiscriminator + + @caseName("three") + case class Three(name: String) extends SealedTraitCustomDiscriminator + } + + @noDiscriminator + sealed trait SealedTraitNoDiscriminator + + object SealedTraitNoDiscriminator { + implicit val schema: Schema[SealedTraitNoDiscriminator] = DeriveSchema.gen[SealedTraitNoDiscriminator] + + case object One extends SealedTraitNoDiscriminator + + case class Two(name: String) extends SealedTraitNoDiscriminator + + @caseName("three") + case class Three(name: String) extends SealedTraitNoDiscriminator + } + + @noDiscriminator + sealed trait SimpleNestedSealedTrait + + object SimpleNestedSealedTrait { + implicit val schema: Schema[SimpleNestedSealedTrait] = DeriveSchema.gen[SimpleNestedSealedTrait] + + case object NestedOne extends SimpleNestedSealedTrait + + case class NestedTwo(name: SealedTraitNoDiscriminator) extends SimpleNestedSealedTrait + + case class NestedThree(name: String) extends SimpleNestedSealedTrait + } + + private val simpleEndpoint = + Endpoint( + (GET / "static" / int("id") / uuid("uuid") ?? Doc.p("user id") / string("name")) ?? Doc.p("get path"), + ) + .in[SimpleInputBody](Doc.p("input body")) + .out[SimpleOutputBody](Doc.p("output body")) + .outError[NotFoundError](Status.NotFound, Doc.p("not found")) + + private val queryParamEndpoint = + Endpoint(GET / "withQuery") + .in[SimpleInputBody] + .query(QueryCodec.paramStr("query")) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + + private val alternativeInputEndpoint = + Endpoint(GET / "inputAlternative") + .inCodec( + (HttpCodec.content[OtherSimpleInputBody] ?? Doc.p("other input") | HttpCodec + .content[SimpleInputBody] ?? Doc.p("simple input")) ?? Doc.p("takes either of the two input bodies"), + ) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + + def toJsonAst(str: String): Json = + Json.decoder.decodeJson(str).toOption.get + + def toJsonAst(api: OpenAPI): Json = + toJsonAst(api.toJson) + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("OpenAPIGenSpec")( + test("simple endpoint to OpenAPI") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", simpleEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static/{id}/{uuid}/{name}" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "id", + | "in" : "path", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "integer", + | "format" : "int32" + | }, + | "explode" : false, + | "style" : "simple" + | }, + | + | { + | "name" : "uuid", + | "in" : "path", + | "description" : "user id\n\n", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : "string", + | "format" : "uuid" + | }, + | "explode" : false, + | "style" : "simple" + | }, + | + | { + | "name" : "name", + | "in" : "path", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "string" + | }, + | "explode" : false, + | "style" : "simple" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref": "#/components/schemas/SimpleInputBody", + | "description" : "input body\n\n" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref": "#/components/schemas/SimpleOutputBody", + | "description" : "output body\n\n" + | } + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref": "#/components/schemas/NotFoundError", + | "description" : "not found\n\n" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with query parameter") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withQuery" : { + | "get" : { + | "parameters" : [ + | { + | "name" : "query", + | "in" : "query", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "string" + | }, + | "explode" : false, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/SimpleInputBody"} + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/SimpleOutputBody"} + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/NotFoundError"} + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("alternative input") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", alternativeInputEndpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/inputAlternative" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : { + | "anyOf" : [ + | { + | "$ref": "#/components/schemas/OtherSimpleInputBody", + | "description" : "other input\n\n" + | }, + | { + | "$ref": "#/components/schemas/SimpleInputBody", + | "description" : "simple input\n\n" + | } + | ], + | "description" : "takes either of the two input bodies\n\n" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/SimpleOutputBody"} + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/NotFoundError"} + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "OtherSimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "fullName" : { + | "type" : + | "string" + | }, + | "shoeSize" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "fullName", + | "shoeSize" + | ] + | }, + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("alternative output") { + val endpoint = + Endpoint(GET / "static") + .in[SimpleInputBody] + .outCodec( + (HttpCodec.content[SimpleOutputBody] ?? Doc.p("simple output") | HttpCodec + .content[NotFoundError] ?? Doc.p("not found")) ?? Doc.p("alternative outputs"), + ) + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : {"$ref": "#/components/schemas/SimpleInputBody"} + | } + | }, + | "required" : true + | }, + | "responses" : { + | "default" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : { "anyOf" : [ + | { + | "$ref": "#/components/schemas/SimpleOutputBody", + | "description" : "simple output\n\n" + | }, + | { + | "$ref": "#/components/schemas/NotFoundError", + | "description" : "not found\n\n" + | } + | ], + | "description" : "alternative outputs\n\n" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with examples") { + val endpoint = + Endpoint(GET / "static") + .inCodec( + HttpCodec + .content[SimpleInputBody] + .examples("john" -> SimpleInputBody("John", 42), "jane" -> SimpleInputBody("Jane", 43)), + ) + .outCodec( + HttpCodec + .content[SimpleOutputBody] + .examples("john" -> SimpleOutputBody("John", 42), "jane" -> SimpleOutputBody("Jane", 43)) | + HttpCodec + .content[NotFoundError] + .examples("not found" -> NotFoundError("not found")), + ) + + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleInputBody" + | }, + | "examples" : { + | "john" : + | { + | "value" : { + | "name" : "John", + | "age" : 42 + | } + | }, + | "jane" : + | { + | "value" : { + | "name" : "Jane", + | "age" : 43 + | } + | } + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "default" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "anyOf" : [ + | { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | }, + | { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | ], + | "description" : "" + | }, + | "examples" : { + | "john" : + | { + | "value" : { + | "userName" : "John", + | "score" : 42 + | } + | }, + | "jane" : + | { + | "value" : { + | "userName" : "Jane", + | "score" : 43 + | } + | }, + | "not found" : + | { + | "value" : { + | "message" : "not found" + | } + | } + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with query parameter, alternative input, alternative output and examples") { + val endpoint = + Endpoint(GET / "static") + .inCodec( + HttpCodec + .content[OtherSimpleInputBody] ?? Doc.p("other input") | + HttpCodec + .content[SimpleInputBody] ?? Doc.p("simple input"), + ) + .query(QueryCodec.paramStr("query")) + .outCodec( + HttpCodec + .content[SimpleOutputBody] ?? Doc.p("simple output") | + HttpCodec + .content[NotFoundError] ?? Doc.p("not found"), + ) + + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "query", + | "in" : "query", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "string" + | }, + | "explode" : false, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "anyOf" : [ + | { + | "$ref" : "#/components/schemas/OtherSimpleInputBody", + | "description" : "other input\n\n" + | }, + | { + | "$ref" : "#/components/schemas/SimpleInputBody", + | "description" : "simple input\n\n" + | } + | ], + | "description" : "" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "default" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "anyOf" : [ + | { + | "$ref" : "#/components/schemas/SimpleOutputBody", + | "description" : "simple output\n\n" + | }, + | { + | "$ref" : "#/components/schemas/NotFoundError", + | "description" : "not found\n\n" + | } + | ], + | "description" : "" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "OtherSimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "fullName" : { + | "type" : + | "string" + | }, + | "shoeSize" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "fullName", + | "shoeSize" + | ] + | }, + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("multipart") { + val endpoint = Endpoint(GET / "test-form") + .outCodec( + (HttpCodec.contentStream[Byte]("image", MediaType.image.png) ++ + HttpCodec.content[String]("title").optional) ?? Doc.p("Test doc") ++ + HttpCodec.content[Int]("width") ++ + HttpCodec.content[Int]("height") ++ + HttpCodec.content[ImageMetadata]("metadata"), + ) + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/test-form" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "type" : + | "null" + | } + | } + | }, + | "required" : false + | }, + | "responses" : { + | "default" : + | { + | "description" : "", + | "content" : { + | "multipart/form-data" : { + | "schema" : + | { + | "type" : + | "object", + | "properties" : { + | "image" : { + | "type" : + | "string", + | "contentEncoding" : "binary", + | "contentMediaType" : "application/octet-stream" + | }, + | "height" : { + | "type" : + | "integer", + | "format" : "int32" + | }, + | "metadata" : { + | "$ref" : "#/components/schemas/ImageMetadata" + | }, + | "title" : { + | "type" : + | [ + | "string", + | "null" + | ] + | }, + | "width" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | false, + | "required" : [ + | "image", + | "width", + | "height", + | "metadata" + | ], + | "description" : "Test doc\n\n" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "ImageMetadata" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "size" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "size" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("multiple endpoint definitions") { + val generated = + OpenAPIGen.fromEndpoints( + "Simple Endpoint", + "1.0", + simpleEndpoint, + queryParamEndpoint, + alternativeInputEndpoint, + ) + val json = toJsonAst(generated) + val expected = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static/{id}/{uuid}/{name}" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "id", + | "in" : "path", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "integer", + | "format" : "int32" + | }, + | "explode" : false, + | "style" : "simple" + | }, + | + | { + | "name" : "uuid", + | "in" : "path", + | "description" : "user id\n\n", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : "string", + | "format" : "uuid" + | }, + | "explode" : false, + | "style" : "simple" + | }, + | + | { + | "name" : "name", + | "in" : "path", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "string" + | }, + | "explode" : false, + | "style" : "simple" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleInputBody", + | "description" : "input body\n\n" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleOutputBody", + | "description" : "output body\n\n" + | } + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NotFoundError", + | "description" : "not found\n\n" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | }, + | "/withQuery" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "query", + | "in" : "query", + | "required" : true, + | "deprecated" : false, + | "schema" : + | { + | "type" : + | "string" + | }, + | "explode" : false, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | }, + | "/inputAlternative" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "anyOf" : [ + | { + | "$ref" : "#/components/schemas/OtherSimpleInputBody", + | "description" : "other input\n\n" + | }, + | { + | "$ref" : "#/components/schemas/SimpleInputBody", + | "description" : "simple input\n\n" + | } + | ], + | "description" : "takes either of the two input bodies\n\n" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : + | { + | "description" : "", + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "message" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "userName", + | "score" + | ] + | }, + | "OtherSimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "fullName" : { + | "type" : + | "string" + | }, + | "shoeSize" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "fullName", + | "shoeSize" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("transient field") { + val endpoint = Endpoint(GET / "static").in[WithTransientField] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/WithTransientField" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "WithTransientField" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("primitive default value") { + val endpoint = Endpoint(GET / "static").in[WithDefaultValue] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/WithDefaultValue" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "WithDefaultValue" : + | { + | "type" : + | "object", + | "properties" : { + | "age" : { + | "type" : + | "integer", + | "format" : "int32", + | "description" : "If not set, this field defaults to the value of the default annotation.", + | "default" : 42 + | } + | }, + | "additionalProperties" : + | true + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("complex default value") { + val endpoint = Endpoint(GET / "static").in[WithComplexDefaultValue] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/WithComplexDefaultValue" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "ImageMetadata" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "size" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "size" + | ] + | }, + | "WithComplexDefaultValue" : + | { + | "type" : + | "object", + | "properties" : { + | "data" : { + | "$ref" : "#/components/schemas/ImageMetadata", + | "description" : "If not set, this field defaults to the value of the default annotation.", + | "default" : { + | "name" : "default", + | "size" : 42 + | } + | } + | }, + | "additionalProperties" : + | true + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("optional field") { + val endpoint = Endpoint(GET / "static").in[WithOptionalField] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/WithOptionalField" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "WithOptionalField" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("nested product") { + val endpoint = Endpoint(GET / "static").in[NestedProduct] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NestedProduct" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "ImageMetadata" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "size" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name", + | "size" + | ] + | }, + | "WithOptionalField" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "NestedProduct" : + | { + | "type" : + | "object", + | "properties" : { + | "imageMetadata" : { + | "$ref" : "#/components/schemas/ImageMetadata" + | }, + | "withOptionalField" : { + | "$ref" : "#/components/schemas/WithOptionalField" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "imageMetadata", + | "withOptionalField" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("enum") { + val endpoint = Endpoint(GET / "static").in[SimpleEnum] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleEnum" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SimpleEnum" : + | { + | "type" : + | "string", + | "enumValues" : [ + | "One", + | "Two", + | "Three" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("sealed trait default discriminator") { + val endpoint = Endpoint(GET / "static").in[SealedTraitDefaultDiscriminator] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SealedTraitDefaultDiscriminator" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "One" : + | { + | "type" : + | "object", + | "properties" : {}, + | "additionalProperties" : + | true + | }, + | "Two" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "Three" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "SealedTraitDefaultDiscriminator" : + | { + | "oneOf" : [ + | { + | "type" : + | "object", + | "properties" : { + | "One" : { + | "$ref" : "#/components/schemas/One" + | } + | }, + | "additionalProperties" : + | false, + | "required" : [ + | "One" + | ] + | }, + | { + | "type" : + | "object", + | "properties" : { + | "Two" : { + | "$ref" : "#/components/schemas/Two" + | } + | }, + | "additionalProperties" : + | false, + | "required" : [ + | "Two" + | ] + | }, + | { + | "type" : + | "object", + | "properties" : { + | "three" : { + | "$ref" : "#/components/schemas/Three" + | } + | }, + | "additionalProperties" : + | false, + | "required" : [ + | "three" + | ] + | } + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("sealed trait custom discriminator") { + val endpoint = Endpoint(GET / "static").in[SealedTraitCustomDiscriminator] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SealedTraitCustomDiscriminator" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "One" : + | { + | "type" : + | "object", + | "properties" : {}, + | "additionalProperties" : + | true + | }, + | "Two" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "Three" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "SealedTraitCustomDiscriminator" : + | { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/One" + | }, + | { + | "$ref" : "#/components/schemas/Two" + | }, + | { + | "$ref" : "#/components/schemas/Three" + | } + | ], + | "discriminator" : { + | "propertyName" : "type", + | "mapping" : { + | "One" : "#/components/schemas/One}", + | "Two" : "#/components/schemas/Two}", + | "three" : "#/components/schemas/Three}" + | } + | } + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("sealed trait no discriminator") { + val endpoint = Endpoint(GET / "static").in[SealedTraitNoDiscriminator] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expected = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SealedTraitNoDiscriminator" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "One" : + | { + | "type" : + | "object", + | "properties" : {}, + | "additionalProperties" : + | true + | }, + | "Two" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "Three" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "SealedTraitNoDiscriminator" : + | { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/One" + | }, + | { + | "$ref" : "#/components/schemas/Two" + | }, + | { + | "$ref" : "#/components/schemas/Three" + | } + | ] + | } + | } + | } + |} + |""".stripMargin + assertTrue(json == toJsonAst(expected)) + }, + test("sealed trait with nested sealed trait") { + val endpoint = Endpoint(GET / "static").in[SimpleNestedSealedTrait] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleNestedSealedTrait" + | } + | } + | }, + | "required" : true + | }, + | "deprecated" : false + | } + | } + | }, + | "components" : { + | "schemas" : { + | "SealedTraitNoDiscriminator" : + | { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/One" + | }, + | { + | "$ref" : "#/components/schemas/Two" + | }, + | { + | "$ref" : "#/components/schemas/Three" + | } + | ] + | }, + | "NestedOne" : + | { + | "type" : + | "object", + | "properties" : {}, + | "additionalProperties" : + | true + | }, + | "NestedThree" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "NestedTwo" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "$ref" : "#/components/schemas/SealedTraitNoDiscriminator" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "Two" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "Three" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "additionalProperties" : + | true, + | "required" : [ + | "name" + | ] + | }, + | "One" : + | { + | "type" : + | "object", + | "properties" : {}, + | "additionalProperties" : + | true + | }, + | "SimpleNestedSealedTrait" : + | { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/NestedOne" + | }, + | { + | "$ref" : "#/components/schemas/NestedTwo" + | }, + | { + | "$ref" : "#/components/schemas/NestedThree" + | } + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("multiple methods on same path") { + val getEndpoint = Endpoint(GET / "test") + .out[String](MediaType.text.`plain`) + val postEndpoint = Endpoint(POST / "test") + .in[String] + .out[String](Status.Created, MediaType.text.`plain`) + val generated = OpenAPIGen.fromEndpoints( + "Multiple Methods on Same Path", + "1.0", + getEndpoint, + postEndpoint, + ) + val json = toJsonAst(generated) + for { + expectedJson <- ZIO.acquireReleaseWith( + ZIO.attemptBlockingIO(scala.io.Source.fromResource("endpoint/openapi/multiple-methods-on-same-path.json")), + )(buf => ZIO.attemptBlockingIO(buf.close()).orDie)(buf => ZIO.attemptBlockingIO(buf.mkString)) + } yield assertTrue(json == toJsonAst(expectedJson)) + }, + ) + +} diff --git a/zio-http/src/test/scala/zio/http/endpoint/openapi/SwaggerUISpec.scala b/zio-http/src/test/scala/zio/http/endpoint/openapi/SwaggerUISpec.scala new file mode 100644 index 0000000000..a7bfbb7f61 --- /dev/null +++ b/zio-http/src/test/scala/zio/http/endpoint/openapi/SwaggerUISpec.scala @@ -0,0 +1,67 @@ +package zio.http.endpoint.openapi + +import zio._ +import zio.test._ + +import zio.http._ +import zio.http.codec.HttpCodec.query +import zio.http.codec.PathCodec.path +import zio.http.endpoint.Endpoint + +object SwaggerUISpec extends ZIOSpecDefault { + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("SwaggerUI")( + test("should return the swagger ui page") { + val getUser = Endpoint(Method.GET / "users" / int("userId")).out[Int] + + val getUserRoute = getUser.implement { Handler.fromFunction[Int] { id => id } } + + val getUserPosts = + Endpoint(Method.GET / "users" / int("userId") / "posts" / int("postId")) + .query(query("name")) + .out[List[String]] + + val getUserPostsRoute = + getUserPosts.implement[Any] { + Handler.fromFunctionZIO[(Int, Int, String)] { case (id1: Int, id2: Int, query: String) => + ZIO.succeed(List(s"API2 RESULT parsed: users/$id1/posts/$id2?name=$query")) + } + } + + val openAPIv1 = OpenAPIGen.fromEndpoints(title = "Endpoint Example", version = "1.0", getUser, getUserPosts) + val openAPIv2 = + OpenAPIGen.fromEndpoints(title = "Another Endpoint Example", version = "2.0", getUser, getUserPosts) + + val routes = + Routes(getUserRoute, getUserPostsRoute) ++ SwaggerUI.routes("docs" / "openapi", openAPIv1, openAPIv2) + + val response = routes.apply(Request(method = Method.GET, url = url"/docs/openapi")) + + val expectedHtml = + """SwaggerUI
""".stripMargin + + for { + res <- response + body <- res.body.asString + } yield { + assertTrue(body == expectedHtml) + } + }, + ) +} diff --git a/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala b/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala index 6727dd6b2c..935aff670b 100644 --- a/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala +++ b/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala @@ -43,6 +43,7 @@ object OriginSpec extends ZIOHttpSpec { assertTrue( Origin.parse("http://domain") == Right(Value("http", "domain", None)), Origin.parse("https://domain") == Right(Value("https", "domain", None)), + Origin.parse("chrome-extension://appid") == Right(Value("chrome-extension", "appid", None)), ) }, test("parsing of valid Origin values") { diff --git a/zio-http/src/test/scala/zio/http/headers/WarningSpec.scala b/zio-http/src/test/scala/zio/http/headers/WarningSpec.scala index a74850dd7a..ba0caec502 100644 --- a/zio-http/src/test/scala/zio/http/headers/WarningSpec.scala +++ b/zio-http/src/test/scala/zio/http/headers/WarningSpec.scala @@ -58,11 +58,11 @@ object WarningSpec extends ZIOHttpSpec { }, test("Accepts Valid Warning with Date") { assertTrue( - Warning.parse(validWarningWithDate) == Right(Warning(112, "-", "\"cache down\"", Some(stubDate))), + Warning.parse(validWarningWithDate) == Right(Warning(112, "-", "cache down", Some(stubDate))), ) }, test("Accepts Valid Warning without Date") { - assertTrue(Warning.parse(validWarning) == Right(Warning(110, "anderson/1.3.37", "\"Response is stale\""))) + assertTrue(Warning.parse(validWarning) == Right(Warning(110, "anderson/1.3.37", "Response is stale"))) }, test("parsing and encoding is symmetrical for warning with Date") { val encodedWarningwithDate = Warning.render(Warning.parse(validWarningWithDate).toOption.get) diff --git a/zio-http/src/test/scala/zio/http/internal/HttpGen.scala b/zio-http/src/test/scala/zio/http/internal/HttpGen.scala index 44079d075b..1ba2149702 100644 --- a/zio-http/src/test/scala/zio/http/internal/HttpGen.scala +++ b/zio-http/src/test/scala/zio/http/internal/HttpGen.scala @@ -70,7 +70,7 @@ object HttpGen { scheme <- Gen.fromIterable(List(Scheme.HTTP, Scheme.HTTPS)) host <- Gen.alphaNumericStringBounded(1, 5) port <- Gen.oneOf(Gen.const(80), Gen.const(443), Gen.int(0, 65536)) - } yield URL.Location.Absolute(scheme, host, port) + } yield URL.Location.Absolute(scheme, host, Some(port)) def genRelativeURL: Gen[Any, URL] = for { path <- HttpGen.anyPath diff --git a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala index 817bc8e395..dc80560107 100644 --- a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala +++ b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala @@ -49,7 +49,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self => client( params .addHeader(DynamicServer.APP_ID, id) - .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))), + .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))), ) .flatMap(_.collect) } @@ -80,7 +80,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self => client( params .addHeader(DynamicServer.APP_ID, id) - .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))), + .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))), ) } } yield response diff --git a/zio-http/src/test/scala/zio/http/template/HtmlSpec.scala b/zio-http/src/test/scala/zio/http/template/HtmlSpec.scala index 26d49da018..b6338a3721 100644 --- a/zio-http/src/test/scala/zio/http/template/HtmlSpec.scala +++ b/zio-http/src/test/scala/zio/http/template/HtmlSpec.scala @@ -75,6 +75,11 @@ case object HtmlSpec extends ZIOHttpSpec { assert(none.encode)(equalTo("")) }, ), + test("explicitly constructed script tag is not escaped") { + val view = script("alert('Hello!');") + val expected = """""" + assert(view.encode)(equalTo(expected.stripMargin)) + }, ) } }