Skip to content

Commit

Permalink
Minor code readability improvements and micro-optimizations (#2289)
Browse files Browse the repository at this point in the history
* Minor adapter optimizations and code readability improvements

* Fix scala 2 warning

* Use a ListBuffer in `toResponseValue`
  • Loading branch information
kyri-petrou authored Jun 17, 2024
1 parent 68ee2a4 commit c76e2f5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 43 deletions.
30 changes: 19 additions & 11 deletions adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ final private class QuickRequestHandler[R](
def handleHttpRequest(request: Request)(implicit trace: Trace): URIO[R, Response] =
transformHttpRequest(request)
.flatMap(executeRequest(request.method, _))
.map(transformResponse(request, _))
.merge
.foldZIO(
Exit.succeed,
resp => Exit.succeed(transformResponse(request, resp))
)

def handleUploadRequest(request: Request)(implicit trace: Trace): URIO[R, Response] =
transformUploadRequest(request).flatMap { case (req, fileHandle) =>
Expand Down Expand Up @@ -93,10 +95,10 @@ final private class QuickRequestHandler[R](

def decodeJson(): ZIO[Any, Response, GraphQLRequest] =
body.asArray.foldZIO(
_ => ZIO.fail(BodyDecodeErrorResponse),
_ => Exit.fail(BodyDecodeErrorResponse),
arr =>
try checkNonEmptyRequest(readFromArray[GraphQLRequest](arr))
catch { case NonFatal(_) => ZIO.fail(BodyDecodeErrorResponse) }
catch { case NonFatal(_) => Exit.fail(BodyDecodeErrorResponse) }
)

val isApplicationGql =
Expand All @@ -111,7 +113,7 @@ final private class QuickRequestHandler[R](
val queryParams = httpReq.url.queryParams

if ((httpReq.method eq Method.GET) || queryParams.hasQueryParam("query")) {
decodeQueryParams(queryParams).fold(ZIO.fail(_), checkNonEmptyRequest)
decodeQueryParams(queryParams).fold(Exit.fail, checkNonEmptyRequest)
} else {
val req = decodeBody(httpReq.body)
if (isFtv1Request(httpReq)) req.map(_.withFederatedTracing)
Expand Down Expand Up @@ -168,11 +170,14 @@ final private class QuickRequestHandler[R](
}

private def responseHeaders(headers: Headers, cacheDirective: Option[String]): Headers =
cacheDirective.fold(headers)(headers.addHeader(Header.CacheControl.name, _))
cacheDirective match {
case None => headers
case Some(h) => headers.addHeader(Header.CacheControl.name, h)
}

private def transformResponse(httpReq: Request, resp: GraphQLResponse[Any])(implicit trace: Trace): Response = {
val accepts = new HttpUtils.AcceptsGqlEncodings(httpReq.headers.get(Header.Accept.name))
val cacheDirective = HttpUtils.computeCacheDirective(resp.extensions)
val cacheDirective = resp.extensions.flatMap(HttpUtils.computeCacheDirective)

resp match {
case resp @ GraphQLResponse(StreamValue(stream), _, _, _) =>
Expand All @@ -184,19 +189,20 @@ final private class QuickRequestHandler[R](
case resp if accepts.serverSentEvents =>
Response.fromServerSentEvents(encodeTextEventStream(resp))
case resp if accepts.graphQLJson =>
val isBadRequest = resp.errors.collectFirst {
val isBadRequest = resp.errors.exists {
case _: CalibanError.ParsingError | _: CalibanError.ValidationError => true
}.getOrElse(false)
case _ => false
}
Response(
status = if (isBadRequest) Status.BadRequest else Status.Ok,
headers = responseHeaders(ContentTypeGql, cacheDirective),
body =
encodeSingleResponse(resp, keepDataOnErrors = !isBadRequest, hasCacheDirective = cacheDirective.isDefined)
)
case resp =>
val isBadRequest = resp.errors.contains(HttpRequestMethod.MutationOverGetError)
Response(
status = resp.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => Status.BadRequest }
.getOrElse(Status.Ok),
status = if (isBadRequest) Status.BadRequest else Status.Ok,
headers = responseHeaders(ContentTypeJson, cacheDirective),
body = encodeSingleResponse(resp, keepDataOnErrors = true, hasCacheDirective = cacheDirective.isDefined)
)
Expand Down Expand Up @@ -298,5 +304,7 @@ object QuickRequestHandler {
null.asInstanceOf[InputValue.ObjectValue]
}

private implicit val responseCodec: JsonValueCodec[ResponseValue] = ValueJsoniter.responseValueCodec

private implicit val stringListCodec: JsonValueCodec[Map[String, List[String]]] = JsonCodecMaker.make
}
42 changes: 26 additions & 16 deletions core/src/main/scala/caliban/GraphQLResponse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import caliban.interop.play.{ IsPlayJsonReads, IsPlayJsonWrites }
import caliban.interop.tapir.IsTapirSchema
import caliban.interop.zio.{ IsZIOJsonCodec, IsZIOJsonDecoder, IsZIOJsonEncoder }

import scala.collection.mutable.ListBuffer

/**
* Represents the result of a GraphQL query, containing a data object and a list of errors.
*/
Expand All @@ -20,22 +22,30 @@ case class GraphQLResponse[+E](
def toResponseValue: ResponseValue = toResponseValue(keepDataOnErrors = true)

def toResponseValue(keepDataOnErrors: Boolean, excludeExtensions: Option[Set[String]] = None): ResponseValue = {
val hasErrors = errors.nonEmpty
ObjectValue(
List(
"data" -> (if (!hasErrors || keepDataOnErrors) Some(data) else None),
"errors" -> (if (hasErrors)
Some(ListValue(errors.map {
case e: CalibanError => e.toResponseValue
case e => ObjectValue(List("message" -> StringValue(e.toString)))
}))
else None),
"extensions" -> excludeExtensions.fold(extensions)(excl =>
extensions.map(obj => ObjectValue(obj.fields.filterNot(f => excl.contains(f._1))))
),
"hasNext" -> hasNext.map(BooleanValue.apply)
).collect { case (name, Some(v)) => name -> v }
)
val builder = new ListBuffer[(String, ResponseValue)]
val hasErrors = errors.nonEmpty
val extensions0 = excludeExtensions match {
case None => extensions
case Some(excl) =>
extensions.flatMap { obj =>
val newFields = obj.fields.filterNot(f => excl.contains(f._1))
if (newFields.nonEmpty) Some(ObjectValue(newFields)) else None
}
}

if (!hasErrors || keepDataOnErrors)
builder += "data" -> data
if (hasErrors)
builder += "errors" -> ListValue(errors.map {
case e: CalibanError => e.toResponseValue
case e => ObjectValue(List("message" -> StringValue(e.toString)))
})
if (extensions0.nonEmpty)
builder += "extensions" -> extensions0.get
if (hasNext.nonEmpty)
builder += "hasNext" -> BooleanValue(hasNext.get)

ObjectValue(builder.result())
}

def withExtension(key: String, value: ResponseValue): GraphQLResponse[E] =
Expand Down
9 changes: 4 additions & 5 deletions core/src/main/scala/caliban/HttpUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ private[caliban] object HttpUtils {
}).map(v => toSse(v.toResponseValue)) ++ ZStream.succeed(done)
}

def computeCacheDirective(extensions: Option[ResponseValue.ObjectValue]): Option[String] =
extensions
.flatMap(_.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) =>
fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader }
}.flatten)
def computeCacheDirective(extensions: ResponseValue.ObjectValue): Option[String] =
extensions.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) =>
fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader }
}.flatten

