Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make type Scheme follow the specs #2490

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions zio-http/src/main/scala/zio/http/Scheme.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,63 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace

sealed trait Scheme { self =>
def encode: String = self match {
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.Custom(scheme) => scheme
}

def isHttp: Boolean = !isWebSocket
def isHttp: Boolean = self match {
case Scheme.HTTP | Scheme.HTTPS => true
case _ => false
}

def isWebSocket: Boolean = self match {
case Scheme.WS => true
case Scheme.WSS => true
case _ => false
}

def isSecure: Boolean = self match {
case Scheme.HTTPS => true
case Scheme.WSS => true
case _ => false
def isSecure: Option[Boolean] = self match {
case Scheme.HTTPS | Scheme.WSS => Some(true)
case Scheme.HTTP | Scheme.WS => Some(false)
case _ => None
}

def defaultPort: Int = self match {
case Scheme.HTTP => 80
case Scheme.HTTPS => 443
case Scheme.WS => 80
case Scheme.WSS => 443
/** default ports is only define for the Schemes: http, https, ws, wss */
def defaultPort: Option[Int] = self match {
case Scheme.HTTP => Some(Scheme.defaultPortForHTTP)
case Scheme.HTTPS => Some(Scheme.defaultPortForHTTPS)
case Scheme.WS => Some(Scheme.defaultPortForWS)
case Scheme.WSS => Some(Scheme.defaultPortForWSS)
case Scheme.Custom(_) => None
}

}
object Scheme {

object Scheme {

/**
* Decodes a string to an Option of Scheme. Returns None in case of
* null/non-valid Scheme
*
* The should be lowercase and follow this syntax:
* - Scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
*/
def decode(scheme: String): Option[Scheme] =
Option(unsafe.decode(scheme)(Unsafe.unsafe))

private[zio] object unsafe {
def decode(scheme: String)(implicit unsafe: Unsafe): Scheme = {
if (scheme == null) null
if (scheme == null || scheme.isEmpty) null
else
scheme.length match {
case 5 => Scheme.HTTPS
case 4 => Scheme.HTTP
case 3 => Scheme.WSS
case 2 => Scheme.WS
case _ => null
scheme match {
case "http" => HTTP
case "https" => HTTPS
case "ws" => WS
case "wss" => WSS
case custom => Custom(custom.toLowerCase)
}
}
}
Expand All @@ -78,4 +89,15 @@ object Scheme {
case object WS extends Scheme

case object WSS extends Scheme

/**
* @param scheme
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would actually happen, if this is http(s)/ws(s)?

Copy link
Contributor Author

@FabioPinheiro FabioPinheiro Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a question of pattern matching.

We are allowing two representations for the same value. (Scheme.HTTP and Scheme.Custom("http"))
Since the construction method is public.
But I see this pattern is being used everywhere in zio-http so it's a more general problem

One solution would be to make it private and the only way to create Custom is with the Scheme's method def decode(scheme: String).
This would make it a bit more type-safe. Instead of being on the documentation

* value MUST not be "http" "https" "ws" "wss"
*/
final case class Custom(scheme: String) extends Scheme

def defaultPortForHTTP = 80
def defaultPortForHTTPS = 443
def defaultPortForWS = 80
def defaultPortForWSS = 443
}
69 changes: 37 additions & 32 deletions zio-http/src/main/scala/zio/http/URL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.Try

import zio.Chunk

import zio.http.URL.{Fragment, Location, portFromScheme}
import zio.http.URL.{Fragment, Location}
import zio.http.internal.QueryParamEncoding

final case class URL(
Expand All @@ -48,10 +48,10 @@ final case class URL(
def /(segment: String): URL = self.copy(path = self.path / segment)

def absolute(host: String): URL =
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP)))
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, None))

def absolute(scheme: Scheme, host: String, port: Int): URL =
self.copy(kind = URL.Location.Absolute(scheme, host, port))
self.copy(kind = URL.Location.Absolute(scheme, host, Some(port)))

def addLeadingSlash: URL = self.copy(path = path.addLeadingSlash)

Expand Down Expand Up @@ -101,20 +101,25 @@ final case class URL(

def host(host: String): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP))
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, None)
case abs: URL.Location.Absolute => abs.copy(host = host)
}
copy(kind = location)
}

/**
* @return
* the location, the host name and the port. The port part is omitted if is
* the default port for the protocol.
*/
def hostPort: Option[String] =
kind match {
case URL.Location.Relative => None
case URL.Location.Absolute(scheme, host, port) =>
Some(
if (port == portFromScheme(scheme)) host
else s"$host:$port",
)
case URL.Location.Relative => None
case abs: URL.Location.Absolute =>
abs.portIfNotDefault match {
case None => Some(abs.host)
case Some(customPort) => Some(s"${abs.host}:$customPort")
}
}

