Skip to content

Commit

Permalink
Check prsence of CSRF or origin check when starting the transports
Browse files Browse the repository at this point in the history
  • Loading branch information
tanguilp committed Oct 27, 2024
1 parent 5b2ba48 commit d339626
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
29 changes: 26 additions & 3 deletions lib/phoenix/endpoint/supervisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ defmodule Phoenix.Endpoint.Supervisor do
children =
config_children(mod, secret_conf, default_conf) ++
pubsub_children(mod, conf) ++
socket_children(mod, :child_spec) ++
socket_children(otp_app, mod, :child_spec) ++
server_children(mod, conf, server?) ++
socket_children(mod, :drainer_spec) ++
socket_children(otp_app, mod, :drainer_spec) ++
watcher_children(mod, conf, server?)

Supervisor.init(children, strategy: :one_for_one)
Expand Down Expand Up @@ -118,10 +118,12 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp socket_children(endpoint, fun) do
defp socket_children(otp_app, endpoint, fun) do
for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)),
spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]),
spec != :ignore do
check_origin_or_csrf_checked!(otp_app, endpoint, opts)

spec
end
end
Expand All @@ -135,6 +137,27 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp check_origin_or_csrf_checked!(otp_app, endpoint, socket_opts) do
endpoint_check_origin = config(otp_app, endpoint)[:check_origin]

for {transport, transport_opts} <- socket_opts do
check_origin =
if is_boolean(transport_opts[:check_origin]) do
transport_opts[:check_origin]
else
endpoint_check_origin
end

check_csrf = transport_opts[:check_csrf]

if check_origin == false and check_csrf == false do
raise ArgumentError,
"One of :check_origin and :check_csrf must be set to non-false value for " <>
"transport #{inspect(transport)}"
end
end
end

defp config_children(mod, conf, default_conf) do
args = {mod, conf, default_conf, name: Module.concat(mod, "Config")}
[{Phoenix.Config, args}]
Expand Down
5 changes: 0 additions & 5 deletions lib/phoenix/socket/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,8 @@ defmodule Phoenix.Socket.Transport do
import Plug.Conn
origin = conn |> get_req_header("origin") |> List.first()
check_origin = check_origin_config(handler, endpoint, opts)
check_csrf = opts[:check_csrf]

cond do
check_origin == false and check_csrf == false ->
raise ArgumentError,
"One of :check_origin and :check_csrf must be set"

is_nil(origin) or check_origin == false ->
conn

Expand Down
7 changes: 0 additions & 7 deletions test/phoenix/socket/transport_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,6 @@ defmodule Phoenix.Socket.TransportTest do
# an allowed host
refute check_origin("https://host.com/", check_origin: mfa).halted
end

test "raises if both :check_origin and :check_csrf are set to false" do
assert_raise ArgumentError, ~r/One of :check_origin and :check_csrf must be set/, fn ->
check_origin("https://host.com/", check_origin: false, check_csrf: false)
end
end

end

## Check subprotocols
Expand Down

0 comments on commit d339626

Please sign in to comment.