diff --git a/zio-http/src/main/scala/zio/http/Scheme.scala b/zio-http/src/main/scala/zio/http/Scheme.scala index bbb39803ca..0336b6eb62 100644 --- a/zio-http/src/main/scala/zio/http/Scheme.scala +++ b/zio-http/src/main/scala/zio/http/Scheme.scala @@ -21,13 +21,17 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace sealed trait Scheme { self => def encode: String = self match { - case Scheme.HTTP => "http" - case Scheme.HTTPS => "https" - case Scheme.WS => "ws" - case Scheme.WSS => "wss" + case Scheme.HTTP => "http" + case Scheme.HTTPS => "https" + case Scheme.WS => "ws" + case Scheme.WSS => "wss" + case Scheme.Custom(scheme) => scheme } - def isHttp: Boolean = !isWebSocket + def isHttp: Boolean = self match { + case Scheme.HTTP | Scheme.HTTPS => true + case _ => false + } def isWebSocket: Boolean = self match { case Scheme.WS => true @@ -35,38 +39,45 @@ sealed trait Scheme { self => case _ => false } - def isSecure: Boolean = self match { - case Scheme.HTTPS => true - case Scheme.WSS => true - case _ => false + def isSecure: Option[Boolean] = self match { + case Scheme.HTTPS | Scheme.WSS => Some(true) + case Scheme.HTTP | Scheme.WS => Some(false) + case _ => None } - def defaultPort: Int = self match { - case Scheme.HTTP => 80 - case Scheme.HTTPS => 443 - case Scheme.WS => 80 - case Scheme.WSS => 443 + /** default ports is only define for the Schemes: http, https, ws, wss */ + def defaultPort: Option[Int] = self match { + case Scheme.HTTP => Some(Scheme.defaultPortForHTTP) + case Scheme.HTTPS => Some(Scheme.defaultPortForHTTPS) + case Scheme.WS => Some(Scheme.defaultPortForWS) + case Scheme.WSS => Some(Scheme.defaultPortForWSS) + case Scheme.Custom(_) => None } + } -object Scheme { + +object Scheme { /** * Decodes a string to an Option of Scheme. Returns None in case of * null/non-valid Scheme + * + * The should be lowercase and follow this syntax: + * - Scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) */ def decode(scheme: String): Option[Scheme] = Option(unsafe.decode(scheme)(Unsafe.unsafe)) private[zio] object unsafe { def decode(scheme: String)(implicit unsafe: Unsafe): Scheme = { - if (scheme == null) null + if (scheme == null || scheme.isEmpty) null else - scheme.length match { - case 5 => Scheme.HTTPS - case 4 => Scheme.HTTP - case 3 => Scheme.WSS - case 2 => Scheme.WS - case _ => null + scheme match { + case "http" => HTTP + case "https" => HTTPS + case "ws" => WS + case "wss" => WSS + case custom => Custom(custom.toLowerCase) } } } @@ -78,4 +89,15 @@ object Scheme { case object WS extends Scheme case object WSS extends Scheme + + /** + * @param scheme + * value MUST not be "http" "https" "ws" "wss" + */ + final case class Custom(scheme: String) extends Scheme + + def defaultPortForHTTP = 80 + def defaultPortForHTTPS = 443 + def defaultPortForWS = 80 + def defaultPortForWSS = 443 } diff --git a/zio-http/src/main/scala/zio/http/URL.scala b/zio-http/src/main/scala/zio/http/URL.scala index fe311dfe7a..6ae6507487 100644 --- a/zio-http/src/main/scala/zio/http/URL.scala +++ b/zio-http/src/main/scala/zio/http/URL.scala @@ -22,7 +22,7 @@ import scala.util.Try import zio.Chunk -import zio.http.URL.{Fragment, Location, portFromScheme} +import zio.http.URL.{Fragment, Location} import zio.http.internal.QueryParamEncoding final case class URL( @@ -48,10 +48,10 @@ final case class URL( def /(segment: String): URL = self.copy(path = self.path / segment) def absolute(host: String): URL = - self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP))) + self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, None)) def absolute(scheme: Scheme, host: String, port: Int): URL = - self.copy(kind = URL.Location.Absolute(scheme, host, port)) + self.copy(kind = URL.Location.Absolute(scheme, host, Some(port))) def addLeadingSlash: URL = self.copy(path = path.addLeadingSlash) @@ -101,20 +101,25 @@ final case class URL( def host(host: String): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP)) + case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, None) case abs: URL.Location.Absolute => abs.copy(host = host) } copy(kind = location) } + /** + * @return + * the location, the host name and the port. The port part is omitted if is + * the default port for the protocol. + */ def hostPort: Option[String] = kind match { - case URL.Location.Relative => None - case URL.Location.Absolute(scheme, host, port) => - Some( - if (port == portFromScheme(scheme)) host - else s"$host:$port", - ) + case URL.Location.Relative => None + case abs: URL.Location.Absolute => + abs.portIfNotDefault match { + case None => Some(abs.host) + case Some(customPort) => Some(s"${abs.host}:$customPort") + } } def isAbsolute: Boolean = self.kind match { @@ -140,8 +145,8 @@ final case class URL( def port(port: Int): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", port) - case abs: URL.Location.Absolute => abs.copy(port = port) + case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", Some(port)) + case abs: URL.Location.Absolute => abs.copy(originalPort = Some(port)) } copy(kind = location) @@ -149,16 +154,17 @@ final case class URL( def port: Option[Int] = kind match { case URL.Location.Relative => None - case abs: URL.Location.Absolute => Option(abs.port) + case abs: URL.Location.Absolute => abs.originalPort } - def portOrDefault: Int = port.getOrElse(portFromScheme(scheme.getOrElse(Scheme.HTTP))) + def portOrDefault: Option[Int] = kind match { + case URL.Location.Relative => None + case abs: URL.Location.Absolute => abs.portOrDefault + } def portIfNotDefault: Option[Int] = kind match { - case URL.Location.Relative => - None - case abs: URL.Location.Absolute => - if (abs.port == portFromScheme(abs.scheme)) None else Some(abs.port) + case URL.Location.Relative => None + case abs: URL.Location.Absolute => abs.portIfNotDefault } def queryParams(queryParams: QueryParams): URL = @@ -185,7 +191,7 @@ final case class URL( def scheme(scheme: Scheme): URL = { val location = kind match { - case URL.Location.Relative => URL.Location.Absolute(scheme, "", URL.portFromScheme(scheme)) + case URL.Location.Relative => URL.Location.Absolute(scheme, "", None) case abs: URL.Location.Absolute => abs.copy(scheme = scheme) } @@ -239,7 +245,11 @@ object URL { } object Location { - final case class Absolute(scheme: Scheme, host: String, port: Int) extends Location + final case class Absolute(scheme: Scheme, host: String, originalPort: Option[Int]) extends Location { + def portOrDefault: Option[Int] = originalPort.orElse(scheme.defaultPort) + def portIfNotDefault: Option[Int] = originalPort.filter(p => scheme.defaultPort.exists(_ != p)) + def port: Int = originalPort.orElse(scheme.defaultPort).getOrElse(Scheme.defaultPortForHTTP) + } case object Relative extends Location } @@ -262,13 +272,13 @@ object URL { ) + url.fragment.fold("")(f => "#" + f.raw) url.kind match { - case Location.Relative => - path(true) - case Location.Absolute(scheme, host, port) => + case Location.Relative => path(true) + case abs: Location.Absolute => val path2 = path(false) - - if (port == portFromScheme(scheme)) s"${scheme.encode}://$host$path2" - else s"${scheme.encode}://$host:$port$path2" + abs.portIfNotDefault match { + case None => s"${abs.scheme.encode}://${abs.host}$path2" + case Some(customPort) => s"${abs.scheme.encode}://${abs.host}:$customPort$path2" + } } } @@ -277,7 +287,7 @@ object URL { scheme <- Scheme.decode(uri.getScheme) host <- Option(uri.getHost) path <- Option(uri.getRawPath) - port = Option(uri.getPort).filter(_ != -1).getOrElse(portFromScheme(scheme)) + port = Option(uri.getPort).filter(_ != -1).orElse(scheme.defaultPort) // FIXME REMOVE defaultPort connection = URL.Location.Absolute(scheme, host, port) path2 = Path.decode(path) path3 = if (path.nonEmpty) path2.addLeadingSlash else path2 @@ -288,9 +298,4 @@ object URL { path <- Option(uri.getRawPath) } yield URL(Path.decode(path), Location.Relative, QueryParams.decode(uri.getRawQuery), Fragment.fromURI(uri)) - private def portFromScheme(scheme: Scheme): Int = scheme match { - case Scheme.HTTP | Scheme.WS => 80 - case Scheme.HTTPS | Scheme.WSS => 443 - } - } diff --git a/zio-http/src/main/scala/zio/http/ZClient.scala b/zio-http/src/main/scala/zio/http/ZClient.scala index 477ee85cc9..5563b31d2c 100644 --- a/zio-http/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/src/main/scala/zio/http/ZClient.scala @@ -668,18 +668,14 @@ object ZClient { app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] = for { - env <- ZIO.environment[Env1] - webSocketUrl = url.scheme( - url.scheme match { - case Some(Scheme.HTTP) => Scheme.WS - case Some(Scheme.HTTPS) => Scheme.WSS - case Some(Scheme.WS) => Scheme.WS - case Some(Scheme.WSS) => Scheme.WSS - case None => Scheme.WS - }, - ) - scope <- ZIO.scope - res <- requestAsync( + env <- ZIO.environment[Env1] + webSocketUrl <- url.scheme match { + case Some(Scheme.HTTP) | Some(Scheme.WS) | None => ZIO.succeed(url.scheme(Scheme.WS)) + case Some(Scheme.WSS) | Some(Scheme.HTTPS) => ZIO.succeed(url.scheme(Scheme.WSS)) + case _ => ZIO.fail(throw new IllegalArgumentException("URL's scheme MUST be WS(S) or HTTP(S)")) + } + scope <- ZIO.scope + res <- requestAsync( Request(version = version, method = Method.GET, url = webSocketUrl, headers = headers), config, () => app.provideEnvironment(env), diff --git a/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala b/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala index 222a6b8b54..1834142af7 100644 --- a/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala +++ b/zio-http/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala @@ -70,7 +70,7 @@ object NettyConnectionPool { case None => } - if (location.scheme.isSecure) { + if (location.scheme.isSecure.getOrElse(false)) { pipeline.addLast( Names.SSLHandler, ClientSSLConverter diff --git a/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala b/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala index 511b6e50e3..5868d7a381 100644 --- a/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala +++ b/zio-http/src/test/scala/zio/http/ResponseCompressionSpec.scala @@ -66,7 +66,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec { response <- client.request( Request( method = Method.GET, - url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)), + url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))), ) .addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())), ) @@ -82,7 +82,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec { response <- client.request( Request( method = Method.GET, - url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)), + url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))), ) .addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())), ) diff --git a/zio-http/src/test/scala/zio/http/SchemeSpec.scala b/zio-http/src/test/scala/zio/http/SchemeSpec.scala index 46e9026126..dcbaaaddb2 100644 --- a/zio-http/src/test/scala/zio/http/SchemeSpec.scala +++ b/zio-http/src/test/scala/zio/http/SchemeSpec.scala @@ -31,5 +31,8 @@ object SchemeSpec extends ZIOHttpSpec { test("null string decode") { assert(Scheme.decode(null))(isNone) }, + test("decode chrome-extension") { + assertTrue(Scheme.decode("chrome-extension").isDefined) + }, ) } diff --git a/zio-http/src/test/scala/zio/http/URLSpec.scala b/zio-http/src/test/scala/zio/http/URLSpec.scala index 0d0ab387b8..e6521f90be 100644 --- a/zio-http/src/test/scala/zio/http/URLSpec.scala +++ b/zio-http/src/test/scala/zio/http/URLSpec.scala @@ -49,14 +49,14 @@ object URLSpec extends ZIOHttpSpec { ), suite("normalize")( test("adds leading slash") { - val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None) + val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None) val url2 = url.normalize assertTrue(extractPath(url2) == Path("/a/b/c")) }, test("deletes leading slash if there are no path segments") { - val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None) + val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None) val url2 = url.normalize assertTrue(extractPath(url2) == Path.empty) diff --git a/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala b/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala index ac453659d1..beaa59edfa 100644 --- a/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala +++ b/zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala @@ -34,7 +34,7 @@ object ZClientAspectSpec extends ZIOHttpSpec { port <- Server.install(app) baseClient <- ZIO.service[Client] client = baseClient.url( - URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)), + URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))), ) @@ ZClientAspect.debug response <- client.request(Request.get(URL.empty / "hello")) output <- TestConsole.output @@ -51,7 +51,7 @@ object ZClientAspectSpec extends ZIOHttpSpec { baseClient <- ZIO.service[Client] client = baseClient .url( - URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)), + URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))), ) .disableStreaming @@ ZClientAspect.requestLogging( loggedRequestHeaders = Set(Header.UserAgent), diff --git a/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala b/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala index 6727dd6b2c..935aff670b 100644 --- a/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala +++ b/zio-http/src/test/scala/zio/http/headers/OriginSpec.scala @@ -43,6 +43,7 @@ object OriginSpec extends ZIOHttpSpec { assertTrue( Origin.parse("http://domain") == Right(Value("http", "domain", None)), Origin.parse("https://domain") == Right(Value("https", "domain", None)), + Origin.parse("chrome-extension://appid") == Right(Value("chrome-extension", "appid", None)), ) }, test("parsing of valid Origin values") { diff --git a/zio-http/src/test/scala/zio/http/internal/HttpGen.scala b/zio-http/src/test/scala/zio/http/internal/HttpGen.scala index 44079d075b..1ba2149702 100644 --- a/zio-http/src/test/scala/zio/http/internal/HttpGen.scala +++ b/zio-http/src/test/scala/zio/http/internal/HttpGen.scala @@ -70,7 +70,7 @@ object HttpGen { scheme <- Gen.fromIterable(List(Scheme.HTTP, Scheme.HTTPS)) host <- Gen.alphaNumericStringBounded(1, 5) port <- Gen.oneOf(Gen.const(80), Gen.const(443), Gen.int(0, 65536)) - } yield URL.Location.Absolute(scheme, host, port) + } yield URL.Location.Absolute(scheme, host, Some(port)) def genRelativeURL: Gen[Any, URL] = for { path <- HttpGen.anyPath diff --git a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala index 817bc8e395..dc80560107 100644 --- a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala +++ b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala @@ -49,7 +49,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self => client( params .addHeader(DynamicServer.APP_ID, id) - .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))), + .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))), ) .flatMap(_.collect) } @@ -80,7 +80,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self => client( params .addHeader(DynamicServer.APP_ID, id) - .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))), + .copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))), ) } } yield response