Skip to content

Commit

Permalink
Make Scheme follow the specs (zio#2490)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioPinheiro authored Nov 24, 2023
1 parent 89b8977 commit 33512bc
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 76 deletions.
66 changes: 44 additions & 22 deletions zio-http/src/main/scala/zio/http/Scheme.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,63 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace

sealed trait Scheme { self =>
def encode: String = self match {
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.Custom(scheme) => scheme
}

def isHttp: Boolean = !isWebSocket
def isHttp: Boolean = self match {
case Scheme.HTTP | Scheme.HTTPS => true
case _ => false
}

def isWebSocket: Boolean = self match {
case Scheme.WS => true
case Scheme.WSS => true
case _ => false
}

def isSecure: Boolean = self match {
case Scheme.HTTPS => true
case Scheme.WSS => true
case _ => false
def isSecure: Option[Boolean] = self match {
case Scheme.HTTPS | Scheme.WSS => Some(true)
case Scheme.HTTP | Scheme.WS => Some(false)
case _ => None
}

def defaultPort: Int = self match {
case Scheme.HTTP => 80
case Scheme.HTTPS => 443
case Scheme.WS => 80
case Scheme.WSS => 443
/** default ports is only define for the Schemes: http, https, ws, wss */
def defaultPort: Option[Int] = self match {
case Scheme.HTTP => Some(Scheme.defaultPortForHTTP)
case Scheme.HTTPS => Some(Scheme.defaultPortForHTTPS)
case Scheme.WS => Some(Scheme.defaultPortForWS)
case Scheme.WSS => Some(Scheme.defaultPortForWSS)
case Scheme.Custom(_) => None
}

}
object Scheme {

object Scheme {

/**
* Decodes a string to an Option of Scheme. Returns None in case of
* null/non-valid Scheme
*
* The should be lowercase and follow this syntax:
* - Scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
*/
def decode(scheme: String): Option[Scheme] =
Option(unsafe.decode(scheme)(Unsafe.unsafe))

private[zio] object unsafe {
def decode(scheme: String)(implicit unsafe: Unsafe): Scheme = {
if (scheme == null) null
if (scheme == null || scheme.isEmpty) null
else
scheme.length match {
case 5 => Scheme.HTTPS
case 4 => Scheme.HTTP
case 3 => Scheme.WSS
case 2 => Scheme.WS
case _ => null
scheme match {
case "http" => HTTP
case "https" => HTTPS
case "ws" => WS
case "wss" => WSS
case custom => Custom(custom.toLowerCase)
}
}
}
Expand All @@ -78,4 +89,15 @@ object Scheme {
case object WS extends Scheme

case object WSS extends Scheme

/**
* @param scheme
* value MUST not be "http" "https" "ws" "wss"
*/
final case class Custom(scheme: String) extends Scheme

def defaultPortForHTTP = 80
def defaultPortForHTTPS = 443
def defaultPortForWS = 80
def defaultPortForWSS = 443
}
69 changes: 37 additions & 32 deletions zio-http/src/main/scala/zio/http/URL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.Try

import zio.Chunk

import zio.http.URL.{Fragment, Location, portFromScheme}
import zio.http.URL.{Fragment, Location}
import zio.http.internal.QueryParamEncoding

final case class URL(
Expand All @@ -48,10 +48,10 @@ final case class URL(
def /(segment: String): URL = self.copy(path = self.path / segment)

def absolute(host: String): URL =
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP)))
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, None))

def absolute(scheme: Scheme, host: String, port: Int): URL =
self.copy(kind = URL.Location.Absolute(scheme, host, port))
self.copy(kind = URL.Location.Absolute(scheme, host, Some(port)))

def addLeadingSlash: URL = self.copy(path = path.addLeadingSlash)

Expand Down Expand Up @@ -101,20 +101,25 @@ final case class URL(

def host(host: String): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP))
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, None)
case abs: URL.Location.Absolute => abs.copy(host = host)
}
copy(kind = location)
}

/**
* @return
* the location, the host name and the port. The port part is omitted if is
* the default port for the protocol.
*/
def hostPort: Option[String] =
kind match {
case URL.Location.Relative => None
case URL.Location.Absolute(scheme, host, port) =>
Some(
if (port == portFromScheme(scheme)) host
else s"$host:$port",
)
case URL.Location.Relative => None
case abs: URL.Location.Absolute =>
abs.portIfNotDefault match {
case None => Some(abs.host)
case Some(customPort) => Some(s"${abs.host}:$customPort")
}
}

