Skip to content

Commit

Permalink
Add check_csrf option to socket transport options (#5952)
Browse files Browse the repository at this point in the history
One of check_origin or check_csrf must be enabled.
  • Loading branch information
tanguilp authored Oct 31, 2024
1 parent a14841b commit c486cdf
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 17 deletions.
7 changes: 7 additions & 0 deletions lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,13 @@ defmodule Phoenix.Endpoint do
The MFA is invoked with the request `%URI{}` as the first argument,
followed by arguments in the MFA list, and must return a boolean.
* `:check_csrf` - if the transport should perform CSRF check. To avoid
"Cross-Site WebSocket Hijacking", you must have at least one of
`check_origin` and `check_csrf` enabled. If you set both to `false`,
Phoenix will raise, but it is still possible to disable both by passing
a custom MFA to `check_origin`. In such cases, it is your responsibility
to ensure at least one of them is enabled. Defaults to `true`
* `:code_reloader` - enable or disable the code reloader. Defaults to your
endpoint configuration
Expand Down
24 changes: 20 additions & 4 deletions lib/phoenix/endpoint/supervisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ defmodule Phoenix.Endpoint.Supervisor do
children =
config_children(mod, secret_conf, default_conf) ++
pubsub_children(mod, conf) ++
socket_children(mod, :child_spec) ++
socket_children(mod, conf, :child_spec) ++
server_children(mod, conf, server?) ++
socket_children(mod, :drainer_spec) ++
socket_children(mod, conf, :drainer_spec) ++
watcher_children(mod, conf, server?)

Supervisor.init(children, strategy: :one_for_one)
end

Expand Down Expand Up @@ -118,8 +117,9 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp socket_children(endpoint, fun) do
defp socket_children(endpoint, conf, fun) do
for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)),
_ = check_origin_or_csrf_checked!(conf, opts),
spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]),
spec != :ignore do
spec
Expand All @@ -135,6 +135,22 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp check_origin_or_csrf_checked!(endpoint_conf, socket_opts) do
check_origin = endpoint_conf[:check_origin]

for {transport, transport_opts} <- socket_opts, is_list(transport_opts) do
check_origin = Keyword.get(transport_opts, :check_origin, check_origin)

check_csrf = transport_opts[:check_csrf]

if check_origin == false and check_csrf == false do
raise ArgumentError,
"one of :check_origin and :check_csrf must be set to non-false value for " <>
"transport #{inspect(transport)}"
end
end
end

