Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check_csrf option to socket transport options #5952

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,10 @@ defmodule Phoenix.Endpoint do
The MFA is invoked with the request `%URI{}` as the first argument,
followed by arguments in the MFA list, and must return a boolean.

* `:check_csrf` - if the transport should perform CSRF check. Note that disabling
both CSRF and origin checks at the same time is not allowed and will raise.
Defaults to `true`
josevalim marked this conversation as resolved.
Show resolved Hide resolved

* `:code_reloader` - enable or disable the code reloader. Defaults to your
endpoint configuration

Expand Down
29 changes: 26 additions & 3 deletions lib/phoenix/endpoint/supervisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ defmodule Phoenix.Endpoint.Supervisor do
children =
config_children(mod, secret_conf, default_conf) ++
pubsub_children(mod, conf) ++
socket_children(mod, :child_spec) ++
socket_children(otp_app, mod, :child_spec) ++
tanguilp marked this conversation as resolved.
Show resolved Hide resolved
server_children(mod, conf, server?) ++
socket_children(mod, :drainer_spec) ++
socket_children(otp_app, mod, :drainer_spec) ++
watcher_children(mod, conf, server?)

Supervisor.init(children, strategy: :one_for_one)
Expand Down Expand Up @@ -118,10 +118,12 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp socket_children(endpoint, fun) do
defp socket_children(otp_app, endpoint, fun) do
for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)),
spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]),
josevalim marked this conversation as resolved.
Show resolved Hide resolved
spec != :ignore do
check_origin_or_csrf_checked!(otp_app, endpoint, opts)

spec
end
end
Expand All @@ -135,6 +137,27 @@ defmodule Phoenix.Endpoint.Supervisor do
end
end

defp check_origin_or_csrf_checked!(otp_app, endpoint, socket_opts) do
endpoint_check_origin = config(otp_app, endpoint)[:check_origin]

for {transport, transport_opts} <- socket_opts do
check_origin =
if is_boolean(transport_opts[:check_origin]) do
transport_opts[:check_origin]
else
endpoint_check_origin
end
tanguilp marked this conversation as resolved.
Show resolved Hide resolved

check_csrf = transport_opts[:check_csrf]

if check_origin == false and check_csrf == false do
raise ArgumentError,
"One of :check_origin and :check_csrf must be set to non-false value for " <>
tanguilp marked this conversation as resolved.
Show resolved Hide resolved
"transport #{inspect(transport)}"
end
josevalim marked this conversation as resolved.
Show resolved Hide resolved
end
end

defp config_children(mod, conf, default_conf) do
args = {mod, conf, default_conf, name: Module.concat(mod, "Config")}
[{Phoenix.Config, args}]
Expand Down
27 changes: 17 additions & 10 deletions lib/phoenix/socket/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,9 @@ defmodule Phoenix.Socket.Transport do

* `:user_agent` - the value of the "user-agent" request header