def isAbsolute: Boolean = self.kind match {
Expand All @@ -140,25 +145,26 @@ final case class URL(

def port(port: Int): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", port)
case abs: URL.Location.Absolute => abs.copy(port = port)
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", Some(port))
case abs: URL.Location.Absolute => abs.copy(originalPort = Some(port))
}

copy(kind = location)
}

def port: Option[Int] = kind match {
case URL.Location.Relative => None
case abs: URL.Location.Absolute => Option(abs.port)
case abs: URL.Location.Absolute => abs.originalPort
}

def portOrDefault: Int = port.getOrElse(portFromScheme(scheme.getOrElse(Scheme.HTTP)))
def portOrDefault: Option[Int] = kind match {
FabioPinheiro marked this conversation as resolved.
Show resolved Hide resolved
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portOrDefault
}

def portIfNotDefault: Option[Int] = kind match {
case URL.Location.Relative =>
None
case abs: URL.Location.Absolute =>
if (abs.port == portFromScheme(abs.scheme)) None else Some(abs.port)
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portIfNotDefault
}

def queryParams(queryParams: QueryParams): URL =
Expand All @@ -185,7 +191,7 @@ final case class URL(

def scheme(scheme: Scheme): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(scheme, "", URL.portFromScheme(scheme))
case URL.Location.Relative => URL.Location.Absolute(scheme, "", None)
case abs: URL.Location.Absolute => abs.copy(scheme = scheme)
}

Expand Down Expand Up @@ -239,7 +245,11 @@ object URL {
}

object Location {
final case class Absolute(scheme: Scheme, host: String, port: Int) extends Location
final case class Absolute(scheme: Scheme, host: String, originalPort: Option[Int]) extends Location {
def portOrDefault: Option[Int] = originalPort.orElse(scheme.defaultPort)
def portIfNotDefault: Option[Int] = originalPort.filter(p => scheme.defaultPort.exists(_ != p))
def port: Int = originalPort.orElse(scheme.defaultPort).getOrElse(Scheme.defaultPortForHTTP)
}

case object Relative extends Location
}
Expand All @@ -262,13 +272,13 @@ object URL {
) + url.fragment.fold("")(f => "#" + f.raw)

url.kind match {
case Location.Relative =>
path(true)
case Location.Absolute(scheme, host, port) =>
case Location.Relative => path(true)
case abs: Location.Absolute =>
val path2 = path(false)

if (port == portFromScheme(scheme)) s"${scheme.encode}://$host$path2"
else s"${scheme.encode}://$host:$port$path2"
abs.portIfNotDefault match {
case None => s"${abs.scheme.encode}://${abs.host}$path2"
case Some(customPort) => s"${abs.scheme.encode}://${abs.host}:$customPort$path2"
}
}
}