final class AcceptsGqlEncodings(header0: Option[String]) {
private val isEmpty = header0.isEmpty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ object TapirAdapter {
streamConstructor: StreamConstructor[BS],
responseCodec: JsonCodec[ResponseValue]
): (MediaType, StatusCode, Option[String], CalibanBody[BS]) = {
val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept))
val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept))
val cacheDirective = response.extensions.flatMap(HttpUtils.computeCacheDirective)

response match {
case resp @ GraphQLResponse(StreamValue(stream), _, _, _) =>
Expand All @@ -116,15 +117,14 @@ object TapirAdapter {
encodeMultipartMixedResponse(resp, stream)
)
case resp if accepts.graphQLJson =>
val isBadRequest = response.errors.collectFirst {
val isBadRequest = response.errors.exists {
case _: CalibanError.ParsingError | _: CalibanError.ValidationError => true
}.getOrElse(false)
val code = if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok
val cacheDirective = HttpUtils.computeCacheDirective(response.extensions)
case _ => false
}
(
GraphqlResponseJson.mediaType,
code,
HttpUtils.computeCacheDirective(response.extensions),
if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok,
cacheDirective,
encodeSingleResponse(
resp,
keepDataOnErrors = !isBadRequest,
Expand All @@ -139,12 +139,10 @@ object TapirAdapter {
encodeTextEventStreamResponse(resp)
)
case resp =>
val code = response.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => StatusCode.BadRequest }
.getOrElse(StatusCode.Ok)
val cacheDirective = HttpUtils.computeCacheDirective(response.extensions)
val isBadRequest = response.errors.contains(HttpRequestMethod.MutationOverGetError: Any)
(
MediaType.ApplicationJson,
code,
if (isBadRequest) StatusCode.BadRequest else StatusCode.Ok,
cacheDirective,
encodeSingleResponse(
resp,
Expand Down

0 comments on commit c76e2f5

Please sign in to comment.