diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala index 4dd554ec08..2e0d063a2f 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodec.scala @@ -32,6 +32,7 @@ import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ import zio.http.codec.HttpCodec.{Annotated, Metadata} import zio.http.codec.internal.EncoderDecoder +import zio.prelude._ /** * A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP @@ -610,13 +611,17 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with } } - private[http] final case class MultiQuery[I](name: String, textCodec: TextCodec[I], index: Int = 0) - extends Query[Chunk[I], I] { - def index(index: Int): Query[Chunk[I], I] = copy(index = index) + private[http] final case class MultiQuery[F[+_]: ForEach, I]( + name: String, + textCodec: TextCodec[I], + cardinality: QueryCardinality[F], + index: Int = 0, + ) extends Query[F[I], I] { + def index(index: Int): Query[F[I], I] = copy(index = index) - def encode(value: Chunk[I]): Chunk[String] = value map textCodec.encode + def encode(values: F[I]): Chunk[String] = values.map(textCodec.encode).toChunk - def decode(values: Chunk[String]): Chunk[I] = values map decodeItem + def decode(values: Chunk[String]): F[I] = cardinality.decode(name, textCodec, values) } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala index 822b363821..f87595fb60 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -57,6 +57,12 @@ object HttpCodecError { final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec" } + + final case class WrongQueryParamCardinality(queryParamName: String, actual: Int, expected: String) + extends HttpCodecError { + def message = s"Wrong query parameter $queryParamName cardinality $actual, $expected expected" + } + final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError { def message = s"Malformed request body failed to decode: $details" } diff --git a/zio-http/src/main/scala/zio/http/codec/QueryCardinality.scala b/zio-http/src/main/scala/zio/http/codec/QueryCardinality.scala new file mode 100644 index 0000000000..894ea47fc4 --- /dev/null +++ b/zio-http/src/main/scala/zio/http/codec/QueryCardinality.scala @@ -0,0 +1,36 @@ +package zio.http.codec + +import zio.Chunk +import zio.http.codec.HttpCodecError +import zio.prelude._ + +sealed case class QueryCardinality[F[+_]: Covariant](extract: (String, Chunk[String]) => F[String]) { + def coerce(values: Any): F[_] = values.asInstanceOf[F[_]] + + def decode[I](name: String, codec: TextCodec[I], values: Chunk[String]): F[I] = extract(name, values) map { + (value: String) => + if (codec.isDefinedAt(value)) codec(value) else throw HttpCodecError.MalformedQueryParam(name, codec) + } +} + +object QueryCardinality { + object any extends QueryCardinality((_, values) => values) + + object oneOrMore + extends QueryCardinality((name, values) => + values + .nonEmptyOrElse(throw HttpCodecError.WrongQueryParamCardinality(name, values.length, "one or more"))(identity), + ) + + object optional + extends QueryCardinality((name, values) => + if (values.length > 1) throw HttpCodecError.WrongQueryParamCardinality(name, values.length, "one or none") + else values.headOption, + ) + + object one + extends QueryCardinality[Id.Type]((name, values) => + if (values.length == 1) Id(values.head) + else throw HttpCodecError.WrongQueryParamCardinality(name, values.length, "exactly one"), + ) +} diff --git a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala index 4d73bb0026..cdb4d12435 100644 --- a/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,36 +15,49 @@ */ package zio.http.codec -import zio.Chunk +import zio.{Chunk, NonEmptyChunk} +import zio.prelude.{ForEach, Id} import zio.stacktracer.TracingImplicits.disableAutoTrace private[codec] trait QueryCodecs { - def query(name: String): QueryCodec[String] = - HttpCodec.MonoQuery(name, TextCodec.string) + @inline def queryAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = + HttpCodec.MonoQuery(name, codec) - def queryBool(name: String): QueryCodec[Boolean] = - HttpCodec.MonoQuery(name, TextCodec.boolean) + def query(name: String): QueryCodec[String] = queryAs[String](name) - def queryInt(name: String): QueryCodec[Int] = - HttpCodec.MonoQuery(name, TextCodec.int) + def queryBool(name: String): QueryCodec[Boolean] = queryAs[Boolean](name) - def queryAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.MonoQuery(name, codec) + def queryInt(name: String): QueryCodec[Int] = queryAs[Int](name) - def queries[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = - HttpCodec.MultiQuery(name, codec) + @inline def queryAs[F[+_]: ForEach, I](name: String, cardinality: QueryCardinality[F])(implicit + codec: TextCodec[I], + ): QueryCodec[F[I]] = + HttpCodec.MultiQuery(name, codec, cardinality) - def paramStr(name: String): QueryCodec[String] = - HttpCodec.MonoQuery(name, TextCodec.string) + def queryOpt[I: TextCodec](name: String): QueryCodec[Option[I]] = queryAs(name, QueryCardinality.optional) - def paramBool(name: String): QueryCodec[Boolean] = - HttpCodec.MonoQuery(name, TextCodec.boolean) + def queryOne[I: TextCodec](name: String): QueryCodec[Id[I]] = queryAs(name, QueryCardinality.one) - def paramInt(name: String): QueryCodec[Int] = - HttpCodec.MonoQuery(name, TextCodec.int) + def queries[I: TextCodec](name: String): QueryCodec[Chunk[I]] = queryAs(name, QueryCardinality.any) - def paramAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = - HttpCodec.MonoQuery(name, codec) + def queryOneOrMore[I: TextCodec](name: String): QueryCodec[NonEmptyChunk[I]] = + queryAs(name, QueryCardinality.oneOrMore) + + def paramAs[A](name: String)(implicit codec: TextCodec[A]): QueryCodec[A] = queryAs[A](name) + + def paramStr(name: String): QueryCodec[String] = query(name) + + def paramBool(name: String): QueryCodec[Boolean] = queryBool(name) + + def paramInt(name: String): QueryCodec[Int] = queryInt(name) + + def paramAs[F[+_]: ForEach, I: TextCodec](name: String, cardinality: QueryCardinality[F]): QueryCodec[F[I]] = + queryAs(name, cardinality) + + def paramOpt[I: TextCodec](name: String): QueryCodec[Option[I]] = queryOpt(name) + + def paramOne[I: TextCodec](name: String): QueryCodec[Id[I]] = queryOne(name) + + def params[I: TextCodec](name: String): QueryCodec[Chunk[I]] = queries(name) - def params[I](name: String)(implicit codec: TextCodec[I]): QueryCodec[Chunk[I]] = - HttpCodec.MultiQuery(name, codec) + def paramOneOrMore[I: TextCodec](name: String): QueryCodec[NonEmptyChunk[I]] = queryOneOrMore(name) } diff --git a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala index 6d6bbbc7f6..695aa1ff8a 100644 --- a/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala +++ b/zio-http/src/test/scala/zio/http/codec/HttpCodecSpec.scala @@ -19,6 +19,7 @@ package zio.http.codec import java.util.UUID import zio._ +import zio.prelude.Id import zio.test._ import zio.http._ @@ -37,16 +38,22 @@ object HttpCodecSpec extends ZIOHttpSpec { val emptyJson = Body.fromString("{}") - val strParam = "name" - val codecStr = QueryCodec.paramStr(strParam) - val boolParam = "isAge" - val codecBool = QueryCodec.paramBool(boolParam) - val intParam = "age" - val codecInt = QueryCodec.paramInt(intParam) - val longParam = "count" - val codecLong = QueryCodec.paramAs[Long](longParam) - val seqIntParam = "integers" - val codecSeqInt = QueryCodec.params[Int](seqIntParam) + private val strParam = "name" + private val codecStr = QueryCodec.paramStr(strParam) + private val boolParam = "isAge" + private val codecBool = QueryCodec.paramBool(boolParam) + private val intParam = "age" + private val codecInt = QueryCodec.paramInt(intParam) + private val longParam = "count" + private val codecLong = QueryCodec.paramAs[Long](longParam) + private val optBoolParam = "maybe" + private val codecOptBool = QueryCodec.paramOpt[Boolean](optBoolParam) + private val oneLongParam = "lonelyLong" + private val codecOneLong = QueryCodec.paramOne[Long](oneLongParam) + private val seqIntParam = "integers" + private val codecSeqInt = QueryCodec.params[Int](seqIntParam) + private val oneOrMoreStrParam = "names" + private val codecOneOrMoreStr = QueryCodec.paramOneOrMore[String](oneOrMoreStrParam) def makeRequest(name: String, value: Any) = Request.get(googleUrl.queryParams(QueryParams(name -> value.toString))) @@ -172,22 +179,74 @@ object HttpCodecSpec extends ZIOHttpSpec { ) } }, - test("paramSeq decoding with empty chunk") { + test("paramOpt decoding empty chunk") { + assertZIO(codecOptBool.decodeRequest(makeChunkRequest(optBoolParam, Chunk.empty)))(Assertion.isNone) + }, + test("paramOpt decoding singleton chunk") { + assertZIO(codecOptBool.decodeRequest(makeChunkRequest(optBoolParam, Chunk("true"))))( + Assertion.isSome(Assertion.isTrue), + ) && + assertZIO(codecOptBool.decodeRequest(makeChunkRequest(optBoolParam, Chunk("false"))))( + Assertion.isSome(Assertion.isFalse), + ) + }, + test("paramOpt encoding empty chunk") { + assert(codecOptBool.encodeRequest(None).url.queryParams.get(optBoolParam))(Assertion.isNone) + }, + test("paramOpt encoding non-empty chunk") { + assert(codecOptBool.encodeRequest(Some(true)).url.queryParams.getAll(optBoolParam).get)( + Assertion.equalTo(Chunk("true")), + ) && + assert(codecOptBool.encodeRequest(Some(false)).url.queryParams.getAll(optBoolParam).get)( + Assertion.equalTo(Chunk("false")), + ) + }, + test("paramOne decoding singleton chunk") { + assertZIO(codecOneLong.decodeRequest(makeChunkRequest(oneLongParam, Chunk(Long.MaxValue.toString))))( + Assertion.equalTo(Id(Long.MaxValue)), + ) + }, + test("paramOne encoding non-empty chunk") { + assert(codecOneLong.encodeRequest(Id(Long.MinValue)).url.queryParams.getAll(oneLongParam).get)( + Assertion.equalTo(Chunk(Long.MinValue.toString)), + ) + }, + test("params decoding empty chunk") { assertZIO(codecSeqInt.decodeRequest(makeChunkRequest(seqIntParam, Chunk.empty)))(Assertion.isEmpty) }, - test("paramSeq decoding with non-empty chunk") { + test("params decoding non-empty chunk") { assertZIO(codecSeqInt.decodeRequest(makeChunkRequest(seqIntParam, Chunk("2023", "10", "7"))))( Assertion.equalTo(Chunk(2023, 10, 7)), ) }, - test("paramSeq encoding with empty chunk") { + test("params encoding empty chunk") { assert(codecSeqInt.encodeRequest(Chunk.empty).url.queryParams.get(seqIntParam))(Assertion.isNone) }, - test("paramSeq encoding with non-empty chunk") { + test("params encoding non-empty chunk") { assert(codecSeqInt.encodeRequest(Chunk(1974, 5, 3)).url.queryParams.getAll(seqIntParam).get)( Assertion.equalTo(Chunk("1974", "5", "3")), ) }, + test("paramOneOrMore decoding non-empty chunk") { + assertZIO(codecOneOrMoreStr.decodeRequest(makeChunkRequest(oneOrMoreStrParam, Chunk("one"))))( + Assertion.equalTo(NonEmptyChunk("one")), + ) && + assertZIO(codecOneOrMoreStr.decodeRequest(makeChunkRequest(oneOrMoreStrParam, Chunk("one", "two", "three"))))( + Assertion.equalTo(NonEmptyChunk("one", "two", "three")), + ) + }, + test("paramOneOrMore encoding non-empty chunk") { + assert( + codecOneOrMoreStr + .encodeRequest(NonEmptyChunk("for", "five", "six")) + .url + .queryParams + .getAll(oneOrMoreStrParam) + .get, + )( + Assertion.equalTo(Chunk("for", "five", "six")), + ) + }, ) + suite("Codec with examples") { test("with examples") { diff --git a/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala index f7cba19f65..463e265535 100644 --- a/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala +++ b/zio-http/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala @@ -30,7 +30,7 @@ import zio.schema.{DeriveSchema, Schema} import zio.http.Header.ContentType import zio.http.Method._ import zio.http._ -import zio.http.codec.HttpCodec.{queries, query, queryAs, queryInt} +import zio.http.codec.HttpCodec.{queries, query, queryAs, queryOne, queryOneOrMore, queryOpt} import zio.http.codec._ import zio.http.endpoint.EndpointSpec.testEndpoint import zio.http.forms.Fixtures.formField @@ -105,7 +105,7 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes(s"/users/$userId?key=$key&value=$value", s"path(users, $userId, Some($key), Some($value))") } }, - test("query parameter with multiple values") { + test("query parameter with any number of values") { check(Gen.boolean, Gen.alphaNumericString, Gen.alphaNumericString) { (isSomething, name1, name2) => val testRoutes = testEndpoint( Routes( @@ -125,5 +125,39 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes(s"/data?isSomething=$isSomething&name=$name1&name=$name2", s"query($isSomething, $name1, $name2)") } }, + test("query parameter with one or more values") { + check(Gen.boolean, Gen.alphaNumericString, Gen.alphaNumericString) { (isSomething, name1, name2) => + val testRoutes = testEndpoint( + Routes( + Endpoint(GET / "data") + .query(queryOne[Boolean]("isSomething")) + .query(queryOneOrMore[String]("name")) + .out[String] + .implement { + Handler.fromFunction { case (isSomething, names) => + s"query($isSomething, ${names mkString ", "})" + } + }, + ), + ) _ + testRoutes(s"/data?isSomething=$isSomething&name=$name1", s"query($isSomething, $name1)") && + testRoutes(s"/data?isSomething=$isSomething&name=$name1&name=$name2", s"query($isSomething, $name1, $name2)") + } + }, + test("query parameter with optional value") { + check(Gen.alphaNumericString) { (name) => + val testRoutes = testEndpoint( + Routes( + Endpoint(GET / "data") + .query(queryOpt[String]("name")) + .out[String] + .implement { + Handler.fromFunction { name => s"query($name)" } + }, + ), + ) _ + testRoutes(s"/data", s"query(None)") && testRoutes(s"/data?name=$name", s"query(Some($name))") + } + }, ) }