def isAbsolute: Boolean = self.kind match {
Expand All @@ -140,25 +145,26 @@ final case class URL(

def port(port: Int): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", port)
case abs: URL.Location.Absolute => abs.copy(port = port)
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", Some(port))
case abs: URL.Location.Absolute => abs.copy(originalPort = Some(port))
}

copy(kind = location)
}

def port: Option[Int] = kind match {
case URL.Location.Relative => None
case abs: URL.Location.Absolute => Option(abs.port)
case abs: URL.Location.Absolute => abs.originalPort
}

def portOrDefault: Int = port.getOrElse(portFromScheme(scheme.getOrElse(Scheme.HTTP)))
def portOrDefault: Option[Int] = kind match {
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portOrDefault
}

def portIfNotDefault: Option[Int] = kind match {
case URL.Location.Relative =>
None
case abs: URL.Location.Absolute =>
if (abs.port == portFromScheme(abs.scheme)) None else Some(abs.port)
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portIfNotDefault
}

def queryParams(queryParams: QueryParams): URL =
Expand All @@ -185,7 +191,7 @@ final case class URL(

def scheme(scheme: Scheme): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(scheme, "", URL.portFromScheme(scheme))
case URL.Location.Relative => URL.Location.Absolute(scheme, "", None)
case abs: URL.Location.Absolute => abs.copy(scheme = scheme)
}

Expand Down Expand Up @@ -239,7 +245,11 @@ object URL {
}

object Location {
final case class Absolute(scheme: Scheme, host: String, port: Int) extends Location
final case class Absolute(scheme: Scheme, host: String, originalPort: Option[Int]) extends Location {
def portOrDefault: Option[Int] = originalPort.orElse(scheme.defaultPort)
def portIfNotDefault: Option[Int] = originalPort.filter(p => scheme.defaultPort.exists(_ != p))
def port: Int = originalPort.orElse(scheme.defaultPort).getOrElse(Scheme.defaultPortForHTTP)
}

case object Relative extends Location
}
Expand All @@ -262,13 +272,13 @@ object URL {
) + url.fragment.fold("")(f => "#" + f.raw)

url.kind match {
case Location.Relative =>
path(true)
case Location.Absolute(scheme, host, port) =>
case Location.Relative => path(true)
case abs: Location.Absolute =>
val path2 = path(false)

if (port == portFromScheme(scheme)) s"${scheme.encode}://$host$path2"
else s"${scheme.encode}://$host:$port$path2"
abs.portIfNotDefault match {
case None => s"${abs.scheme.encode}://${abs.host}$path2"
case Some(customPort) => s"${abs.scheme.encode}://${abs.host}:$customPort$path2"
}
}
}

