From 97dc4ae9117f5b801bf88188810a410e5d401eff Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Sat, 22 Jun 2024 18:10:47 +0200 Subject: [PATCH] Select codec based on response status for endpoint client (#2727) (#2929) --- .../zio/http/endpoint/RoundtripSpec.scala | 103 +----------------- .../main/scala/zio/http/codec/HttpCodec.scala | 26 ++++- .../zio/http/endpoint/EndpointExecutor.scala | 2 +- .../endpoint/internal/EndpointClient.scala | 26 +---- 4 files changed, 34 insertions(+), 123 deletions(-) diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala index e8429f6a7f..75339f790a 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala @@ -346,91 +346,6 @@ object RoundtripSpec extends ZIOHttpSpec { "42", ) }, - test("middleware error returned") { - - val alwaysFailingMiddleware = EndpointMiddleware( - authorization, - HttpCodec.empty, - HttpCodec.error[String](Status.Custom(900)), - ) - - val endpoint = - Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware - - val endpointRoute = - endpoint.implementHandler(Handler.identity) - - val routes = endpointRoute.toRoutes - - val app = routes @@ alwaysFailingMiddleware - .implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit) - - for { - port <- Server.install(app) - executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass")))) - - out <- ZIO - .serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor => - executor.apply(endpoint.apply(42)) - } - .provideSome[Client & Scope](executorLayer) - .flip - } yield assert(out)(equalTo("FAIL")) - }, - test("failed middleware deserialization") { - val alwaysFailingMiddleware = EndpointMiddleware( - authorization, - HttpCodec.empty, - HttpCodec.error[String](Status.Custom(900)), - ) - - val endpoint = - Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware - - val alwaysFailingMiddlewareWithAnotherSignature = EndpointMiddleware( - authorization, - HttpCodec.empty, - HttpCodec.error[Long](Status.Custom(900)), - ) - - val endpointWithAnotherSignature = - Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddlewareWithAnotherSignature - - val endpointRoute = - endpoint.implementHandler(Handler.identity) - - val routes = endpointRoute.toRoutes - - val app = routes @@ alwaysFailingMiddleware.implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit) - - for { - port <- Server.install(app) - executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass")))) - - cause <- ZIO - .serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor => - executor.apply(endpointWithAnotherSignature.apply(42)) - } - .provideSome[Client with Scope](executorLayer) - .cause - } yield assert(cause.prettyPrint)( - containsString( - "java.lang.IllegalStateException: Cannot deserialize using endpoint error codec", - ), - ) && assert(cause.prettyPrint)( - containsString( - "java.lang.IllegalStateException: Cannot deserialize using middleware error codec", - ), - ) && assert(cause.prettyPrint)( - containsString( - "Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.", - ), - ) && assert(cause.prettyPrint)( - containsString( - "Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected a number, got F)", - ), - ) - }, test("Failed endpoint deserialization") { val endpoint = Endpoint(GET / "users" / int("userId")).out[Int].outError[Int](Status.Custom(999)) @@ -457,21 +372,9 @@ object RoundtripSpec extends ZIOHttpSpec { } .provideSome[Client with Scope](executorLayer) .cause - } yield assert(cause.prettyPrint)( - containsString( - "java.lang.IllegalStateException: Cannot deserialize using endpoint error codec", - ), - ) && assert(cause.prettyPrint)( - containsString( - "java.lang.IllegalStateException: Cannot deserialize using middleware error codec", - ), - ) && assert(cause.prettyPrint)( - containsString( - "Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.", - ), - ) && assert(cause.prettyPrint)( - containsString( - """Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""", + } yield assertTrue( + cause.prettyPrint.contains( + """zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""", ), ) }, diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 7669c069dd..f595e04d79 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -29,7 +29,7 @@ import zio.schema.Schema import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ import zio.http.codec.HttpCodec.{Annotated, Metadata} -import zio.http.codec.internal.EncoderDecoder +import zio.http.codec.internal.{AtomizedCodecs, EncoderDecoder} /** * A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP @@ -48,6 +48,27 @@ sealed trait HttpCodec[-AtomTypes, Value] { private lazy val encoderDecoder: EncoderDecoder[AtomTypes, Value] = EncoderDecoder(self) + private def statusCodecs: Chunk[SimpleCodec[Status, _]] = + self.asInstanceOf[HttpCodec[_, _]] match { + case HttpCodec.Fallback(left, right, _, _) => left.statusCodecs ++ right.statusCodecs + case HttpCodec.Combine(left, right, _) => left.statusCodecs ++ right.statusCodecs + case HttpCodec.Annotated(codec, _) => codec.statusCodecs + case HttpCodec.TransformOrFail(codec, _, _) => codec.statusCodecs + case HttpCodec.Empty => Chunk.empty + case HttpCodec.Halt => Chunk.empty + case atom: HttpCodec.Atom[_, _] => + atom match { + case HttpCodec.Status(codec, _) => Chunk.single(codec) + case _ => Chunk.empty + } + } + + private lazy val statusCodes: Set[Status] = statusCodecs.collect { case SimpleCodec.Specified(status) => + status + }.toSet + + private lazy val matchesAnyStatus: Boolean = statusCodecs.contains(SimpleCodec.Unspecified[Status]()) + /** * Returns a new codec that is the same as this one, but has attached docs, * which will render whenever docs are generated from the codec. @@ -238,6 +259,9 @@ sealed trait HttpCodec[-AtomTypes, Value] { else Left(s"Expected ${expected} but found ${actual}"), )(_ => expected) + private[http] def matchesStatus(status: Status) = + matchesAnyStatus || statusCodes.contains(status) + def named(name: String): HttpCodec[AtomTypes, Value] = HttpCodec.Annotated(self, Metadata.Named(name)) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/EndpointExecutor.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/EndpointExecutor.scala index 4b2f03b2ac..6e13648951 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/EndpointExecutor.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/EndpointExecutor.scala @@ -63,7 +63,7 @@ final case class EndpointExecutor[+MI]( alt: Alternator[E, invocation.middleware.Err], ev: MI <:< invocation.middleware.In, trace: Trace, - ): ZIO[Scope, alt.Out, B] = { + ): ZIO[Scope, E, B] = { middlewareInput.flatMap { mi => getClient(invocation.endpoint).orDie.flatMap { endpointClient => endpointClient.execute(client, invocation)(ev(mi)) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/internal/EndpointClient.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/internal/EndpointClient.scala index ffaf016f07..9f30b2b7fb 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/internal/EndpointClient.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/internal/EndpointClient.scala @@ -29,7 +29,7 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl ) { def execute(client: Client, invocation: Invocation[P, I, E, O, M])( mi: invocation.middleware.In, - )(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, alt.Out, O] = { + )(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, E, O] = { val request0 = endpoint.input.encodeRequest(invocation.input) val request = request0.copy(url = endpointRoot ++ request0.url) @@ -44,28 +44,12 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl ) client.request(withDefaultAcceptHeader).orDie.flatMap { response => - if (response.status.isSuccess) { + if (endpoint.output.matchesStatus(response.status)) { endpoint.output.decodeResponse(response).orDie + } else if (endpoint.error.matchesStatus(response.status)) { + endpoint.error.decodeResponse(response).orDie.flip } else { - // Preferentially decode an error from the handler, before falling back - // to decoding the middleware error: - val handlerError = - endpoint.error - .decodeResponse(response) - .map(e => alt.left(e)) - .mapError(t => new IllegalStateException("Cannot deserialize using endpoint error codec", t)) - - val middlewareError = - invocation.middleware.error - .decodeResponse(response) - .map(e => alt.right(e)) - .mapError(t => new IllegalStateException("Cannot deserialize using middleware error codec", t)) - - handlerError.catchAllCause { handlerCause => - middlewareError.catchAllCause { middlewareCause => - ZIO.failCause(handlerCause ++ middlewareCause) - } - }.orDie.flip + ZIO.die(new IllegalStateException(s"Status code: ${response.status} is not defined in the endpoint")) } } }