diff --git a/lib/phoenix/endpoint/supervisor.ex b/lib/phoenix/endpoint/supervisor.ex index 922b1e0450..ab73ac800c 100644 --- a/lib/phoenix/endpoint/supervisor.ex +++ b/lib/phoenix/endpoint/supervisor.ex @@ -84,11 +84,10 @@ defmodule Phoenix.Endpoint.Supervisor do children = config_children(mod, secret_conf, default_conf) ++ pubsub_children(mod, conf) ++ - socket_children(otp_app, mod, :child_spec) ++ + socket_children(mod, conf, :child_spec) ++ server_children(mod, conf, server?) ++ - socket_children(otp_app, mod, :drainer_spec) ++ + socket_children(mod, conf, :drainer_spec) ++ watcher_children(mod, conf, server?) - Supervisor.init(children, strategy: :one_for_one) end @@ -118,12 +117,11 @@ defmodule Phoenix.Endpoint.Supervisor do end end - defp socket_children(otp_app, endpoint, fun) do + defp socket_children(endpoint, conf, fun) do for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)), + _ = check_origin_or_csrf_checked!(conf, opts), spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]), spec != :ignore do - check_origin_or_csrf_checked!(otp_app, endpoint, opts) - spec end end @@ -137,12 +135,10 @@ defmodule Phoenix.Endpoint.Supervisor do end end - defp check_origin_or_csrf_checked!(otp_app, endpoint, socket_opts) do - endpoint_check_origin = config(otp_app, endpoint)[:check_origin] - - for {transport, transport_opts} <- socket_opts do + defp check_origin_or_csrf_checked!(endpoint_conf, socket_opts) do + for {transport, transport_opts} <- socket_opts, is_list(transport_opts) do check_origin = - Keyword.get(transport_opts, :check_origin, endpoint_check_origin) + Keyword.get(transport_opts, :check_origin, endpoint_conf[:check_origin]) check_csrf = transport_opts[:check_csrf] diff --git a/test/phoenix/endpoint/supervisor_test.exs b/test/phoenix/endpoint/supervisor_test.exs index 235c5c0be4..3b2b958db6 100644 --- a/test/phoenix/endpoint/supervisor_test.exs +++ b/test/phoenix/endpoint/supervisor_test.exs @@ -180,4 +180,44 @@ defmodule Phoenix.Endpoint.SupervisorTest do end) end end + + describe "origin & CSRF checks config" do + defmodule TestSocket do + @behaviour Phoenix.Socket.Transport + def child_spec(_), do: :ignore + def connect(_), do: {:ok, []} + def init(state), do: {:ok, state} + def handle_in(_, state), do: {:ok, state} + def handle_info(_, state), do: {:ok, state} + def terminate(_, _), do: :ok + end + + defmodule SocketEndpoint do + use Phoenix.Endpoint, otp_app: :phoenix + + socket "/ws", TestSocket, websocket: [check_csrf: false, check_origin: false] + end + + Application.put_env(:phoenix, SocketEndpoint, []) + + test "fails when CSRF and origin checks both disabled in transport" do + assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn -> + Supervisor.init({:phoenix, SocketEndpoint, []}) + end + end + + defmodule SocketEndpointOriginCheckDisabled do + use Phoenix.Endpoint, otp_app: :phoenix + + socket "/ws", TestSocket, websocket: [check_csrf: false] + end + + Application.put_env(:phoenix, SocketEndpointOriginCheckDisabled, check_origin: false) + + test "fails when origin is disabled in endpoint config and CSRF disabled in transport" do + assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn -> + Supervisor.init({:phoenix, SocketEndpointOriginCheckDisabled, []}) + end + end + end end