Expand All @@ -277,7 +287,7 @@ object URL {
scheme <- Scheme.decode(uri.getScheme)
host <- Option(uri.getHost)
path <- Option(uri.getRawPath)
port = Option(uri.getPort).filter(_ != -1).getOrElse(portFromScheme(scheme))
port = Option(uri.getPort).filter(_ != -1).orElse(scheme.defaultPort) // FIXME REMOVE defaultPort
connection = URL.Location.Absolute(scheme, host, port)
path2 = Path.decode(path)
path3 = if (path.nonEmpty) path2.addLeadingSlash else path2
Expand All @@ -288,9 +298,4 @@ object URL {
path <- Option(uri.getRawPath)
} yield URL(Path.decode(path), Location.Relative, QueryParams.decode(uri.getRawQuery), Fragment.fromURI(uri))

private def portFromScheme(scheme: Scheme): Int = scheme match {
case Scheme.HTTP | Scheme.WS => 80
case Scheme.HTTPS | Scheme.WSS => 443
}

}
20 changes: 8 additions & 12 deletions zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,14 @@ object ZClient {
app: WebSocketApp[Env1],
)(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] =
for {
env <- ZIO.environment[Env1]
webSocketUrl = url.scheme(
url.scheme match {
case Some(Scheme.HTTP) => Scheme.WS
case Some(Scheme.HTTPS) => Scheme.WSS
case Some(Scheme.WS) => Scheme.WS
case Some(Scheme.WSS) => Scheme.WSS
case None => Scheme.WS
},
)
scope <- ZIO.scope
res <- requestAsync(
env <- ZIO.environment[Env1]
webSocketUrl <- url.scheme match {
case Some(Scheme.HTTP) | Some(Scheme.WS) | None => ZIO.succeed(url.scheme(Scheme.WS))
case Some(Scheme.WSS) | Some(Scheme.HTTPS) => ZIO.succeed(url.scheme(Scheme.WSS))
case _ => ZIO.fail(throw new IllegalArgumentException("URL's scheme MUST be WS(S) or HTTP(S)"))
}
scope <- ZIO.scope
res <- requestAsync(
Request(version = version, method = Method.GET, url = webSocketUrl, headers = headers),
config,
() => app.provideEnvironment(env),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object NettyConnectionPool {
case None =>
}

if (location.scheme.isSecure) {
if (location.scheme.isSecure.getOrElse(false)) {
pipeline.addLast(
Names.SSLHandler,
ClientSSLConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec {
response <- client.request(
Request(
method = Method.GET,
url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)),
url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))),
)
.addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())),
)
Expand All @@ -82,7 +82,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec {
response <- client.request(
Request(
method = Method.GET,
url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)),
url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))),
)
.addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())),
)
Expand Down
3 changes: 3 additions & 0 deletions zio-http/src/test/scala/zio/http/SchemeSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ object SchemeSpec extends ZIOHttpSpec {
test("null string decode") {
assert(Scheme.decode(null))(isNone)
},
test("decode chrome-extension") {
assertTrue(Scheme.decode("chrome-extension").isDefined)
},
)
}
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zio/http/URLSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ object URLSpec extends ZIOHttpSpec {
),
suite("normalize")(
test("adds leading slash") {
val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None)
val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None)

val url2 = url.normalize

assertTrue(extractPath(url2) == Path("/a/b/c"))
},
test("deletes leading slash if there are no path segments") {
val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None)
val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None)
val url2 = url.normalize

assertTrue(extractPath(url2) == Path.empty)
Expand Down
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object ZClientAspectSpec extends ZIOHttpSpec {
port <- Server.install(app)
baseClient <- ZIO.service[Client]
client = baseClient.url(
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)),
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))),
) @@ ZClientAspect.debug
response <- client.request(Request.get(URL.empty / "hello"))
output <- TestConsole.output
Expand All @@ -51,7 +51,7 @@ object ZClientAspectSpec extends ZIOHttpSpec {
baseClient <- ZIO.service[Client]
client = baseClient
.url(
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)),
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))),
)
.disableStreaming @@ ZClientAspect.requestLogging(
loggedRequestHeaders = Set(Header.UserAgent),
Expand Down
1 change: 1 addition & 0 deletions zio-http/src/test/scala/zio/http/headers/OriginSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ object OriginSpec extends ZIOHttpSpec {
assertTrue(
Origin.parse("http://domain") == Right(Value("http", "domain", None)),
Origin.parse("https://domain") == Right(Value("https", "domain", None)),
Origin.parse("chrome-extension://appid") == Right(Value("chrome-extension", "appid", None)),
)
},
test("parsing of valid Origin values") {
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/test/scala/zio/http/internal/HttpGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object HttpGen {
scheme <- Gen.fromIterable(List(Scheme.HTTP, Scheme.HTTPS))
host <- Gen.alphaNumericStringBounded(1, 5)
port <- Gen.oneOf(Gen.const(80), Gen.const(443), Gen.int(0, 65536))
} yield URL.Location.Absolute(scheme, host, port)
} yield URL.Location.Absolute(scheme, host, Some(port))

def genRelativeURL: Gen[Any, URL] = for {
path <- HttpGen.anyPath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self =>
client(
params
.addHeader(DynamicServer.APP_ID, id)
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))),
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))),
)
.flatMap(_.collect)
}
Expand Down Expand Up @@ -80,7 +80,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self =>
client(
params
.addHeader(DynamicServer.APP_ID, id)
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))),
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))),
)
}
} yield response
Expand Down

0 comments on commit 33512bc

Please sign in to comment.