Skip to content

Commit

Permalink
Merge branch 'main' into long-url
Browse files Browse the repository at this point in the history
  • Loading branch information
mschuwalow authored Nov 27, 2023
2 parents 658cfc8 + b43dbc4 commit 35e6b82
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = 3.7.14
version = 3.7.17
maxColumn = 120

align.preset = more
Expand Down
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import sbt.Keys.scalaVersion
object Dependencies {
val JwtCoreVersion = "9.1.1"
val NettyVersion = "4.1.101.Final"
val NettyIncubatorVersion = "0.0.20.Final"
val NettyIncubatorVersion = "0.0.24.Final"
val ScalaCompactCollectionVersion = "2.11.0"
val ZioVersion = "2.0.19"
val ZioCliVersion = "0.5.0"
val ZioSchemaVersion = "0.4.15"
val ZioSchemaVersion = "0.4.16"
val SttpVersion = "3.3.18"

val `jwt-core` = "com.github.jwt-scala" %% "jwt-core" % JwtCoreVersion
Expand Down
10 changes: 5 additions & 5 deletions zio-http/src/main/scala/zio/http/FormField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ object FormField {

private[http] def getContentType(ast: Chunk[FormAST]): MediaType =
ast.collectFirst {
case header: FormAST.Header if header.name == "Content-Type" =>
case header: FormAST.Header if header.name.equalsIgnoreCase("Content-Type") =>
MediaType
.forContentType(header.value)
.getOrElse(MediaType.application.`octet-stream`) // Unknown content type defaults to binary
Expand All @@ -200,13 +200,13 @@ object FormField {
)(implicit trace: Trace): ZIO[Any, FormDecodingError, FormField] = {
val extract =
ast.foldLeft((Option.empty[FormAST.Header], Option.empty[FormAST.Header], Option.empty[FormAST.Header])) {
case (accum, header: FormAST.Header) if header.name == "Content-Disposition" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Disposition") =>
(Some(header), accum._2, accum._3)
case (accum, header: FormAST.Header) if header.name == "Content-Type" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Type") =>
(accum._1, Some(header), accum._3)
case (accum, header: FormAST.Header) if header.name == "Content-Transfer-Encoding" =>
case (accum, header: FormAST.Header) if header.name.equalsIgnoreCase("Content-Transfer-Encoding") =>
(accum._1, accum._2, Some(header))
case (accum, _) => accum
case (accum, _) => accum
}

for {
Expand Down
15 changes: 6 additions & 9 deletions zio-http/src/main/scala/zio/http/Header.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.util.{Either, Failure, Success, Try}
import zio._

import zio.http.codec.RichTextCodec
import zio.http.endpoint.openapi.OpenAPI.SecurityScheme.Http
import zio.http.internal.DateEncoding

sealed trait Header {
Expand Down Expand Up @@ -2480,16 +2481,12 @@ object Header {
private val codec: RichTextCodec[ContentType] = {

// char `.` according to BNF not allowed as `token`, but here tolerated
val token = RichTextCodec.filter(_ => true).validate("not a token") {
case ' ' | '(' | ')' | '<' | '>' | '@' | ',' | ';' | ':' | '\\' | '"' | '/' | '[' | ']' | '?' | '=' => false
case _ => true
}
val tokenQuoted = RichTextCodec.filter(_ => true).validate("not a quoted token") {
case ' ' | '"' => false
case _ => true
}
val token = RichTextCodec.charsNot(' ', '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=')

val tokenQuoted = RichTextCodec.charsNot(' ', '"')

val type1 = RichTextCodec.string.collectOrFail("unsupported main type") {
case value if MediaType.mainTypeMap.get(value).isDefined => value
case value if MediaType.mainTypeMap.contains(value) => value
}
val type1x = (RichTextCodec.literalCI("x-") ~ token.repeat.string).transform[String](in => s"${in._1}${in._2}")(in => ("x-", s"${in.substring(2)}"))
val codecType1 = (type1 | type1x).transform[String](_.merge) {
Expand Down
66 changes: 44 additions & 22 deletions zio-http/src/main/scala/zio/http/Scheme.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,63 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace

sealed trait Scheme { self =>
def encode: String = self match {
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.HTTP => "http"
case Scheme.HTTPS => "https"
case Scheme.WS => "ws"
case Scheme.WSS => "wss"
case Scheme.Custom(scheme) => scheme
}

def isHttp: Boolean = !isWebSocket
def isHttp: Boolean = self match {
case Scheme.HTTP | Scheme.HTTPS => true
case _ => false
}

def isWebSocket: Boolean = self match {
case Scheme.WS => true
case Scheme.WSS => true
case _ => false
}

def isSecure: Boolean = self match {
case Scheme.HTTPS => true
case Scheme.WSS => true
case _ => false
def isSecure: Option[Boolean] = self match {
case Scheme.HTTPS | Scheme.WSS => Some(true)
case Scheme.HTTP | Scheme.WS => Some(false)
case _ => None
}

def defaultPort: Int = self match {
case Scheme.HTTP => 80
case Scheme.HTTPS => 443
case Scheme.WS => 80
case Scheme.WSS => 443
/** default ports is only define for the Schemes: http, https, ws, wss */
def defaultPort: Option[Int] = self match {
case Scheme.HTTP => Some(Scheme.defaultPortForHTTP)
case Scheme.HTTPS => Some(Scheme.defaultPortForHTTPS)
case Scheme.WS => Some(Scheme.defaultPortForWS)
case Scheme.WSS => Some(Scheme.defaultPortForWSS)
case Scheme.Custom(_) => None
}

}
object Scheme {

object Scheme {

/**
* Decodes a string to an Option of Scheme. Returns None in case of
* null/non-valid Scheme
*
* The should be lowercase and follow this syntax:
* - Scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
*/
def decode(scheme: String): Option[Scheme] =
Option(unsafe.decode(scheme)(Unsafe.unsafe))

private[zio] object unsafe {
def decode(scheme: String)(implicit unsafe: Unsafe): Scheme = {
if (scheme == null) null
if (scheme == null || scheme.isEmpty) null
else
scheme.length match {
case 5 => Scheme.HTTPS
case 4 => Scheme.HTTP
case 3 => Scheme.WSS
case 2 => Scheme.WS
case _ => null
scheme match {
case "http" => HTTP
case "https" => HTTPS
case "ws" => WS
case "wss" => WSS
case custom => Custom(custom.toLowerCase)
}
}
}
Expand All @@ -78,4 +89,15 @@ object Scheme {
case object WS extends Scheme

case object WSS extends Scheme

/**
* @param scheme
* value MUST not be "http" "https" "ws" "wss"
*/
final case class Custom(scheme: String) extends Scheme

def defaultPortForHTTP = 80
def defaultPortForHTTPS = 443
def defaultPortForWS = 80
def defaultPortForWSS = 443
}
69 changes: 37 additions & 32 deletions zio-http/src/main/scala/zio/http/URL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.Try

import zio.Chunk

import zio.http.URL.{Fragment, Location, portFromScheme}
import zio.http.URL.{Fragment, Location}
import zio.http.internal.QueryParamEncoding

final case class URL(
Expand All @@ -48,10 +48,10 @@ final case class URL(
def /(segment: String): URL = self.copy(path = self.path / segment)

def absolute(host: String): URL =
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP)))
self.copy(kind = URL.Location.Absolute(Scheme.HTTP, host, None))

def absolute(scheme: Scheme, host: String, port: Int): URL =
self.copy(kind = URL.Location.Absolute(scheme, host, port))
self.copy(kind = URL.Location.Absolute(scheme, host, Some(port)))

def addLeadingSlash: URL = self.copy(path = path.addLeadingSlash)

Expand Down Expand Up @@ -101,20 +101,25 @@ final case class URL(

def host(host: String): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, URL.portFromScheme(Scheme.HTTP))
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, host, None)
case abs: URL.Location.Absolute => abs.copy(host = host)
}
copy(kind = location)
}

/**
* @return
* the location, the host name and the port. The port part is omitted if is
* the default port for the protocol.
*/
def hostPort: Option[String] =
kind match {
case URL.Location.Relative => None
case URL.Location.Absolute(scheme, host, port) =>
Some(
if (port == portFromScheme(scheme)) host
else s"$host:$port",
)
case URL.Location.Relative => None
case abs: URL.Location.Absolute =>
abs.portIfNotDefault match {
case None => Some(abs.host)
case Some(customPort) => Some(s"${abs.host}:$customPort")
}
}

def isAbsolute: Boolean = self.kind match {
Expand All @@ -140,25 +145,26 @@ final case class URL(

def port(port: Int): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", port)
case abs: URL.Location.Absolute => abs.copy(port = port)
case URL.Location.Relative => URL.Location.Absolute(Scheme.HTTP, "", Some(port))
case abs: URL.Location.Absolute => abs.copy(originalPort = Some(port))
}

copy(kind = location)
}

def port: Option[Int] = kind match {
case URL.Location.Relative => None
case abs: URL.Location.Absolute => Option(abs.port)
case abs: URL.Location.Absolute => abs.originalPort
}

def portOrDefault: Int = port.getOrElse(portFromScheme(scheme.getOrElse(Scheme.HTTP)))
def portOrDefault: Option[Int] = kind match {
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portOrDefault
}

def portIfNotDefault: Option[Int] = kind match {
case URL.Location.Relative =>
None
case abs: URL.Location.Absolute =>
if (abs.port == portFromScheme(abs.scheme)) None else Some(abs.port)
case URL.Location.Relative => None
case abs: URL.Location.Absolute => abs.portIfNotDefault
}

def queryParams(queryParams: QueryParams): URL =
Expand All @@ -185,7 +191,7 @@ final case class URL(

def scheme(scheme: Scheme): URL = {
val location = kind match {
case URL.Location.Relative => URL.Location.Absolute(scheme, "", URL.portFromScheme(scheme))
case URL.Location.Relative => URL.Location.Absolute(scheme, "", None)
case abs: URL.Location.Absolute => abs.copy(scheme = scheme)
}

Expand Down Expand Up @@ -239,7 +245,11 @@ object URL {
}

object Location {
final case class Absolute(scheme: Scheme, host: String, port: Int) extends Location
final case class Absolute(scheme: Scheme, host: String, originalPort: Option[Int]) extends Location {
def portOrDefault: Option[Int] = originalPort.orElse(scheme.defaultPort)
def portIfNotDefault: Option[Int] = originalPort.filter(p => scheme.defaultPort.exists(_ != p))
def port: Int = originalPort.orElse(scheme.defaultPort).getOrElse(Scheme.defaultPortForHTTP)
}

case object Relative extends Location
}
Expand All @@ -262,13 +272,13 @@ object URL {
) + url.fragment.fold("")(f => "#" + f.raw)

url.kind match {
case Location.Relative =>
path(true)
case Location.Absolute(scheme, host, port) =>
case Location.Relative => path(true)
case abs: Location.Absolute =>
val path2 = path(false)

if (port == portFromScheme(scheme)) s"${scheme.encode}://$host$path2"
else s"${scheme.encode}://$host:$port$path2"
abs.portIfNotDefault match {
case None => s"${abs.scheme.encode}://${abs.host}$path2"
case Some(customPort) => s"${abs.scheme.encode}://${abs.host}:$customPort$path2"
}
}
}

Expand All @@ -277,7 +287,7 @@ object URL {
scheme <- Scheme.decode(uri.getScheme)
host <- Option(uri.getHost)
path <- Option(uri.getRawPath)
port = Option(uri.getPort).filter(_ != -1).getOrElse(portFromScheme(scheme))
port = Option(uri.getPort).filter(_ != -1).orElse(scheme.defaultPort) // FIXME REMOVE defaultPort
connection = URL.Location.Absolute(scheme, host, port)
path2 = Path.decode(path)
path3 = if (path.nonEmpty) path2.addLeadingSlash else path2
Expand All @@ -288,9 +298,4 @@ object URL {
path <- Option(uri.getRawPath)
} yield URL(Path.decode(path), Location.Relative, QueryParams.decode(uri.getRawQuery), Fragment.fromURI(uri))

private def portFromScheme(scheme: Scheme): Int = scheme match {
case Scheme.HTTP | Scheme.WS => 80
case Scheme.HTTPS | Scheme.WSS => 443
}

}
20 changes: 8 additions & 12 deletions zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,14 @@ object ZClient {
app: WebSocketApp[Env1],
)(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] =
for {
env <- ZIO.environment[Env1]
webSocketUrl = url.scheme(
url.scheme match {
case Some(Scheme.HTTP) => Scheme.WS
case Some(Scheme.HTTPS) => Scheme.WSS
case Some(Scheme.WS) => Scheme.WS
case Some(Scheme.WSS) => Scheme.WSS
case None => Scheme.WS
},
)
scope <- ZIO.scope
res <- requestAsync(
env <- ZIO.environment[Env1]
webSocketUrl <- url.scheme match {
case Some(Scheme.HTTP) | Some(Scheme.WS) | None => ZIO.succeed(url.scheme(Scheme.WS))
case Some(Scheme.WSS) | Some(Scheme.HTTPS) => ZIO.succeed(url.scheme(Scheme.WSS))
case _ => ZIO.fail(throw new IllegalArgumentException("URL's scheme MUST be WS(S) or HTTP(S)"))
}
scope <- ZIO.scope
res <- requestAsync(
Request(version = version, method = Method.GET, url = webSocketUrl, headers = headers),
config,
() => app.provideEnvironment(env),
Expand Down
Loading

0 comments on commit 35e6b82

Please sign in to comment.