diff --git a/lib/phoenix/endpoint.ex b/lib/phoenix/endpoint.ex index 465420ac29..a923be9a84 100644 --- a/lib/phoenix/endpoint.ex +++ b/lib/phoenix/endpoint.ex @@ -859,6 +859,13 @@ defmodule Phoenix.Endpoint do The MFA is invoked with the request `%URI{}` as the first argument, followed by arguments in the MFA list, and must return a boolean. + * `:check_csrf` - if the transport should perform CSRF check. To avoid + "Cross-Site WebSocket Hijacking", you must have at least one of + `check_origin` and `check_csrf` enabled. If you set both to `false`, + Phoenix will raise, but it is still possible to disable both by passing + a custom MFA to `check_origin`. In such cases, it is your responsibility + to ensure at least one of them is enabled. Defaults to `true` + * `:code_reloader` - enable or disable the code reloader. Defaults to your endpoint configuration diff --git a/lib/phoenix/endpoint/supervisor.ex b/lib/phoenix/endpoint/supervisor.ex index 5ee904dcb7..597f175d90 100644 --- a/lib/phoenix/endpoint/supervisor.ex +++ b/lib/phoenix/endpoint/supervisor.ex @@ -84,11 +84,10 @@ defmodule Phoenix.Endpoint.Supervisor do children = config_children(mod, secret_conf, default_conf) ++ pubsub_children(mod, conf) ++ - socket_children(mod, :child_spec) ++ + socket_children(mod, conf, :child_spec) ++ server_children(mod, conf, server?) ++ - socket_children(mod, :drainer_spec) ++ + socket_children(mod, conf, :drainer_spec) ++ watcher_children(mod, conf, server?) - Supervisor.init(children, strategy: :one_for_one) end @@ -118,8 +117,9 @@ defmodule Phoenix.Endpoint.Supervisor do end end - defp socket_children(endpoint, fun) do + defp socket_children(endpoint, conf, fun) do for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)), + _ = check_origin_or_csrf_checked!(conf, opts), spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]), spec != :ignore do spec @@ -135,6 +135,22 @@ defmodule Phoenix.Endpoint.Supervisor do end end + defp check_origin_or_csrf_checked!(endpoint_conf, socket_opts) do + check_origin = endpoint_conf[:check_origin] + + for {transport, transport_opts} <- socket_opts, is_list(transport_opts) do + check_origin = Keyword.get(transport_opts, :check_origin, check_origin) + + check_csrf = transport_opts[:check_csrf] + + if check_origin == false and check_csrf == false do + raise ArgumentError, + "one of :check_origin and :check_csrf must be set to non-false value for " <> + "transport #{inspect(transport)}" + end + end + end + defp config_children(mod, conf, default_conf) do args = {mod, conf, default_conf, name: Module.concat(mod, "Config")} [{Phoenix.Config, args}] diff --git a/lib/phoenix/socket/transport.ex b/lib/phoenix/socket/transport.ex index a1c5fd5467..c9ab5dd956 100644 --- a/lib/phoenix/socket/transport.ex +++ b/lib/phoenix/socket/transport.ex @@ -458,8 +458,9 @@ defmodule Phoenix.Socket.Transport do * `:user_agent` - the value of the "user-agent" request header + The CSRF check can be disabled by setting the `:check_csrf` option to `false`. """ - def connect_info(conn, endpoint, keys) do + def connect_info(conn, endpoint, keys, opts \\ []) do for key <- keys, into: %{} do case key do :peer_data -> @@ -478,7 +479,7 @@ defmodule Phoenix.Socket.Transport do {:user_agent, fetch_user_agent(conn)} {:session, session} -> - {:session, connect_session(conn, endpoint, session)} + {:session, connect_session(conn, endpoint, session, opts)} {key, val} -> {key, val} @@ -486,26 +487,24 @@ defmodule Phoenix.Socket.Transport do end end - defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}) do + defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}, opts) do conn = Plug.Conn.fetch_cookies(conn) + check_csrf = Keyword.get(opts, :check_csrf, true) - with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"], - cookie when is_binary(cookie) <- conn.cookies[key], + with cookie when is_binary(cookie) <- conn.cookies[key], conn = put_in(conn.secret_key_base, endpoint.config(:secret_key_base)), {_, session} <- store.get(conn, cookie, init), - csrf_state when is_binary(csrf_state) <- - Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]), - true <- Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) do + true <- not check_csrf or csrf_token_valid?(conn, session, csrf_token_key) do session else _ -> nil end end - defp connect_session(conn, endpoint, {:mfa, {module, function, args}}) do + defp connect_session(conn, endpoint, {:mfa, {module, function, args}}, opts) do case apply(module, function, args) do session_config when is_list(session_config) -> - connect_session(conn, endpoint, init_session(session_config)) + connect_session(conn, endpoint, init_session(session_config), opts) other -> raise ArgumentError, @@ -542,6 +541,14 @@ defmodule Phoenix.Socket.Transport do end end + defp csrf_token_valid?(conn, session, csrf_token_key) do + with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"], + csrf_state when is_binary(csrf_state) <- + Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]) do + Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) + end + end + defp check_origin_config(handler, endpoint, opts) do Phoenix.Config.cache(endpoint, {:check_origin, handler}, fn _ -> check_origin = diff --git a/lib/phoenix/transports/long_poll.ex b/lib/phoenix/transports/long_poll.ex index de3365cee9..37da2fcda0 100644 --- a/lib/phoenix/transports/long_poll.ex +++ b/lib/phoenix/transports/long_poll.ex @@ -4,6 +4,7 @@ defmodule Phoenix.Transports.LongPoll do # 10MB @max_base64_size 10_000_000 + @connect_info_opts [:check_csrf] import Plug.Conn alias Phoenix.Socket.{V1, V2, Transport} @@ -136,7 +137,10 @@ defmodule Phoenix.Transports.LongPoll do (System.system_time(:millisecond) |> Integer.to_string()) keys = Keyword.get(opts, :connect_info, []) - connect_info = Transport.connect_info(conn, endpoint, keys) + + connect_info = + Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts)) + arg = {endpoint, handler, opts, conn.params, priv_topic, connect_info} spec = {Phoenix.Transports.LongPoll.Server, arg} diff --git a/lib/phoenix/transports/websocket.ex b/lib/phoenix/transports/websocket.ex index f04cc4350d..dfc7bd2508 100644 --- a/lib/phoenix/transports/websocket.ex +++ b/lib/phoenix/transports/websocket.ex @@ -15,6 +15,8 @@ defmodule Phoenix.Transports.WebSocket do # @behaviour Plug + @connect_info_opts [:check_csrf] + import Plug.Conn alias Phoenix.Socket.{V1, V2, Transport} @@ -45,7 +47,9 @@ defmodule Phoenix.Transports.WebSocket do %{params: params} = conn -> keys = Keyword.get(opts, :connect_info, []) - connect_info = Transport.connect_info(conn, endpoint, keys) + + connect_info = + Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts)) config = %{ endpoint: endpoint, diff --git a/test/phoenix/endpoint/supervisor_test.exs b/test/phoenix/endpoint/supervisor_test.exs index 235c5c0be4..3b2b958db6 100644 --- a/test/phoenix/endpoint/supervisor_test.exs +++ b/test/phoenix/endpoint/supervisor_test.exs @@ -180,4 +180,44 @@ defmodule Phoenix.Endpoint.SupervisorTest do end) end end + + describe "origin & CSRF checks config" do + defmodule TestSocket do + @behaviour Phoenix.Socket.Transport + def child_spec(_), do: :ignore + def connect(_), do: {:ok, []} + def init(state), do: {:ok, state} + def handle_in(_, state), do: {:ok, state} + def handle_info(_, state), do: {:ok, state} + def terminate(_, _), do: :ok + end + + defmodule SocketEndpoint do + use Phoenix.Endpoint, otp_app: :phoenix + + socket "/ws", TestSocket, websocket: [check_csrf: false, check_origin: false] + end + + Application.put_env(:phoenix, SocketEndpoint, []) + + test "fails when CSRF and origin checks both disabled in transport" do + assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn -> + Supervisor.init({:phoenix, SocketEndpoint, []}) + end + end + + defmodule SocketEndpointOriginCheckDisabled do + use Phoenix.Endpoint, otp_app: :phoenix + + socket "/ws", TestSocket, websocket: [check_csrf: false] + end + + Application.put_env(:phoenix, SocketEndpointOriginCheckDisabled, check_origin: false) + + test "fails when origin is disabled in endpoint config and CSRF disabled in transport" do + assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn -> + Supervisor.init({:phoenix, SocketEndpointOriginCheckDisabled, []}) + end + end + end end diff --git a/test/phoenix/socket/transport_test.exs b/test/phoenix/socket/transport_test.exs index 2a3e3df9f7..05700342dd 100644 --- a/test/phoenix/socket/transport_test.exs +++ b/test/phoenix/socket/transport_test.exs @@ -276,7 +276,7 @@ defmodule Phoenix.Socket.TransportTest do end end - describe "connect_info/3" do + describe "connect_info/4" do defp load_connect_info(connect_info) do [connect_info: connect_info] = Transport.load_config(connect_info: connect_info) connect_info @@ -330,5 +330,31 @@ defmodule Phoenix.Socket.TransportTest do |> Transport.connect_info(Endpoint, connect_info) end + test "loads the session when CSRF is disabled despite CSRF token not being provided" do + conn = conn(:get, "https://foo.com/") |> Endpoint.call([]) + session_cookie = conn.cookies["_hello_key"] + + connect_info = load_connect_info(session: {Endpoint, :session_config, []}) + + assert %{session: %{"from_session" => "123"}} = + conn(:get, "https://foo.com/") + |> put_req_cookie("_hello_key", session_cookie) + |> fetch_query_params() + |> Transport.connect_info(Endpoint, connect_info, check_csrf: false) + end + + test "doesn't load session when an invalid CSRF token is provided" do + conn = conn(:get, "https://foo.com/") |> Endpoint.call([]) + invalid_csrf_token = "some invalid CSRF token" + session_cookie = conn.cookies["_hello_key"] + + connect_info = load_connect_info(session: {Endpoint, :session_config, []}) + + assert %{session: nil} = + conn(:get, "https://foo.com/", _csrf_token: invalid_csrf_token) + |> put_req_cookie("_hello_key", session_cookie) + |> fetch_query_params() + |> Transport.connect_info(Endpoint, connect_info) + end end end