From 1b5359049427d51b56bf981f3127e9f8f311c297 Mon Sep 17 00:00:00 2001 From: Jules Ivanic Date: Tue, 19 Nov 2024 11:42:25 +1100 Subject: [PATCH] Optimize `zio.http.codec.internal.EncoderDecoder.Single.decodeBody` code --- .../http/codec/internal/EncoderDecoder.scala | 257 +++++++++--------- 1 file changed, 132 insertions(+), 125 deletions(-) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 705b432cbf..a09a886cc0 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -16,21 +16,20 @@ package zio.http.codec.internal -import scala.util.Try - import zio._ - -import zio.schema.codec.{BinaryCodec, DecodeError} -import zio.schema.{Schema, StandardType} - import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec._ +import zio.schema.codec.{BinaryCodec, DecodeError} +import zio.schema.{Schema, StandardType} -private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => +import scala.util.Try + +private[codec] trait EncoderDecoder[-AtomTypes, Value] { + self => def decode(config: CodecConfig, url: URL, status: Status, method: Method, headers: Headers, body: Body)(implicit - trace: Trace, + trace: Trace, ): Task[Value] def encodeWith[Z](config: CodecConfig, value: Value, outputTypes: Chunk[MediaTypeWithQFactor])( @@ -38,11 +37,12 @@ private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => ): Z } + private[codec] object EncoderDecoder { def apply[AtomTypes, Value]( - httpCodec: HttpCodec[AtomTypes, Value], - ): EncoderDecoder[AtomTypes, Value] = { + httpCodec: HttpCodec[AtomTypes, Value], + ): EncoderDecoder[AtomTypes, Value] = { val flattened = httpCodec.alternatives flattened.length match { @@ -53,8 +53,8 @@ private[codec] object EncoderDecoder { } private final case class Multiple[-AtomTypes, Value]( - httpCodecs: Chunk[(HttpCodec[AtomTypes, Value], HttpCodec.Fallback.Condition)], - ) extends EncoderDecoder[AtomTypes, Value] { + httpCodecs: Chunk[(HttpCodec[AtomTypes, Value], HttpCodec.Fallback.Condition)], + ) extends EncoderDecoder[AtomTypes, Value] { val singles = httpCodecs.map { case (httpCodec, condition) => Single(httpCodec) -> condition } override def decode(config: CodecConfig, url: URL, status: Status, method: Method, headers: Headers, body: Body)( @@ -84,8 +84,8 @@ private[codec] object EncoderDecoder { override def encodeWith[Z](config: CodecConfig, value: Value, outputTypes: Chunk[MediaTypeWithQFactor])( f: (URL, Option[Status], Option[Method], Headers, Body) => Z, ): Z = { - var i = 0 - var encoded = null.asInstanceOf[Z] + var i = 0 + var encoded = null.asInstanceOf[Z] var lastError = null.asInstanceOf[Throwable] while (i < singles.length) { @@ -124,39 +124,39 @@ private[codec] object EncoderDecoder { """.stripMargin.trim() override def encodeWith[Z]( - config: CodecConfig, - value: Value, - outputTypes: Chunk[MediaTypeWithQFactor], - )(f: (zio.http.URL, Option[zio.http.Status], Option[zio.http.Method], zio.http.Headers, zio.http.Body) => Z): Z = { + config: CodecConfig, + value: Value, + outputTypes: Chunk[MediaTypeWithQFactor], + )(f: (zio.http.URL, Option[zio.http.Status], Option[zio.http.Method], zio.http.Headers, zio.http.Body) => Z): Z = { throw new IllegalStateException(encodeWithErrorMessage) } override def decode( - config: CodecConfig, - url: zio.http.URL, - status: zio.http.Status, - method: zio.http.Method, - headers: zio.http.Headers, - body: zio.http.Body, - )(implicit trace: zio.Trace): zio.Task[Value] = { + config: CodecConfig, + url: zio.http.URL, + status: zio.http.Status, + method: zio.http.Method, + headers: zio.http.Headers, + body: zio.http.Body, + )(implicit trace: zio.Trace): zio.Task[Value] = { ZIO.fail(new IllegalStateException(decodeErrorMessage)) } } private final case class Single[-AtomTypes, Value]( - httpCodec: HttpCodec[AtomTypes, Value], - ) extends EncoderDecoder[AtomTypes, Value] { - private val constructor = Mechanic.makeConstructor(httpCodec) + httpCodec: HttpCodec[AtomTypes, Value], + ) extends EncoderDecoder[AtomTypes, Value] { + private val constructor = Mechanic.makeConstructor(httpCodec) private val deconstructor = Mechanic.makeDeconstructor(httpCodec) private val flattened: AtomizedCodecs = AtomizedCodecs.flatten(httpCodec) - implicit val trace: Trace = Trace.empty + implicit val trace: Trace = Trace.empty private lazy val formBoundary = Boundary("----zio-http-boundary-D4792A5C-93E0-43B5-9A1F-48E38FDE5714") - private lazy val indexByName = flattened.content.zipWithIndex.map { case (codec, idx) => + private lazy val indexByName = flattened.content.zipWithIndex.map { case (codec, idx) => codec.name.getOrElse("field" + idx.toString) -> idx }.toMap - private lazy val nameByIndex = indexByName.map(_.swap) + private lazy val nameByIndex = indexByName.map(_.swap) override def decode(config: CodecConfig, url: URL, status: Status, method: Method, headers: Headers, body: Body)( implicit trace: Trace, @@ -176,24 +176,26 @@ private[codec] object EncoderDecoder { ): Z = { val inputs = deconstructor(value) - val path = encodePath(inputs.path) - val query = encodeQuery(config, inputs.query) - val status = encodeStatus(inputs.status) - val method = encodeMethod(inputs.method) - val headers = encodeHeaders(inputs.header) + val path = encodePath(inputs.path) + val query = encodeQuery(config, inputs.query) + val status = encodeStatus(inputs.status) + val method = encodeMethod(inputs.method) + val headers = encodeHeaders(inputs.header) + def contentTypeHeaders = encodeContentType(inputs.content, outputTypes) - val body = encodeBody(config, inputs.content, outputTypes) + + val body = encodeBody(config, inputs.content, outputTypes) val headers0 = if (headers.contains("content-type")) headers else headers ++ contentTypeHeaders f(URL(path, queryParams = query), status, method, headers0, body) } private def genericDecode[A, Codec]( - a: A, - codecs: Chunk[Codec], - inputs: Array[Any], - decode: (Codec, A) => Any, - ): Unit = { + a: A, + codecs: Chunk[Codec], + inputs: Array[Any], + decode: (Codec, A) => Any, + ): Unit = { for (i <- 0 until inputs.length) { val codec = codecs(i) inputs(i) = decode(codec, a) @@ -207,7 +209,7 @@ private[codec] object EncoderDecoder { inputs, (codec, path) => { codec.erase.decode(path) match { - case Left(error) => throw HttpCodecError.MalformedPath(path, codec, error) + case Left(error) => throw HttpCodecError.MalformedPath(path, codec, error) case Right(value) => value } }, @@ -219,22 +221,22 @@ private[codec] object EncoderDecoder { flattened.query, inputs, (codec, queryParams) => { - val query = codec.erase + val query = codec.erase val isOptional = query.isOptional query.queryType match { - case QueryType.Primitive(name, bc @ BinaryCodecWithSchema(_, schema)) => - val count = queryParams.valueCount(name) + case QueryType.Primitive(name, bc@BinaryCodecWithSchema(_, schema)) => + val count = queryParams.valueCount(name) val hasParam = queryParams.hasQueryParam(name) if (!hasParam && isOptional) None else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) else { - val decoded = bc + val decoded = bc .codec(config) .decode( Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), ) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) case Right(value) => value } val validationErrors = schema.validate(decoded)(schema) @@ -243,41 +245,41 @@ private[codec] object EncoderDecoder { Some("") else decoded } - case c @ QueryType.Collection(_, QueryType.Primitive(name, bc), optional) => + case c@QueryType.Collection(_, QueryType.Primitive(name, bc), optional) => if (!queryParams.hasQueryParam(name)) { if (!optional) c.toCollection(Chunk.empty) else None } else { - val values = queryParams.queryParams(name) - val decoded = c.toCollection { + val values = queryParams.queryParams(name) + val decoded = c.toCollection { values.map { value => bc.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) case Right(value) => value } } } - val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] + val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] val validationErrors = erasedSchema.validate(decoded)(erasedSchema) if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) if (optional) Some(decoded) else decoded } - case query @ QueryType.Record(recordSchema) => + case query@QueryType.Record(recordSchema) => val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined } if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) None else if (!hasAllParams && isOptional) { recordSchema.defaultValue match { - case Left(err) => + case Left(err) => throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") case Right(value) => value } } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { query.fieldAndCodecs.collect { case (field, _) - if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => + if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => field.name } } @@ -286,24 +288,24 @@ private[codec] object EncoderDecoder { case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get else { - val values = queryParams.queryParams(field.name) - val decoded = values.map { value => + val values = queryParams.queryParams(field.name) + val decoded = values.map { value => codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) + case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) case Right(value) => value } } val decodedCollection = field.schema match { - case s @ Schema.Sequence(_, fromChunk, _, _, _) => - val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = s.asInstanceOf[Schema[Any]] + case s@Schema.Sequence(_, fromChunk, _, _, _) => + val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = s.asInstanceOf[Schema[Any]] val validationErrors = erasedSchema.validate(collection)(erasedSchema) if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) collection - case s @ Schema.Set(_, _) => - val collection = decoded.toSet[Any] - val erasedSchema = s.asInstanceOf[Schema.Set[Any]] + case s@Schema.Set(_, _) => + val collection = decoded.toSet[Any] + val erasedSchema = s.asInstanceOf[Schema.Set[Any]] val validationErrors = erasedSchema.validate(collection)(erasedSchema) if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) collection @@ -311,13 +313,13 @@ private[codec] object EncoderDecoder { } decodedCollection } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.name, null) - val decoded = { + case (field, codec) => + val value = queryParams.queryParamOrElse(field.name, null) + val decoded = { if (value == null) field.defaultValue.get else { codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) + case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) case Right(value) => value } } @@ -330,7 +332,7 @@ private[codec] object EncoderDecoder { val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] val constructed = schema.construct(decoded)(Unsafe.unsafe) constructed match { - case Left(value) => + case Left(value) => throw HttpCodecError.MalformedQueryParam( s"${schema.id}", DecodeError.ReadError(Cause.empty, value), @@ -338,14 +340,14 @@ private[codec] object EncoderDecoder { case Right(value) => schema.validate(value)(schema) match { case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) + case _ => Some(value) } } } else { - val schema = recordSchema.asInstanceOf[Schema.Record[Any]] + val schema = recordSchema.asInstanceOf[Schema.Record[Any]] val constructed = schema.construct(decoded)(Unsafe.unsafe) constructed match { - case Left(value) => + case Left(value) => throw HttpCodecError.MalformedQueryParam( s"${schema.id}", DecodeError.ReadError(Cause.empty, value), @@ -353,7 +355,7 @@ private[codec] object EncoderDecoder { case Right(value) => schema.validate(value)(schema) match { case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value + case _ => value } } } @@ -364,11 +366,11 @@ private[codec] object EncoderDecoder { private def emptyStringIsValue(schema: Schema[_]): Boolean = schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true + case StandardType.UnitType => true case StandardType.StringType => true case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false + case StandardType.CharType => true + case _ => false } private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = @@ -397,8 +399,8 @@ private[codec] object EncoderDecoder { codec match { case SimpleCodec.Specified(expected) if expected != status => throw HttpCodecError.MalformedStatus(expected, status) - case _: SimpleCodec.Unspecified[_] => status - case _ => () + case _: SimpleCodec.Unspecified[_] => status + case _ => () }, ) @@ -411,26 +413,29 @@ private[codec] object EncoderDecoder { codec match { case SimpleCodec.Specified(expected) if expected != method => throw HttpCodecError.MalformedMethod(expected, method) - case _: SimpleCodec.Unspecified[_] => method - case _ => () + case _: SimpleCodec.Unspecified[_] => method + case _ => () }, ) private def decodeBody(config: CodecConfig, body: Body, inputs: Array[Any])(implicit - trace: Trace, + trace: Trace, ): Task[Unit] = { - val codecs = flattened.content + val isNonMultiPart = inputs.length < 2 + if (isNonMultiPart) { + val codecs = flattened.content - if (inputs.length < 2) { - // non multi-part - codecs.headOption.map { codec => + //noinspection SimplifyUnlessInspection + if (codecs.isEmpty) ZIO.unit + else { + val codec = codecs.head codec .decodeFromBody(body, config) .mapBoth( - { err => HttpCodecError.MalformedBody(err.getMessage(), Some(err)) }, + { err => HttpCodecError.MalformedBody(err.getMessage, Some(err)) }, result => inputs(0) = result, ) - }.getOrElse(ZIO.unit) + } } else { // multi-part decodeForm(body.asMultipartFormStream, inputs, config) *> check(inputs) @@ -438,20 +443,22 @@ private[codec] object EncoderDecoder { } private def decodeForm( - form: Task[StreamingForm], - inputs: Array[Any], - config: CodecConfig, - ): ZIO[Any, Throwable, Unit] = + form: Task[StreamingForm], + inputs: Array[Any], + config: CodecConfig, + ): ZIO[Any, Throwable, Unit] = form.flatMap(_.collectAll).flatMap { collectedForm => ZIO.foreachDiscard(collectedForm.formData) { field => val codecs = flattened.content - val i = indexByName + val i = indexByName .get(field.name) .getOrElse(throw HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}")) - val codec = codecs(i).erase + val codec = codecs(i).erase for { decoded <- codec.decodeFromField(field, config) - _ <- ZIO.attempt { inputs(i) = decoded } + _ <- ZIO.attempt { + inputs(i) = decoded + } } yield () } } @@ -467,11 +474,11 @@ private[codec] object EncoderDecoder { } private def genericEncode[A, Codec]( - codecs: Chunk[Codec], - inputs: Array[Any], - init: A, - encoding: (Codec, Any, A) => A, - ): A = { + codecs: Chunk[Codec], + inputs: Array[Any], + init: A, + encoding: (Codec, Any, A) => A, + ): A = { var res = init for (i <- 0 until inputs.length) { val codec = codecs(i) @@ -485,7 +492,7 @@ private[codec] object EncoderDecoder { codecs.headOption.map { codec => codec match { case _: SimpleCodec.Unspecified[_] => inputs(0).asInstanceOf[A] - case SimpleCodec.Specified(elem) => elem + case SimpleCodec.Specified(elem) => elem } } @@ -496,7 +503,7 @@ private[codec] object EncoderDecoder { Path.empty, (codec, a, acc) => { val encoded = codec.erase.encode(a) match { - case Left(error) => + case Left(error) => throw HttpCodecError.MalformedPath(acc, codec, error) case Right(value) => value } @@ -513,7 +520,7 @@ private[codec] object EncoderDecoder { val query = codec.erase query.queryType match { - case QueryType.Primitive(name, codec) => + case QueryType.Primitive(name, codec) => val schema = codec.schema if (schema.isInstanceOf[Schema.Primitive[_]]) { if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { @@ -530,12 +537,12 @@ private[codec] object EncoderDecoder { "Only primitive schema is supported for query parameters of type Primitive", ) } - case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => + case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => var in: Any = input if (optional) { in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) } - val values = input.asInstanceOf[Iterable[Any]] + val values = input.asInstanceOf[Iterable[Any]] if (values.nonEmpty) { queryParams.addQueryParams( name, @@ -546,21 +553,21 @@ private[codec] object EncoderDecoder { ), ) } else queryParams - case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => + case query@QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => input match { - case None => queryParams + case None => queryParams case Some(value) => val innerSchema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var j = 0 - var qp = queryParams + var j = 0 + var qp = queryParams while (j < fieldValues.size) { val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { + val name = field.name + val value = fieldValues(j) match { case Some(value) => value - case None => field.defaultValue + case None => field.defaultValue } value match { case values: Iterable[_] => @@ -570,7 +577,7 @@ private[codec] object EncoderDecoder { codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString }), ) - case _ => + case _ => val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString qp = qp.addQueryParam(name, encoded) } @@ -578,17 +585,17 @@ private[codec] object EncoderDecoder { } qp } - case query @ QueryType.Record(recordSchema) => + case query@QueryType.Record(recordSchema) => val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) - var j = 0 - var qp = queryParams + var j = 0 + var qp = queryParams while (j < fieldValues.size) { val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { + val name = field.name + val value = fieldValues(j) match { case Some(value) => value - case None => field.defaultValue + case None => field.defaultValue } value match { case values if values.isInstanceOf[Iterable[_]] => @@ -598,7 +605,7 @@ private[codec] object EncoderDecoder { codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString }), ) - case _ => + case _ => val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString qp = qp.addQueryParam(name, encoded) } @@ -636,13 +643,13 @@ private[codec] object EncoderDecoder { } private def encodeMultipartFormData( - inputs: Array[Any], - outputTypes: Chunk[MediaTypeWithQFactor], - config: CodecConfig, - ): Form = { + inputs: Array[Any], + outputTypes: Chunk[MediaTypeWithQFactor], + config: CodecConfig, + ): Form = { val formFields = flattened.content.zipWithIndex.map { case (bodyCodec, idx) => val input = inputs(idx) - val name = nameByIndex(idx) + val name = nameByIndex(idx) bodyCodec.erase.encodeToField(input, outputTypes, name, config) }