Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid double JSON encode/decode for socket.io #4449

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
if (socket) {
socket.emit(
"event",
JSON.stringify(event, (k, v) => (v === undefined ? null : v))
event,
);
return true;
}
Expand Down Expand Up @@ -407,6 +407,8 @@ export const connect = async (
transports: transports,
autoUnref: false,
});
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)

function checkVisibility() {
if (document.visibilityState === "visible") {
Expand Down Expand Up @@ -444,7 +446,7 @@ export const connect = async (

// On each received message, queue the updates and events.
socket.current.on("event", async (message) => {
const update = JSON5.parse(message);
const update = message;
for (const substate in update.delta) {
dispatch[substate](update.delta[substate]);
}
Expand Down
9 changes: 7 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import traceback
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -362,6 +363,10 @@ def _setup_state(self) -> None:
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT,
json=SimpleNamespace(
dumps=staticmethod(format.json_dumps),
loads=staticmethod(json.loads),
),
)
elif getattr(self.sio, "async_mode", "") != "asgi":
raise RuntimeError(
Expand Down Expand Up @@ -1507,7 +1512,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
)

async def on_event(self, sid, data):
Expand All @@ -1520,7 +1525,7 @@ async def on_event(self, sid, data):
sid: The Socket.IO session id.
data: The event data.
"""
fields = json.loads(data)
fields = data
# Get the event.
event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
Expand Down
8 changes: 6 additions & 2 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
return lib


def json_dumps(obj: Any) -> str:
def json_dumps(obj: Any, **kwargs) -> str:
"""Takes an object and returns a jsonified string.

Args:
obj: The object to be serialized.
kwargs: Additional keyword arguments to pass to json.dumps.

Returns:
A string
"""
from reflex.utils import serializers

return json.dumps(obj, ensure_ascii=False, default=serializers.serialize)
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("default", serializers.serialize)

return json.dumps(obj, **kwargs)


def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:
Expand Down
92 changes: 54 additions & 38 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,24 @@ async def _coro_waiter():
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1


class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""

def __call__(self, *args, **kwargs):
"""Call the mock.

Args:
args: the arguments passed to the mock
kwargs: the keyword arguments passed to the mock

Returns:
The result of the mock call
"""
args = copy.deepcopy(args)
kwargs = copy.deepcopy(kwargs)
return super().__call__(*args, **kwargs)


@pytest.fixture(scope="function")
def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture.
Expand All @@ -1853,7 +1871,7 @@ def mock_app_simple(monkeypatch) -> rx.App:

setattr(app_module, CompileVars.APP, app)
app.state = TestState
app.event_namespace.emit = AsyncMock() # type: ignore
app.event_namespace.emit = CopyingAsyncMock() # type: ignore

def _mock_get_app(*args, **kwargs):
return app_module
Expand Down Expand Up @@ -1957,21 +1975,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == dataclasses.asdict(
StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
assert mcall.args[1] == StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id

Expand Down Expand Up @@ -2149,51 +2165,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit

first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
first_ws_message = emit_mock.mock_calls[0].args[1]
assert (
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
)
assert first_ws_message == {
"delta": {
assert first_ws_message == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"],
"computed_order": ["background_task:start"],
}
},
"events": [],
"final": True,
}
events=[],
final=True,
)
for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == {
"delta": {
assert call.args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"],
}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
"delta": {
events=[],
final=True,
)
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
"computed_order": exp_order,
"dict_list": {},
}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
"delta": {
events=[],
final=True,
)
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,
},
},
"events": [],
"final": True,
}
events=[],
final=True,
)


@pytest.mark.asyncio
Expand Down
Loading