diff --git a/lib/phoenix/endpoint/supervisor.ex b/lib/phoenix/endpoint/supervisor.ex index 5ee904dcb7..fc1ce37183 100644 --- a/lib/phoenix/endpoint/supervisor.ex +++ b/lib/phoenix/endpoint/supervisor.ex @@ -84,9 +84,9 @@ defmodule Phoenix.Endpoint.Supervisor do children = config_children(mod, secret_conf, default_conf) ++ pubsub_children(mod, conf) ++ - socket_children(mod, :child_spec) ++ + socket_children(otp_app, mod, :child_spec) ++ server_children(mod, conf, server?) ++ - socket_children(mod, :drainer_spec) ++ + socket_children(otp_app, mod, :drainer_spec) ++ watcher_children(mod, conf, server?) Supervisor.init(children, strategy: :one_for_one) @@ -118,10 +118,12 @@ defmodule Phoenix.Endpoint.Supervisor do end end - defp socket_children(endpoint, fun) do + defp socket_children(otp_app, endpoint, fun) do for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)), spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]), spec != :ignore do + check_origin_or_csrf_checked!(otp_app, endpoint, opts) + spec end end @@ -135,6 +137,27 @@ defmodule Phoenix.Endpoint.Supervisor do end end + defp check_origin_or_csrf_checked!(otp_app, endpoint, socket_opts) do + endpoint_check_origin = config(otp_app, endpoint)[:check_origin] + + for {transport, transport_opts} <- socket_opts do + check_origin = + if is_boolean(transport_opts[:check_origin]) do + transport_opts[:check_origin] + else + endpoint_check_origin + end + + check_csrf = transport_opts[:check_csrf] + + if check_origin == false and check_csrf == false do + raise ArgumentError, + "One of :check_origin and :check_csrf must be set to non-false value for " <> + "transport #{inspect(transport)}" + end + end + end + defp config_children(mod, conf, default_conf) do args = {mod, conf, default_conf, name: Module.concat(mod, "Config")} [{Phoenix.Config, args}] diff --git a/lib/phoenix/socket/transport.ex b/lib/phoenix/socket/transport.ex index 7489949ead..c9ab5dd956 100644 --- a/lib/phoenix/socket/transport.ex +++ b/lib/phoenix/socket/transport.ex @@ -332,13 +332,8 @@ defmodule Phoenix.Socket.Transport do import Plug.Conn origin = conn |> get_req_header("origin") |> List.first() check_origin = check_origin_config(handler, endpoint, opts) - check_csrf = opts[:check_csrf] cond do - check_origin == false and check_csrf == false -> - raise ArgumentError, - "One of :check_origin and :check_csrf must be set" - is_nil(origin) or check_origin == false -> conn diff --git a/test/phoenix/socket/transport_test.exs b/test/phoenix/socket/transport_test.exs index 87caeea22a..05700342dd 100644 --- a/test/phoenix/socket/transport_test.exs +++ b/test/phoenix/socket/transport_test.exs @@ -237,13 +237,6 @@ defmodule Phoenix.Socket.TransportTest do # an allowed host refute check_origin("https://host.com/", check_origin: mfa).halted end - - test "raises if both :check_origin and :check_csrf are set to false" do - assert_raise ArgumentError, ~r/One of :check_origin and :check_csrf must be set/, fn -> - check_origin("https://host.com/", check_origin: false, check_csrf: false) - end - end - end ## Check subprotocols