defp config_children(mod, conf, default_conf) do
args = {mod, conf, default_conf, name: Module.concat(mod, "Config")}
[{Phoenix.Config, args}]
Expand Down
27 changes: 17 additions & 10 deletions lib/phoenix/socket/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,9 @@ defmodule Phoenix.Socket.Transport do
* `:user_agent` - the value of the "user-agent" request header
The CSRF check can be disabled by setting the `:check_csrf` option to `false`.
"""
def connect_info(conn, endpoint, keys) do
def connect_info(conn, endpoint, keys, opts \\ []) do
for key <- keys, into: %{} do
case key do
:peer_data ->
Expand All @@ -478,34 +479,32 @@ defmodule Phoenix.Socket.Transport do
{:user_agent, fetch_user_agent(conn)}

{:session, session} ->
{:session, connect_session(conn, endpoint, session)}
{:session, connect_session(conn, endpoint, session, opts)}

{key, val} ->
{key, val}
end
end
end

defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}) do
defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}, opts) do
conn = Plug.Conn.fetch_cookies(conn)
check_csrf = Keyword.get(opts, :check_csrf, true)

with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
cookie when is_binary(cookie) <- conn.cookies[key],
with cookie when is_binary(cookie) <- conn.cookies[key],
conn = put_in(conn.secret_key_base, endpoint.config(:secret_key_base)),
{_, session} <- store.get(conn, cookie, init),
csrf_state when is_binary(csrf_state) <-
Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]),
true <- Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) do
true <- not check_csrf or csrf_token_valid?(conn, session, csrf_token_key) do
session
else
_ -> nil
end
end

defp connect_session(conn, endpoint, {:mfa, {module, function, args}}) do
defp connect_session(conn, endpoint, {:mfa, {module, function, args}}, opts) do
case apply(module, function, args) do
session_config when is_list(session_config) ->
connect_session(conn, endpoint, init_session(session_config))
connect_session(conn, endpoint, init_session(session_config), opts)

other ->
raise ArgumentError,
Expand Down Expand Up @@ -542,6 +541,14 @@ defmodule Phoenix.Socket.Transport do
end
end

defp csrf_token_valid?(conn, session, csrf_token_key) do
with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
csrf_state when is_binary(csrf_state) <-
Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]) do
Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token)
end
end

defp check_origin_config(handler, endpoint, opts) do
Phoenix.Config.cache(endpoint, {:check_origin, handler}, fn _ ->
check_origin =
Expand Down
6 changes: 5 additions & 1 deletion lib/phoenix/transports/long_poll.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Phoenix.Transports.LongPoll do

# 10MB
@max_base64_size 10_000_000
@connect_info_opts [:check_csrf]

import Plug.Conn
alias Phoenix.Socket.{V1, V2, Transport}
Expand Down Expand Up @@ -136,7 +137,10 @@ defmodule Phoenix.Transports.LongPoll do
(System.system_time(:millisecond) |> Integer.to_string())

keys = Keyword.get(opts, :connect_info, [])
connect_info = Transport.connect_info(conn, endpoint, keys)

connect_info =
Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts))

arg = {endpoint, handler, opts, conn.params, priv_topic, connect_info}
spec = {Phoenix.Transports.LongPoll.Server, arg}

Expand Down
6 changes: 5 additions & 1 deletion lib/phoenix/transports/websocket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ defmodule Phoenix.Transports.WebSocket do
#
@behaviour Plug

@connect_info_opts [:check_csrf]

import Plug.Conn

alias Phoenix.Socket.{V1, V2, Transport}
Expand Down Expand Up @@ -45,7 +47,9 @@ defmodule Phoenix.Transports.WebSocket do

%{params: params} = conn ->
keys = Keyword.get(opts, :connect_info, [])
connect_info = Transport.connect_info(conn, endpoint, keys)

connect_info =
Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts))

config = %{
endpoint: endpoint,
Expand Down
40 changes: 40 additions & 0 deletions test/phoenix/endpoint/supervisor_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,44 @@ defmodule Phoenix.Endpoint.SupervisorTest do
end)
end
end

describe "origin & CSRF checks config" do
defmodule TestSocket do
@behaviour Phoenix.Socket.Transport
def child_spec(_), do: :ignore
def connect(_), do: {:ok, []}
def init(state), do: {:ok, state}
def handle_in(_, state), do: {:ok, state}
def handle_info(_, state), do: {:ok, state}
def terminate(_, _), do: :ok
end

defmodule SocketEndpoint do
use Phoenix.Endpoint, otp_app: :phoenix

socket "/ws", TestSocket, websocket: [check_csrf: false, check_origin: false]
end

Application.put_env(:phoenix, SocketEndpoint, [])

test "fails when CSRF and origin checks both disabled in transport" do
assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn ->
Supervisor.init({:phoenix, SocketEndpoint, []})
end
end

defmodule SocketEndpointOriginCheckDisabled do
use Phoenix.Endpoint, otp_app: :phoenix

socket "/ws", TestSocket, websocket: [check_csrf: false]
end

Application.put_env(:phoenix, SocketEndpointOriginCheckDisabled, check_origin: false)

test "fails when origin is disabled in endpoint config and CSRF disabled in transport" do
assert_raise ArgumentError, ~r/one of :check_origin and :check_csrf must be set/, fn ->
Supervisor.init({:phoenix, SocketEndpointOriginCheckDisabled, []})
end
end
end
end
28 changes: 27 additions & 1 deletion test/phoenix/socket/transport_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ defmodule Phoenix.Socket.TransportTest do
end
end

describe "connect_info/3" do
describe "connect_info/4" do
defp load_connect_info(connect_info) do
[connect_info: connect_info] = Transport.load_config(connect_info: connect_info)
connect_info
Expand Down Expand Up @@ -330,5 +330,31 @@ defmodule Phoenix.Socket.TransportTest do
|> Transport.connect_info(Endpoint, connect_info)
end

test "loads the session when CSRF is disabled despite CSRF token not being provided" do
conn = conn(:get, "https://foo.com/") |> Endpoint.call([])
session_cookie = conn.cookies["_hello_key"]

connect_info = load_connect_info(session: {Endpoint, :session_config, []})

assert %{session: %{"from_session" => "123"}} =
conn(:get, "https://foo.com/")
|> put_req_cookie("_hello_key", session_cookie)
|> fetch_query_params()
|> Transport.connect_info(Endpoint, connect_info, check_csrf: false)
end

test "doesn't load session when an invalid CSRF token is provided" do
conn = conn(:get, "https://foo.com/") |> Endpoint.call([])
invalid_csrf_token = "some invalid CSRF token"
session_cookie = conn.cookies["_hello_key"]

connect_info = load_connect_info(session: {Endpoint, :session_config, []})

assert %{session: nil} =
conn(:get, "https://foo.com/", _csrf_token: invalid_csrf_token)
|> put_req_cookie("_hello_key", session_cookie)
|> fetch_query_params()
|> Transport.connect_info(Endpoint, connect_info)
end
end
end

0 comments on commit c486cdf

Please sign in to comment.