Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow channels to perform custom handover on rejoin #5959

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions assets/js/phoenix/channel.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ export default class Channel {
/**
* Join the channel
* @param {integer} timeout
* @param {boolean} handover - When true, the client won't send a leave message to the existing channel when rejoining
* @returns {Push}
*/
join(timeout = this.timeout){
join(timeout = this.timeout, handover = false){
if(this.joinedOnce){
throw new Error("tried to join multiple times. 'join' can only be called a single time per channel instance")
} else {
this.timeout = timeout
this.joinedOnce = true
this.rejoin()
this.rejoin(timeout, handover)
return this.joinPush
}
}
Expand Down Expand Up @@ -257,9 +258,9 @@ export default class Channel {
/**
* @private
*/
rejoin(timeout = this.timeout){
rejoin(timeout = this.timeout, handover = false){
if(this.isLeaving()){ return }
this.socket.leaveOpenTopic(this.topic)
if(!handover){ this.socket.leaveOpenTopic(this.topic) }
this.state = CHANNEL_STATES.joining
this.joinPush.resend(timeout)
}
Expand Down
3 changes: 2 additions & 1 deletion lib/phoenix/channel/server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ defmodule Phoenix.Channel.Server do

starter = opts[:starter] || &PoolSupervisor.start_child/3
assigns = Map.merge(socket.assigns, Keyword.get(opts, :assigns, %{}))
socket = %{socket | topic: topic, channel: channel, join_ref: join_ref || ref, assigns: assigns}
handover_pid = Keyword.get(opts, :handover_pid, nil)
socket = %{socket | topic: topic, channel: channel, join_ref: join_ref || ref, assigns: assigns, handover_pid: handover_pid}
ref = make_ref()
from = {self(), ref}
child_spec = channel.child_spec({socket.endpoint, from})
Expand Down
120 changes: 103 additions & 17 deletions lib/phoenix/socket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ defmodule Phoenix.Socket do
serializer: nil,
topic: nil,
transport: nil,
transport_pid: nil
transport_pid: nil,
handover_pid: nil

@type t :: %Socket{
assigns: map,
Expand Down Expand Up @@ -371,6 +372,7 @@ defmodule Phoenix.Socket do
## Options

* `:assigns` - the map of socket assigns to merge into the socket on join
* `:handover_on_rejoin` - a boolean to indicate if the channel allows a handover when a duplicate join is detected

## Examples

Expand All @@ -385,6 +387,14 @@ defmodule Phoenix.Socket do
allow more versatile topic scoping.

See `Phoenix.Channel` for more information

## Handover

If a channel is joined multiple times, the existing channel will be terminated by default.
This can be disabled by setting the `:handover_on_rejoin` option to `true`. A custom channel
implementation can then perform a handover by exchanging messages with the old channel
process that is available under the `:handover_pid` key in the socket struct.

"""
defmacro channel(topic_pattern, module, opts \\ []) do
module = expand_alias(module, __CALLER__)
Expand Down Expand Up @@ -526,6 +536,13 @@ defmodule Phoenix.Socket do
handle_in(Map.get(state.channels, topic), message, state, socket)
end

def __info__({:DOWN, _ref, _, _pid, {:shutdown, :handover}}, state) do
# in the special case where the channel is being handed over,
# we don't want to send a phx_error to the socket, as the exit
# is expected
{:ok, state}
end

def __info__({:DOWN, ref, _, pid, reason}, {state, socket}) do
case state.channels_inverse do
%{^pid => {topic, join_ref}} ->
Expand Down Expand Up @@ -559,6 +576,41 @@ defmodule Phoenix.Socket do
{:ok, state}
end

def __info__({:handover, payload, handover_pid, topic, join_ref}, {state, socket}) do
{channel, opts} = socket.handler.__channel__(topic)
opts = Keyword.put(opts, :handover_pid, handover_pid)
join_message = %Message{topic: topic, payload: payload, ref: join_ref, join_ref: join_ref}

case Phoenix.Channel.Server.join(socket, channel, join_message, opts) do
{:ok, reply, pid} ->
reply = %Message{
join_ref: join_ref,
ref: nil,
topic: topic,
event: "phx_handover",
payload: reply
}

shutdown_duplicate_channel(handover_pid, :handover)

state = put_channel(state, pid, topic, join_ref)
{:reply, :ok, encode_reply(socket, reply), {state, socket}}

{:error, reply} ->
reply = %Reply{
join_ref: join_ref,
ref: nil,
topic: topic,
status: :error,
payload: reply
}

shutdown_duplicate_channel(handover_pid, :handover)

{:reply, :error, encode_reply(socket, reply), {state, socket}}
end
end

def __info__(_, state) do
{:ok, state}
end
Expand Down Expand Up @@ -670,6 +722,17 @@ defmodule Phoenix.Socket do
) do
case socket.handler.__channel__(topic) do
{channel, opts} ->
handover? = Keyword.get(opts, :handover_on_rejoin, false)

handover_pid = if handover? do
case state.channels[topic] do
{pid, _, _} -> pid
_ -> nil
end
end

opts = Keyword.put(opts, :handover_pid, handover_pid)

case Phoenix.Channel.Server.join(socket, channel, message, opts) do
{:ok, reply, pid} ->
reply = %Reply{
Expand All @@ -680,6 +743,10 @@ defmodule Phoenix.Socket do
payload: reply
}

if handover_pid do
shutdown_duplicate_channel(handover_pid, :handover)
end

state = put_channel(state, pid, topic, join_ref)
{:reply, :ok, encode_reply(socket, reply), {state, socket}}

Expand All @@ -692,6 +759,10 @@ defmodule Phoenix.Socket do
payload: reply
}

if handover_pid do
shutdown_duplicate_channel(handover_pid, :handover)
end

{:reply, :error, encode_reply(socket, reply), {state, socket}}
end

Expand All @@ -702,22 +773,37 @@ defmodule Phoenix.Socket do
end

defp handle_in({pid, _ref, status}, %{event: "phx_join", topic: topic} = message, state, socket) do
receive do
{:socket_close, ^pid, _reason} -> :ok
after
0 ->
if status != :leaving do
Logger.debug(fn ->
"Duplicate channel join for topic \"#{topic}\" in #{inspect(socket.handler)}. " <>
"Closing existing channel for new join."
end)
end
handover? = case socket.handler.__channel__(topic) do
{_channel, opts} ->
Keyword.get(opts, :handover_on_rejoin, false)

_ -> false
end

:ok = shutdown_duplicate_channel(pid)
{:push, {opcode, payload}, {new_state, new_socket}} = socket_close(pid, {state, socket})
send(self(), {:socket_push, opcode, payload})
handle_in(nil, message, new_state, new_socket)
if handover? do
# the channel wants to handover duplicate joins,
# therefore we don't exit the existing channel process (yet);
# instead, the old pid will be terminated after the new one
# joined successfully
handle_in(nil, message, state, socket)
else
receive do
{:socket_close, ^pid, _reason} -> :ok
after
0 ->
if status != :leaving do
Logger.debug(fn ->
"Duplicate channel join for topic \"#{topic}\" in #{inspect(socket.handler)}. " <>
"Closing existing channel for new join."
end)
end
end

:ok = shutdown_duplicate_channel(pid)
{:push, {opcode, payload}, {new_state, new_socket}} = socket_close(pid, {state, socket})
send(self(), {:socket_push, opcode, payload})
handle_in(nil, message, new_state, new_socket)
end
end

defp handle_in({pid, _ref, _status}, %{event: "phx_leave"} = msg, state, socket) do
Expand Down Expand Up @@ -812,9 +898,9 @@ defmodule Phoenix.Socket do
encode_reply(socket, message)
end

defp shutdown_duplicate_channel(pid) do
defp shutdown_duplicate_channel(pid, reason \\ :duplicate_join) do
ref = Process.monitor(pid)
Process.exit(pid, {:shutdown, :duplicate_join})
Process.exit(pid, {:shutdown, reason})

receive do
{:DOWN, ^ref, _, _, _} -> :ok
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defmodule Phoenix.MixProject do
end
end

@version "1.7.14"
@version "1.8.0-dev"
@scm_url "https://github.com/phoenixframework/phoenix"

# If the elixir requirement is updated, we need to make the installer
Expand Down
11 changes: 7 additions & 4 deletions priv/static/phoenix.cjs.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions priv/static/phoenix.cjs.js.map

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions priv/static/phoenix.js
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,16 @@ var Phoenix = (() => {
/**
* Join the channel
* @param {integer} timeout
* @param {boolean} handover - When true, the client won't send a leave message to the existing channel when rejoining
* @returns {Push}
*/
join(timeout = this.timeout) {
join(timeout = this.timeout, handover = false) {
if (this.joinedOnce) {
throw new Error("tried to join multiple times. 'join' can only be called a single time per channel instance");
} else {
this.timeout = timeout;
this.joinedOnce = true;
this.rejoin();
this.rejoin(timeout, handover);
return this.joinPush;
}
}
Expand Down Expand Up @@ -467,11 +468,13 @@ var Phoenix = (() => {
/**
* @private
*/
rejoin(timeout = this.timeout) {
rejoin(timeout = this.timeout, handover = false) {
if (this.isLeaving()) {
return;
}
this.socket.leaveOpenTopic(this.topic);
if (!handover) {
this.socket.leaveOpenTopic(this.topic);
}
this.state = CHANNEL_STATES.joining;
this.joinPush.resend(timeout);
}
Expand Down
2 changes: 1 addition & 1 deletion priv/static/phoenix.min.js

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions priv/static/phoenix.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,16 @@ var Channel = class {
/**
* Join the channel
* @param {integer} timeout
* @param {boolean} handover - When true, the client won't send a leave message to the existing channel when rejoining
* @returns {Push}
*/
join(timeout = this.timeout) {
join(timeout = this.timeout, handover = false) {
if (this.joinedOnce) {
throw new Error("tried to join multiple times. 'join' can only be called a single time per channel instance");
} else {
this.timeout = timeout;
this.joinedOnce = true;
this.rejoin();
this.rejoin(timeout, handover);
return this.joinPush;
}
}
Expand Down Expand Up @@ -438,11 +439,13 @@ var Channel = class {
/**
* @private
*/
rejoin(timeout = this.timeout) {
rejoin(timeout = this.timeout, handover = false) {
if (this.isLeaving()) {
return;
}
this.socket.leaveOpenTopic(this.topic);
if (!handover) {
this.socket.leaveOpenTopic(this.topic);
}
this.state = CHANNEL_STATES.joining;
this.joinPush.resend(timeout);
}
Expand Down
4 changes: 2 additions & 2 deletions priv/static/phoenix.mjs.map

Large diffs are not rendered by default.