The CSRF check can be disabled by setting the `:check_csrf` option to `false`.
"""
def connect_info(conn, endpoint, keys) do
def connect_info(conn, endpoint, keys, opts \\ []) do
for key <- keys, into: %{} do
case key do
:peer_data ->
Expand All @@ -478,34 +479,32 @@ defmodule Phoenix.Socket.Transport do
{:user_agent, fetch_user_agent(conn)}

{:session, session} ->
{:session, connect_session(conn, endpoint, session)}
{:session, connect_session(conn, endpoint, session, opts)}

{key, val} ->
{key, val}
end
end
end

defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}) do
defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}, opts) do
conn = Plug.Conn.fetch_cookies(conn)
check_csrf = Keyword.get(opts, :check_csrf, true)

with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
cookie when is_binary(cookie) <- conn.cookies[key],
with cookie when is_binary(cookie) <- conn.cookies[key],
conn = put_in(conn.secret_key_base, endpoint.config(:secret_key_base)),
{_, session} <- store.get(conn, cookie, init),
csrf_state when is_binary(csrf_state) <-
Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]),
true <- Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) do
true <- not check_csrf or csrf_token_valid?(conn, session, csrf_token_key) do
session
else
_ -> nil
end
end

defp connect_session(conn, endpoint, {:mfa, {module, function, args}}) do
defp connect_session(conn, endpoint, {:mfa, {module, function, args}}, opts) do
case apply(module, function, args) do
session_config when is_list(session_config) ->
connect_session(conn, endpoint, init_session(session_config))
connect_session(conn, endpoint, init_session(session_config), opts)

other ->
raise ArgumentError,
Expand Down Expand Up @@ -542,6 +541,14 @@ defmodule Phoenix.Socket.Transport do
end
end

defp csrf_token_valid?(conn, session, csrf_token_key) do
with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
csrf_state when is_binary(csrf_state) <-
Plug.CSRFProtection.dump_state_from_session(session[csrf_token_key]) do
Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token)
end
end

defp check_origin_config(handler, endpoint, opts) do
Phoenix.Config.cache(endpoint, {:check_origin, handler}, fn _ ->
check_origin =
Expand Down
6 changes: 5 additions & 1 deletion lib/phoenix/transports/long_poll.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Phoenix.Transports.LongPoll do

# 10MB
@max_base64_size 10_000_000
@connect_info_opts [:check_csrf]

import Plug.Conn
alias Phoenix.Socket.{V1, V2, Transport}
Expand Down Expand Up @@ -136,7 +137,10 @@ defmodule Phoenix.Transports.LongPoll do
(System.system_time(:millisecond) |> Integer.to_string())

keys = Keyword.get(opts, :connect_info, [])
connect_info = Transport.connect_info(conn, endpoint, keys)

connect_info =
Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts))

arg = {endpoint, handler, opts, conn.params, priv_topic, connect_info}
spec = {Phoenix.Transports.LongPoll.Server, arg}

Expand Down
6 changes: 5 additions & 1 deletion lib/phoenix/transports/websocket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ defmodule Phoenix.Transports.WebSocket do
#
@behaviour Plug

@connect_info_opts [:check_csrf]

import Plug.Conn

alias Phoenix.Socket.{V1, V2, Transport}
Expand Down Expand Up @@ -45,7 +47,9 @@ defmodule Phoenix.Transports.WebSocket do

%{params: params} = conn ->
keys = Keyword.get(opts, :connect_info, [])
connect_info = Transport.connect_info(conn, endpoint, keys)

connect_info =
Transport.connect_info(conn, endpoint, keys, Keyword.take(opts, @connect_info_opts))

config = %{
endpoint: endpoint,
Expand Down
28 changes: 27 additions & 1 deletion test/phoenix/socket/transport_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ defmodule Phoenix.Socket.TransportTest do
end
end

describe "connect_info/3" do
describe "connect_info/4" do
defp load_connect_info(connect_info) do
[connect_info: connect_info] = Transport.load_config(connect_info: connect_info)
connect_info
Expand Down Expand Up @@ -330,5 +330,31 @@ defmodule Phoenix.Socket.TransportTest do
|> Transport.connect_info(Endpoint, connect_info)
end

test "loads the session when CSRF is disabled despite CSRF token not being provided" do
conn = conn(:get, "https://foo.com/") |> Endpoint.call([])
session_cookie = conn.cookies["_hello_key"]

connect_info = load_connect_info(session: {Endpoint, :session_config, []})

assert %{session: %{"from_session" => "123"}} =
conn(:get, "https://foo.com/")
|> put_req_cookie("_hello_key", session_cookie)
|> fetch_query_params()
|> Transport.connect_info(Endpoint, connect_info, check_csrf: false)
end

test "doesn't load session when an invalid CSRF token is provided" do
conn = conn(:get, "https://foo.com/") |> Endpoint.call([])
invalid_csrf_token = "some invalid CSRF token"
session_cookie = conn.cookies["_hello_key"]

connect_info = load_connect_info(session: {Endpoint, :session_config, []})

assert %{session: nil} =
conn(:get, "https://foo.com/", _csrf_token: invalid_csrf_token)
|> put_req_cookie("_hello_key", session_cookie)
|> fetch_query_params()
|> Transport.connect_info(Endpoint, connect_info)
end
end
end