diff --git a/lib/bandit/websocket/README.md b/lib/bandit/websocket/README.md index acd6024c..ba76d6ff 100644 --- a/lib/bandit/websocket/README.md +++ b/lib/bandit/websocket/README.md @@ -17,14 +17,14 @@ The HTTP request containing the upgrade request is first passed to the user's application as a standard Plug call. After inspecting the request and deeming it a suitable upgrade candidate (via whatever policy the application dictates), the user indicates a desire to upgrade the connection to a WebSocket by calling -`Plug.Conn.upgrade_adapter/3` (this is most commonly done by calling -`WebSockAdapter.upgrade/4`, which wraps this underlying call in -a server-agnostic manner). At the conclusion of the `Plug.call/2` callback, -`Bandit.Pipeline` will then attempy to upgrade the underlying connection. As -part of this upgrade process, `Bandit.DelegatingHandler` will switch the -Handler for the connection to be `Bandit.WebSocket.Handler`. This will cause any -future communication after the upgrade process to be handled directly by -Bandit's WebSocket stack. +`WebSockAdapter.upgrade/4`, which checks that the request is a valid WebSocket +upgrade request, and then calls `Plug.Conn.upgrade_adapter/3` to signal to +Bandit that the connection should be upgraded at the conclusion of the request. +At the conclusion of the `Plug.call/2` callback, `Bandit.Pipeline` will then +attempt to upgrade the underlying connection. As part of this upgrade process, +`Bandit.DelegatingHandler` will switch the Handler for the connection to be +`Bandit.WebSocket.Handler`. This will cause any future communication after the +upgrade process to be handled directly by Bandit's WebSocket stack. ## Process model @@ -41,7 +41,7 @@ modeled by the `Bandit.WebSocket.Connection` struct and module. All data subsequently received by the underlying [Thousand Island](https://github.com/mtrudel/thousand_island) library will result in -a call to `Bandit.WebSocket.Handler.handle_data/3`, which will then attmept to +a call to `Bandit.WebSocket.Handler.handle_data/3`, which will then attempt to parse the data into one or more WebSocket frames. Once a frame has been constructed, it is them passed through to the configured `WebSock` handler by way of the underlying `Bandit.WebSocket.Connection`. diff --git a/lib/bandit/websocket/handshake.ex b/lib/bandit/websocket/handshake.ex index c9b349f0..c3f4dc55 100644 --- a/lib/bandit/websocket/handshake.ex +++ b/lib/bandit/websocket/handshake.ex @@ -6,37 +6,10 @@ defmodule Bandit.WebSocket.Handshake do @type extensions :: [{String.t(), [{String.t(), String.t() | true}]}] - @spec valid_upgrade?(Plug.Conn.t()) :: boolean() - def valid_upgrade?(%Plug.Conn{} = conn) do - validate_upgrade(conn) == :ok - end - - @spec validate_upgrade(Plug.Conn.t()) :: :ok | {:error, String.t()} - defp validate_upgrade(conn) do - # Cases from RFC6455§4.2.1 - with {:http_version, :"HTTP/1.1"} <- {:http_version, get_http_protocol(conn)}, - {:method, "GET"} <- {:method, conn.method}, - {:host_header, header} when header != [] <- {:host_header, get_req_header(conn, "host")}, - {:upgrade_header, true} <- - {:upgrade_header, header_contains(conn, "upgrade", "websocket")}, - {:connection_header, true} <- - {:connection_header, header_contains(conn, "connection", "upgrade")}, - {:sec_websocket_key_header, true} <- - {:sec_websocket_key_header, - match?([<<_::binary>>], get_req_header(conn, "sec-websocket-key"))}, - {:sec_websocket_version_header, ["13"]} <- - {:sec_websocket_version_header, get_req_header(conn, "sec-websocket-version")} do - :ok - else - {step, detail} -> - {:error, "WebSocket upgrade failed: error in #{step} check: #{inspect(detail)}"} - end - end - @spec handshake(Plug.Conn.t(), keyword(), keyword()) :: {:ok, Plug.Conn.t(), Keyword.t()} | {:error, String.t()} def handshake(%Plug.Conn{} = conn, opts, websocket_opts) do - with :ok <- validate_upgrade(conn) do + with :ok <- Bandit.WebSocket.UpgradeValidation.validate_upgrade(conn) do do_handshake(conn, opts, websocket_opts) end end @@ -126,19 +99,4 @@ defmodule Bandit.WebSocket.Handshake do put_resp_header(conn, "sec-websocket-extensions", extensions) end - - @spec header_contains(Plug.Conn.t(), field :: String.t(), value :: String.t()) :: - true | binary() - defp header_contains(conn, field, value) do - downcase_value = String.downcase(value, :ascii) - header = get_req_header(conn, field) - - header - |> Enum.flat_map(&Plug.Conn.Utils.list/1) - |> Enum.any?(&(String.downcase(&1, :ascii) == downcase_value)) - |> case do - true -> true - false -> "Did not find '#{value}' in '#{header}'" - end - end end diff --git a/lib/bandit/websocket/upgrade_validation.ex b/lib/bandit/websocket/upgrade_validation.ex new file mode 100644 index 00000000..3ba04712 --- /dev/null +++ b/lib/bandit/websocket/upgrade_validation.ex @@ -0,0 +1,65 @@ +defmodule Bandit.WebSocket.UpgradeValidation do + @moduledoc false + # Provides validation of WebSocket upgrade requests as described in RFC6455§4.2 + + # Validates that the request satisfies the requirements to issue a WebSocket upgrade response. + # Validations are performed based on the clauses laid out in RFC6455§4.2 + # + # This function does not actually perform an upgrade or change the connection in any way + # + # Returns `:ok` if the connection satisfies the requirements for a WebSocket upgrade, and + # `{:error, reason}` if not + # + @spec validate_upgrade(Plug.Conn.t()) :: :ok | {:error, String.t()} + def validate_upgrade(conn) do + case Plug.Conn.get_http_protocol(conn) do + :"HTTP/1.1" -> validate_upgrade_http1(conn) + other -> {:error, "HTTP version #{other} unsupported"} + end + end + + # Validate the conn per RFC6455§4.2.1 + defp validate_upgrade_http1(conn) do + with :ok <- assert_method(conn, "GET"), + :ok <- assert_header_nonempty(conn, "host"), + :ok <- assert_header_contains(conn, "connection", "upgrade"), + :ok <- assert_header_contains(conn, "upgrade", "websocket"), + :ok <- assert_header_nonempty(conn, "sec-websocket-key"), + :ok <- assert_header_equals(conn, "sec-websocket-version", "13") do + :ok + end + end + + defp assert_method(conn, verb) do + case conn.method do + ^verb -> :ok + other -> {:error, "HTTP method #{other} unsupported"} + end + end + + defp assert_header_nonempty(conn, header) do + case Plug.Conn.get_req_header(conn, header) do + [] -> {:error, "Header #{header} is absent"} + _ -> :ok + end + end + + defp assert_header_equals(conn, header, value) do + case Plug.Conn.get_req_header(conn, header) |> Enum.map(&String.downcase(&1, :ascii)) do + [^value] -> :ok + header_value -> {:error, "Header #{header} #{inspect(header_value)} not equal to #{value}"} + end + end + + defp assert_header_contains(conn, header, value) do + header_value = Plug.Conn.get_req_header(conn, header) + + header_value + |> Enum.flat_map(&Plug.Conn.Utils.list/1) + |> Enum.any?(&(String.downcase(&1, :ascii) == value)) + |> case do + true -> :ok + false -> {:error, "Header #{header} #{inspect(header_value)} does not contain #{value}"} + end + end +end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index f9f23176..1acb847e 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -638,13 +638,12 @@ defmodule HTTP1RequestTest do ) assert SimpleHTTP1Client.recv_reply(client) - ~> {:ok, "400 Bad Request", list(), - "WebSocket upgrade failed: error in method check: \"POST\""} + ~> {:ok, "400 Bad Request", list(), "HTTP method POST unsupported"} Process.sleep(100) end) - assert errors =~ "WebSocket upgrade failed: error in method check: \\\"POST\\\"" + assert errors =~ "HTTP method POST unsupported" end test "returns a 400 and errors loudly in cases where an upgrade is indicated but upgrade header is incorrect", @@ -668,13 +667,12 @@ defmodule HTTP1RequestTest do assert SimpleHTTP1Client.recv_reply(client) ~> {:ok, "400 Bad Request", list(), - "WebSocket upgrade failed: error in upgrade_header check: \"Did not find 'websocket' in 'NOPE'\""} + "Header upgrade [\"NOPE\"] does not contain websocket"} Process.sleep(100) end) - assert errors =~ - "WebSocket upgrade failed: error in upgrade_header check: \\\"Did not find 'websocket' in 'NOPE'\\\"" + assert errors =~ "Header upgrade [\\\"NOPE\\\"] does not contain websocket" end test "returns a 400 and errors loudly in cases where an upgrade is indicated but connection header is incorrect", @@ -698,13 +696,12 @@ defmodule HTTP1RequestTest do assert SimpleHTTP1Client.recv_reply(client) ~> {:ok, "400 Bad Request", list(), - "WebSocket upgrade failed: error in connection_header check: \"Did not find 'upgrade' in 'NOPE'\""} + "Header connection [\"NOPE\"] does not contain upgrade"} Process.sleep(100) end) - assert errors =~ - "WebSocket upgrade failed: error in connection_header check: \\\"Did not find 'upgrade' in 'NOPE'\\\"" + assert errors =~ "Header connection [\\\"NOPE\\\"] does not contain upgrade" end test "returns a 400 and errors loudly in cases where an upgrade is indicated but key header is incorrect", @@ -726,13 +723,12 @@ defmodule HTTP1RequestTest do ) assert SimpleHTTP1Client.recv_reply(client) - ~> {:ok, "400 Bad Request", list(), - "WebSocket upgrade failed: error in sec_websocket_key_header check: false"} + ~> {:ok, "400 Bad Request", list(), "Header sec-websocket-key is absent"} Process.sleep(100) end) - assert errors =~ "WebSocket upgrade failed: error in sec_websocket_key_header check: false" + assert errors =~ "Header sec-websocket-key is absent" end test "returns a 400 and errors loudly in cases where an upgrade is indicated but version header is incorrect", @@ -756,13 +752,12 @@ defmodule HTTP1RequestTest do assert SimpleHTTP1Client.recv_reply(client) ~> {:ok, "400 Bad Request", list(), - "WebSocket upgrade failed: error in sec_websocket_version_header check: [\"99\"]"} + "Header sec-websocket-version [\"99\"] not equal to 13"} Process.sleep(100) end) - assert errors =~ - "WebSocket upgrade failed: error in sec_websocket_version_header check: [\\\"99\\\"]" + assert errors =~ "Header sec-websocket-version [\\\"99\\\"] not equal to 13" end test "returns a 400 and errors loudly if websocket support is not enabled", context do diff --git a/test/bandit/websocket/autobahn_test.exs b/test/bandit/websocket/autobahn_test.exs index 59b30162..9e7a6dd2 100644 --- a/test/bandit/websocket/autobahn_test.exs +++ b/test/bandit/websocket/autobahn_test.exs @@ -16,10 +16,7 @@ defmodule WebsocketAutobahnTest do @impl Plug def call(conn, _opts) do - case Bandit.WebSocket.Handshake.valid_upgrade?(conn) do - true -> Plug.Conn.upgrade_adapter(conn, :websocket, {EchoWebSock, :ok, compress: true}) - false -> Plug.Conn.send_resp(conn, 204, <<>>) - end + Plug.Conn.upgrade_adapter(conn, :websocket, {EchoWebSock, :ok, compress: true}) end @tag capture_log: true diff --git a/test/bandit/websocket/http1_handshake_test.exs b/test/bandit/websocket/http1_handshake_test.exs index f3d28751..60b0736a 100644 --- a/test/bandit/websocket/http1_handshake_test.exs +++ b/test/bandit/websocket/http1_handshake_test.exs @@ -1,7 +1,4 @@ defmodule WebSocketHTTP1HandshakeTest do - # This is fundamentally a test of the Plug helpers in Bandit.WebSocket.Handshake, so we define - # a simple Plug that uses these handshakes to upgrade to a no-op WebSock implementation - use ExUnit.Case, async: true use ServerHelpers @@ -12,17 +9,8 @@ defmodule WebSocketHTTP1HandshakeTest do end def call(conn, _opts) do - case Bandit.WebSocket.Handshake.valid_upgrade?(conn) do - true -> - opts = if List.first(conn.path_info) == "compress", do: [compress: true], else: [] - - conn - |> Plug.Conn.upgrade_adapter(:websocket, {MyNoopWebSock, [], opts}) - - false -> - conn - |> Plug.Conn.send_resp(204, <<>>) - end + opts = if List.first(conn.path_info) == "compress", do: [compress: true], else: [] + Plug.Conn.upgrade_adapter(conn, :websocket, {MyNoopWebSock, [], opts}) end describe "HTTP/1.1 handshake" do @@ -45,114 +33,6 @@ defmodule WebSocketHTTP1HandshakeTest do assert Keyword.get(headers, :"sec-websocket-accept") == "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" end - test "does not accept non-GET requests", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send(client, "POST", "/", [ - "Host: server.example.com", - "Upgrade: WeBsOcKeT", - "Connection: UpGrAdE", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13" - ]) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept non-HTTP/1.1 requests", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send( - client, - "GET", - "/", - [ - "Host: server.example.com", - "Upgrade: WeBsOcKeT", - "Connection: UpGrAdE", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13" - ], - "1.0" - ) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept requests without a host header", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send( - client, - "GET", - "/", - [ - "Upgrade: WeBsOcKeT", - "Connection: UpGrAdE", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13" - ], - "1.0" - ) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept non-websocket upgrade requests", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send(client, "GET", "/", [ - "Host: server.example.com", - "Upgrade: bogus", - "Connection: UpGrAdE", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13" - ]) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept non-upgrade requests", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send(client, "GET", "/", [ - "Host: server.example.com", - "Upgrade: WeBsOcKeT", - "Connection: close", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 13" - ]) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept requests without a request key", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send(client, "GET", "/", [ - "Host: server.example.com", - "Upgrade: WeBsOcKeT", - "Connection: UpGrAdE", - "Sec-WebSocket-Version: 13" - ]) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - - test "does not accept requests without a version of 13", context do - client = SimpleWebSocketClient.tcp_client(context) - - SimpleHTTP1Client.send(client, "GET", "/", [ - "Host: server.example.com", - "Upgrade: WeBsOcKeT", - "Connection: UpGrAdE", - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version: 12" - ]) - - assert {:ok, "204 No Content", _headers, <<>>} = SimpleHTTP1Client.recv_reply(client) - end - test "negotiates permessage-deflate if so configured", context do client = SimpleWebSocketClient.tcp_client(context) @@ -209,8 +89,6 @@ defmodule WebSocketHTTP1HandshakeTest do "Sec-WebSocket-Extensions: permessage-deflate;client_max_window_bits=12" ]) - V - assert {:ok, "101 Switching Protocols", headers, <<>>} = SimpleHTTP1Client.recv_reply(client) diff --git a/test/bandit/websocket/protocol_test.exs b/test/bandit/websocket/protocol_test.exs index b8519f6e..d99fee57 100644 --- a/test/bandit/websocket/protocol_test.exs +++ b/test/bandit/websocket/protocol_test.exs @@ -6,16 +6,9 @@ defmodule WebSocketProtocolTest do def call(conn, _opts) do conn = Plug.Conn.fetch_query_params(conn) - - case Bandit.WebSocket.Handshake.valid_upgrade?(conn) do - true -> - websock = conn.query_params["websock"] |> String.to_atom() - compress = conn.query_params["compress"] - Plug.Conn.upgrade_adapter(conn, :websocket, {websock, conn.params, compress: compress}) - - false -> - Plug.Conn.send_resp(conn, 204, <<>>) - end + websock = conn.query_params["websock"] |> String.to_atom() + compress = conn.query_params["compress"] + Plug.Conn.upgrade_adapter(conn, :websocket, {websock, conn.params, compress: compress}) end # These websocks are used throughout these tests, so declare them top-level diff --git a/test/bandit/websocket/sock_test.exs b/test/bandit/websocket/sock_test.exs index 25d506f7..ed7e1d26 100644 --- a/test/bandit/websocket/sock_test.exs +++ b/test/bandit/websocket/sock_test.exs @@ -10,15 +10,8 @@ defmodule WebSocketWebSockTest do def call(conn, _opts) do conn = Plug.Conn.fetch_query_params(conn) - - case Bandit.WebSocket.Handshake.valid_upgrade?(conn) do - true -> - websock = conn.query_params["websock"] |> String.to_atom() - Plug.Conn.upgrade_adapter(conn, :websocket, {websock, [], []}) - - false -> - Plug.Conn.send_resp(conn, 204, <<>>) - end + websock = conn.query_params["websock"] |> String.to_atom() + Plug.Conn.upgrade_adapter(conn, :websocket, {websock, [], []}) end describe "init" do diff --git a/test/bandit/websocket/upgrade_validation_test.exs b/test/bandit/websocket/upgrade_validation_test.exs new file mode 100644 index 00000000..d237077f --- /dev/null +++ b/test/bandit/websocket/upgrade_validation_test.exs @@ -0,0 +1,128 @@ +defmodule UpgradeValidationTest do + # Note that these tests do not actually upgrade the connection to a WebSocket; they're just a + # plug that happens to call `validate_upgrade/1` and returns the result. The fact that we use + # HTTP calls to do this is to avoid having to manually construct `Plug.Conn` structs for testing + + use ExUnit.Case, async: true + use ServerHelpers + + setup :http_server + + def validate_upgrade(conn) do + case Bandit.WebSocket.UpgradeValidation.validate_upgrade(conn) do + :ok -> Plug.Conn.send_resp(conn, 200, "ok") + {:error, reason} -> Plug.Conn.send_resp(conn, 200, reason) + end + end + + describe "HTTP/1 upgrades" do + test "accepts well formed requests", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "200 OK", _headers, "ok"} = SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept non-GET requests", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "POST", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "200 OK", _headers, "HTTP method POST unsupported"} = + SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept non-HTTP/1.1 requests", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send( + client, + "GET", + "/validate_upgrade", + [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ], + "1.0" + ) + + assert {:ok, "200 OK", _headers, "HTTP version HTTP/1.0 unsupported"} = + SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept non-upgrade requests", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: close", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "200 OK", _headers, "Header connection [\"close\"] does not contain upgrade"} = + SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept non-websocket upgrade requests", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: bogus", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "200 OK", _headers, "Header upgrade [\"bogus\"] does not contain websocket"} = + SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept requests without a request key", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Version: 13" + ]) + + assert {:ok, "200 OK", _headers, "Header sec-websocket-key is absent"} = + SimpleHTTP1Client.recv_reply(client) + end + + test "does not accept requests without a version of 13", context do + client = SimpleHTTP1Client.tcp_client(context) + + SimpleHTTP1Client.send(client, "GET", "/validate_upgrade", [ + "Host: server.example.com", + "Upgrade: WeBsOcKeT", + "Connection: UpGrAdE", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 12" + ]) + + assert {:ok, "200 OK", _headers, "Header sec-websocket-version [\"12\"] not equal to 13"} = + SimpleHTTP1Client.recv_reply(client) + end + end +end