From 6a4c2a1b9e0bbe0521a09bdc842cc84f64c42a76 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 27 Nov 2024 13:58:12 -0800 Subject: [PATCH] Avoid double JSON encode/decode for socket.io socket.io (python and js) already has a built in mechanism for JSON encoding and decoding messages over the websocket. To use it, we pass a custom `json` namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the messages. This avoids sending a JSON-encoded string of JSON over the wire, and reduces the number of serialization/deserialization passes over the message data. The side benefit is that debugging websocket messages in browser tools displays the parsed JSON hierarchy and is much easier to work with. --- reflex/.templates/web/utils/state.js | 6 +- reflex/app.py | 9 ++- reflex/utils/format.py | 8 ++- tests/units/test_state.py | 92 ++++++++++++++++------------ 4 files changed, 71 insertions(+), 44 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index e14c669f5f..bf6705f9c1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => { if (socket) { socket.emit( "event", - JSON.stringify(event, (k, v) => (v === undefined ? null : v)) + event, ); return true; } @@ -407,6 +407,8 @@ export const connect = async ( transports: transports, autoUnref: false, }); + // Ensure undefined fields in events are sent as null instead of removed + socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v) function checkVisibility() { if (document.visibilityState === "visible") { @@ -444,7 +446,7 @@ export const connect = async ( // On each received message, queue the updates and events. socket.current.on("event", async (message) => { - const update = JSON5.parse(message); + const update = message; for (const substate in update.delta) { dispatch[substate](update.delta[substate]); } diff --git a/reflex/app.py b/reflex/app.py index fc8efb4201..1153066d07 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -17,6 +17,7 @@ import traceback from datetime import datetime from pathlib import Path +from types import SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -362,6 +363,10 @@ def _setup_state(self) -> None: max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, ping_interval=constants.Ping.INTERVAL, ping_timeout=constants.Ping.TIMEOUT, + json=SimpleNamespace( + dumps=staticmethod(format.json_dumps), + loads=staticmethod(json.loads), + ), ) elif getattr(self.sio, "async_mode", "") != "asgi": raise RuntimeError( @@ -1507,7 +1512,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None: """ # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + self.emit(str(constants.SocketEvent.EVENT), update, to=sid) ) async def on_event(self, sid, data): @@ -1520,7 +1525,7 @@ async def on_event(self, sid, data): sid: The Socket.IO session id. data: The event data. """ - fields = json.loads(data) + fields = data # Get the event. event = Event( **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 1b3d1740fe..b006f0927a 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -664,18 +664,22 @@ def format_library_name(library_fullname: str): return lib -def json_dumps(obj: Any) -> str: +def json_dumps(obj: Any, **kwargs) -> str: """Takes an object and returns a jsonified string. Args: obj: The object to be serialized. + kwargs: Additional keyword arguments to pass to json.dumps. Returns: A string """ from reflex.utils import serializers - return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("default", serializers.serialize) + + return json.dumps(obj, **kwargs) def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 45c021bd82..37b6a29c00 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1837,6 +1837,24 @@ async def _coro_waiter(): assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +class CopyingAsyncMock(AsyncMock): + """An AsyncMock, but deepcopy the args and kwargs first.""" + + def __call__(self, *args, **kwargs): + """Call the mock. + + Args: + args: the arguments passed to the mock + kwargs: the keyword arguments passed to the mock + + Returns: + The result of the mock call + """ + args = copy.deepcopy(args) + kwargs = copy.deepcopy(kwargs) + return super().__call__(*args, **kwargs) + + @pytest.fixture(scope="function") def mock_app_simple(monkeypatch) -> rx.App: """Simple Mock app fixture. @@ -1853,7 +1871,7 @@ def mock_app_simple(monkeypatch) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app.event_namespace.emit = AsyncMock() # type: ignore + app.event_namespace.emit = CopyingAsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): return app_module @@ -1957,21 +1975,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): mock_app.event_namespace.emit.assert_called_once() mcall = mock_app.event_namespace.emit.mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) - assert json.loads(mcall.args[1]) == dataclasses.asdict( - StateUpdate( - delta={ - parent_state.get_full_name(): { - "upper": "", - "sum": 3.14, - }, - grandchild_state.get_full_name(): { - "value2": "42", - }, - GrandchildState3.get_full_name(): { - "computed": "", - }, - } - ) + assert mcall.args[1] == StateUpdate( + delta={ + parent_state.get_full_name(): { + "upper": "", + "sum": 3.14, + }, + grandchild_state.get_full_name(): { + "value2": "42", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id @@ -2149,51 +2165,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit - first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) + first_ws_message = emit_mock.mock_calls[0].args[1] assert ( - first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") + first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") is not None ) - assert first_ws_message == { - "delta": { + assert first_ws_message == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": ["background_task:start"], "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } + events=[], + final=True, + ) for call in emit_mock.mock_calls[1:5]: - assert json.loads(call.args[1]) == { - "delta": { + assert call.args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-2].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-2].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": exp_order, "computed_order": exp_order, "dict_list": {}, } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-1].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-1].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": exp_order, }, }, - "events": [], - "final": True, - } + events=[], + final=True, + ) @pytest.mark.asyncio