Expand All @@ -277,7 +287,7 @@ object URL {
scheme <- Scheme.decode(uri.getScheme)
host <- Option(uri.getHost)
path <- Option(uri.getRawPath)
port = Option(uri.getPort).filter(_ != -1).getOrElse(portFromScheme(scheme))
port = Option(uri.getPort).filter(_ != -1).orElse(scheme.defaultPort) // FIXME REMOVE defaultPort
connection = URL.Location.Absolute(scheme, host, port)
path2 = Path.decode(path)
path3 = if (path.nonEmpty) path2.addLeadingSlash else path2
Expand All @@ -288,9 +298,4 @@ object URL {
path <- Option(uri.getRawPath)
} yield URL(Path.decode(path), Location.Relative, QueryParams.decode(uri.getRawQuery), Fragment.fromURI(uri))

private def portFromScheme(scheme: Scheme): Int = scheme match {
case Scheme.HTTP | Scheme.WS => 80
case Scheme.HTTPS | Scheme.WSS => 443
}

}
20 changes: 8 additions & 12 deletions zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,14 @@ object ZClient {
app: WebSocketApp[Env1],
)(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] =
for {
env <- ZIO.environment[Env1]
webSocketUrl = url.scheme(
url.scheme match {
case Some(Scheme.HTTP) => Scheme.WS
case Some(Scheme.HTTPS) => Scheme.WSS
case Some(Scheme.WS) => Scheme.WS
case Some(Scheme.WSS) => Scheme.WSS
case None => Scheme.WS
},
)
scope <- ZIO.scope
res <- requestAsync(
env <- ZIO.environment[Env1]
webSocketUrl <- url.scheme match {
case Some(Scheme.HTTP) | Some(Scheme.WS) | None => ZIO.succeed(url.scheme(Scheme.WS))
case Some(Scheme.WSS) | Some(Scheme.HTTPS) => ZIO.succeed(url.scheme(Scheme.WSS))
case _ => ZIO.fail(throw new IllegalArgumentException("URL's scheme MUST be WS(S) or HTTP(S)"))
}
scope <- ZIO.scope
res <- requestAsync(
Request(version = version, method = Method.GET, url = webSocketUrl, headers = headers),
config,
() => app.provideEnvironment(env),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object NettyConnectionPool {
case None =>
}

if (location.scheme.isSecure) {
if (location.scheme.isSecure.getOrElse(false)) {
pipeline.addLast(
Names.SSLHandler,
ClientSSLConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec {
response <- client.request(
Request(
method = Method.GET,
url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)),
url = URL(Root / "text", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))),
)
.addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())),
)
Expand All @@ -82,7 +82,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec {
response <- client.request(
Request(
method = Method.GET,
url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", server.port)),
url = URL(Root / "stream", kind = URL.Location.Absolute(Scheme.HTTP, "localhost", Some(server.port))),
)
.addHeader(Header.AcceptEncoding(Header.AcceptEncoding.GZip(), Header.AcceptEncoding.Deflate())),
)
Expand Down
3 changes: 3 additions & 0 deletions zio-http/src/test/scala/zio/http/SchemeSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ object SchemeSpec extends ZIOHttpSpec {
test("null string decode") {
assert(Scheme.decode(null))(isNone)
},
test("decode chrome-extension") {
assertTrue(Scheme.decode("chrome-extension").isDefined)
},
)
}
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zio/http/URLSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ object URLSpec extends ZIOHttpSpec {
),
suite("normalize")(
test("adds leading slash") {
val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None)
val url = URL(Path("a/b/c"), URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None)

val url2 = url.normalize

assertTrue(extractPath(url2) == Path("/a/b/c"))
},
test("deletes leading slash if there are no path segments") {
val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", 80), QueryParams.empty, None)
val url = URL(Path.root, URL.Location.Absolute(Scheme.HTTP, "abc.com", Some(80)), QueryParams.empty, None)
val url2 = url.normalize

assertTrue(extractPath(url2) == Path.empty)
Expand Down
4 changes: 2 additions & 2 deletions zio-http/src/test/scala/zio/http/ZClientAspectSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object ZClientAspectSpec extends ZIOHttpSpec {
port <- Server.install(app)
baseClient <- ZIO.service[Client]
client = baseClient.url(
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)),
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))),
) @@ ZClientAspect.debug
response <- client.request(Request.get(URL.empty / "hello"))
output <- TestConsole.output
Expand All @@ -51,7 +51,7 @@ object ZClientAspectSpec extends ZIOHttpSpec {
baseClient <- ZIO.service[Client]
client = baseClient
.url(
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", port)),
URL(Path.empty, Location.Absolute(Scheme.HTTP, "localhost", Some(port))),
)
.disableStreaming @@ ZClientAspect.requestLogging(
loggedRequestHeaders = Set(Header.UserAgent),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ object OriginSpec extends ZIOHttpSpec {
assertTrue(
Origin.parse("http://domain") == Right(Value("http", "domain", None)),
Origin.parse("https://domain") == Right(Value("https", "domain", None)),
Origin.parse("chrome-extension://appid") == Right(Value("chrome-extension", "appid", None)),
)
},
test("parsing of valid Origin values") {
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/test/scala/zio/http/internal/HttpGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object HttpGen {
scheme <- Gen.fromIterable(List(Scheme.HTTP, Scheme.HTTPS))
host <- Gen.alphaNumericStringBounded(1, 5)
port <- Gen.oneOf(Gen.const(80), Gen.const(443), Gen.int(0, 65536))
} yield URL.Location.Absolute(scheme, host, port)
} yield URL.Location.Absolute(scheme, host, Some(port))

def genRelativeURL: Gen[Any, URL] = for {
path <- HttpGen.anyPath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self =>
client(
params
.addHeader(DynamicServer.APP_ID, id)
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))),
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))),
)
.flatMap(_.collect)
}
Expand Down Expand Up @@ -80,7 +80,7 @@ abstract class HttpRunnableSpec extends ZIOHttpSpec { self =>
client(
params
.addHeader(DynamicServer.APP_ID, id)
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", port))),
.copy(url = URL(params.url.path, Location.Absolute(Scheme.HTTP, "localhost", Some(port)))),
)
}
} yield response
Expand Down
Loading