diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 2428cfa9d2..b2c64cfd08 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -23,6 +23,8 @@ export const clientStorage = {{ client_storage|json_dumps }} export const clientStorage = {} {% endif %} +export const main_state_name = "{{const.main_state_name}}" +export const update_vars_internal = "{{const.update_vars_internal}}" {% if state_name %} export const state_name = "{{state_name}}" diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index e14c669f5f..518b2ce397 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -12,6 +12,8 @@ import { onLoadInternalEvent, state_name, exception_state_name, + main_state_name, + update_vars_internal, } from "$/utils/context.js"; import debounce from "$/utils/helpers/debounce"; import throttle from "$/utils/helpers/throttle"; @@ -117,7 +119,7 @@ export const isStateful = () => { if (event_queue.length === 0) { return false; } - return event_queue.some((event) => event.name.startsWith("reflex___state")); + return event_queue.some(event => event.name.startsWith(main_state_name)); }; /** @@ -822,7 +824,7 @@ export const useEventLoop = ( const vars = {}; vars[storage_to_state_map[e.key]] = e.newValue; const event = Event( - `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`, + `${state_name}.${update_vars_internal}`, { vars: vars } ); addEvents([event], e); @@ -836,7 +838,7 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { const change_start = () => { - const main_state_dispatch = dispatch["reflex___state____state"]; + const main_state_dispatch = dispatch[main_state_name]; if (main_state_dispatch !== undefined) { main_state_dispatch({ is_hydrated: false }); } diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 9f81f319d4..4122a0938a 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -34,7 +34,7 @@ def _compile_document_root(root: Component) -> str: Returns: The compiled document root. """ - return templates.DOCUMENT_ROOT.render( + return templates.document_root().render( imports=utils.compile_imports(root._get_all_imports()), document=root.render(), ) @@ -72,7 +72,7 @@ def _compile_app(app_root: Component) -> str: ("utils_state", f"$/{constants.Dirs.UTILS}/state"), ] - return templates.APP_ROOT.render( + return templates.app_root().render( imports=utils.compile_imports(app_root._get_all_imports()), custom_codes=app_root._get_all_custom_code(), hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, @@ -90,7 +90,7 @@ def _compile_theme(theme: str) -> str: Returns: The compiled theme. """ - return templates.THEME.render(theme=theme) + return templates.theme().render(theme=theme) def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) -> str: @@ -109,7 +109,7 @@ def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) last_compiled_time = str(datetime.now()) return ( - templates.CONTEXT.render( + templates.context().render( initial_state=utils.compile_state(state), state_name=state.get_name(), client_storage=utils.compile_client_storage(state), @@ -118,7 +118,7 @@ def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) default_color_mode=appearance, ) if state - else templates.CONTEXT.render( + else templates.context().render( is_dev_mode=not is_prod_mode(), default_color_mode=appearance, last_compiled_time=last_compiled_time, @@ -145,7 +145,7 @@ def _compile_page( # Compile the code to render the component. kwargs = {"state_name": state.get_name()} if state is not None else {} - return templates.PAGE.render( + return templates.page().render( imports=imports, dynamic_imports=component._get_all_dynamic_imports(), custom_codes=component._get_all_custom_code(), @@ -201,7 +201,7 @@ def _compile_root_stylesheet(stylesheets: list[str]) -> str: ) stylesheet = f"../{constants.Dirs.PUBLIC}/{stylesheet.strip('/')}" sheets.append(stylesheet) if stylesheet not in sheets else None - return templates.STYLE.render(stylesheets=sheets) + return templates.style().render(stylesheets=sheets) def _compile_component(component: Component | StatefulComponent) -> str: @@ -213,7 +213,7 @@ def _compile_component(component: Component | StatefulComponent) -> str: Returns: The compiled component. """ - return templates.COMPONENT.render(component=component) + return templates.component().render(component=component) def _compile_components( @@ -241,7 +241,7 @@ def _compile_components( # Compile the components page. return ( - templates.COMPONENTS.render( + templates.components().render( imports=utils.compile_imports(imports), components=component_renders, ), @@ -319,7 +319,7 @@ def get_shared_components_recursive(component: BaseComponent): f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None ) - return templates.STATEFUL_COMPONENTS.render( + return templates.stateful_components().render( imports=utils.compile_imports(all_imports), memoized_code="\n".join(rendered_components), ) @@ -336,7 +336,7 @@ def _compile_tailwind( Returns: The compiled Tailwind config. """ - return templates.TAILWIND_CONFIG.render( + return templates.tailwind_config().render( **config, ) diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index c868a0cbb7..886e4c1c95 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -11,6 +11,13 @@ class ReflexJinjaEnvironment(Environment): def __init__(self) -> None: """Set default environment.""" + from reflex.state import ( + FrontendEventExceptionState, + OnLoadInternalState, + State, + UpdateVarsInternalState, + ) + extensions = ["jinja2.ext.debug"] super().__init__( extensions=extensions, @@ -42,9 +49,10 @@ def __init__(self) -> None: "set_color_mode": constants.ColorMode.SET, "use_color_mode": constants.ColorMode.USE, "hydrate": constants.CompileVars.HYDRATE, - "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, - "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, - "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, + "main_state_name": State.get_name(), + "on_load_internal": f"{OnLoadInternalState.get_name()}.on_load_internal", + "update_vars_internal": f"{UpdateVarsInternalState.get_name()}.update_vars_internal", + "frontend_exception_state": FrontendEventExceptionState.get_full_name(), } @@ -60,61 +68,172 @@ def get_template(name: str) -> Template: return ReflexJinjaEnvironment().get_template(name=name) -# Template for the Reflex config file. -RXCONFIG = get_template("app/rxconfig.py.jinja2") +def rxconfig(): + """Template for the Reflex config file. + + Returns: + Template: The template for the Reflex config file. + """ + return get_template("app/rxconfig.py.jinja2") + + +def document_root(): + """Code to render a NextJS Document root. + + Returns: + Template: The template for the NextJS Document root. + """ + return get_template("web/pages/_document.js.jinja2") -# Code to render a NextJS Document root. -DOCUMENT_ROOT = get_template("web/pages/_document.js.jinja2") -# Code to render NextJS App root. -APP_ROOT = get_template("web/pages/_app.js.jinja2") +def app_root(): + """Code to render NextJS App root. -# Template for the theme file. -THEME = get_template("web/utils/theme.js.jinja2") + Returns: + Template: The template for the NextJS App root. + """ + return get_template("web/pages/_app.js.jinja2") -# Template for the context file. -CONTEXT = get_template("web/utils/context.js.jinja2") -# Template for Tailwind config. -TAILWIND_CONFIG = get_template("web/tailwind.config.js.jinja2") +def theme(): + """Template for the theme file. -# Template to render a component tag. -COMPONENT = get_template("web/pages/component.js.jinja2") + Returns: + Template: The template for the theme file. + """ + return get_template("web/utils/theme.js.jinja2") -# Code to render a single NextJS page. -PAGE = get_template("web/pages/index.js.jinja2") -# Code to render the custom components page. -COMPONENTS = get_template("web/pages/custom_component.js.jinja2") +def context(): + """Template for the context file. -# Code to render Component instances as part of StatefulComponent -STATEFUL_COMPONENT = get_template("web/pages/stateful_component.js.jinja2") + Returns: + Template: The template for the context file. + """ + return get_template("web/utils/context.js.jinja2") -# Code to render StatefulComponent to an external file to be shared -STATEFUL_COMPONENTS = get_template("web/pages/stateful_components.js.jinja2") -# Sitemap config file. -SITEMAP_CONFIG = "module.exports = {config}".format +def tailwind_config(): + """Template for Tailwind config. -# Code to render the root stylesheet. -STYLE = get_template("web/styles/styles.css.jinja2") + Returns: + Template: The template for the Tailwind config + """ + return get_template("web/tailwind.config.js.jinja2") -# Code that generate the package json file -PACKAGE_JSON = get_template("web/package.json.jinja2") -# Code that generate the pyproject.toml file for custom components. -CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template( - "custom_components/pyproject.toml.jinja2" -) +def component(): + """Template to render a component tag. -# Code that generates the README file for custom components. -CUSTOM_COMPONENTS_README = get_template("custom_components/README.md.jinja2") + Returns: + Template: The template for the component tag. + """ + return get_template("web/pages/component.js.jinja2") -# Code that generates the source file for custom components. -CUSTOM_COMPONENTS_SOURCE = get_template("custom_components/src.py.jinja2") -# Code that generates the init file for custom components. -CUSTOM_COMPONENTS_INIT_FILE = get_template("custom_components/__init__.py.jinja2") +def page(): + """Code to render a single NextJS page. + + Returns: + Template: The template for the NextJS page. + """ + return get_template("web/pages/index.js.jinja2") + + +def components(): + """Code to render the custom components page. + + Returns: + Template: The template for the custom components page. + """ + return get_template("web/pages/custom_component.js.jinja2") + + +def stateful_component(): + """Code to render Component instances as part of StatefulComponent. + + Returns: + Template: The template for the StatefulComponent. + """ + return get_template("web/pages/stateful_component.js.jinja2") + + +def stateful_components(): + """Code to render StatefulComponent to an external file to be shared. + + Returns: + Template: The template for the StatefulComponent. + """ + return get_template("web/pages/stateful_components.js.jinja2") + -# Code that generates the demo app main py file for testing custom components. -CUSTOM_COMPONENTS_DEMO_APP = get_template("custom_components/demo_app.py.jinja2") +def sitemap_config(): + """Sitemap config file. + + Returns: + Template: The template for the sitemap config file. + """ + return "module.exports = {config}".format + + +def style(): + """Code to render the root stylesheet. + + Returns: + Template: The template for the root stylesheet + """ + return get_template("web/styles/styles.css.jinja2") + + +def package_json(): + """Code that generate the package json file. + + Returns: + Template: The template for the package json file + """ + return get_template("web/package.json.jinja2") + + +def custom_components_pyproject_toml(): + """Code that generate the pyproject.toml file for custom components. + + Returns: + Template: The template for the pyproject.toml file + """ + return get_template("custom_components/pyproject.toml.jinja2") + + +def custom_components_readme(): + """Code that generates the README file for custom components. + + Returns: + Template: The template for the README file + """ + return get_template("custom_components/README.md.jinja2") + + +def custom_components_source(): + """Code that generates the source file for custom components. + + Returns: + Template: The template for the source file + """ + return get_template("custom_components/src.py.jinja2") + + +def custom_components_init(): + """Code that generates the init file for custom components. + + Returns: + Template: The template for the init file + """ + return get_template("custom_components/__init__.py.jinja2") + + +def custom_components_demo_app(): + """Code that generates the demo app main py file for testing custom components. + + Returns: + Template: The template for the demo app main py file + """ + return get_template("custom_components/demo_app.py.jinja2") diff --git a/reflex/components/component.py b/reflex/components/component.py index 18dedbf0e5..bf066df66e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -25,7 +25,7 @@ import reflex.state from reflex.base import Base -from reflex.compiler.templates import STATEFUL_COMPONENT +from reflex.compiler.templates import stateful_component from reflex.components.core.breakpoints import Breakpoints from reflex.components.dynamic import load_dynamic_serializer from reflex.components.tags import Tag @@ -2162,7 +2162,7 @@ def _render_stateful_code( component.event_triggers[event_trigger] = memo_trigger # Render the code for this component and hooks. - return STATEFUL_COMPONENT.render( + return stateful_component().render( tag_name=tag_name, memo_trigger_hooks=memo_trigger_hooks, component=component, diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index ce59c3f301..e78cf08804 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -80,7 +80,7 @@ def make_component(component: Component) -> str: ) rendered_components[ - templates.STATEFUL_COMPONENT.render( + templates.stateful_component().render( tag_name="MySSRComponent", memo_trigger_hooks=[], component=component, @@ -101,10 +101,14 @@ def make_component(component: Component) -> str: else: imports[lib] = names - module_code_lines = templates.STATEFUL_COMPONENTS.render( - imports=utils.compile_imports(imports), - memoized_code="\n".join(rendered_components), - ).splitlines()[1:] + module_code_lines = ( + templates.stateful_components() + .render( + imports=utils.compile_imports(imports), + memoized_code="\n".join(rendered_components), + ) + .splitlines()[1:] + ) # Rewrite imports from `/` to destructure from window for ix, line in enumerate(module_code_lines[:]): diff --git a/reflex/config.py b/reflex/config.py index 88230cefec..a882ce9d21 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -2,7 +2,6 @@ from __future__ import annotations -import dataclasses import enum import importlib import inspect @@ -15,6 +14,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generic, List, @@ -149,28 +149,6 @@ def get_url(self) -> str: return f"{self.engine}://{path}/{self.database}" -def get_default_value_for_field(field: dataclasses.Field) -> Any: - """Get the default value for a field. - - Args: - field: The field. - - Returns: - The default value. - - Raises: - ValueError: If no default value is found. - """ - if field.default != dataclasses.MISSING: - return field.default - elif field.default_factory != dataclasses.MISSING: - return field.default_factory() - else: - raise ValueError( - f"Missing value for environment variable {field.name} and no default value found" - ) - - # TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses def interpret_boolean_env(value: str, field_name: str) -> bool: """Interpret a boolean environment variable value. @@ -314,26 +292,47 @@ def interpret_env_var_value( T = TypeVar("T") +ENV_VAR_DEFAULT_FACTORY = Callable[[], T] + class EnvVar(Generic[T]): """Environment variable.""" name: str default: Any + default_factory: Optional[ENV_VAR_DEFAULT_FACTORY] type_: T - def __init__(self, name: str, default: Any, type_: T) -> None: + def __init__( + self, + name: str, + default: Any, + default_factory: Optional[ENV_VAR_DEFAULT_FACTORY], + type_: T, + ) -> None: """Initialize the environment variable. Args: name: The environment variable name. default: The default value. + default_factory: The default factory. type_: The type of the value. """ self.name = name self.default = default + self.default_factory = default_factory self.type_ = type_ + def get_default(self) -> T: + """Get the default value. + + Returns: + The default value. + """ + if self.default_factory is not None: + return self.default_factory() + return self.default + def interpret(self, value: str) -> T: """Interpret the environment variable value. @@ -373,7 +372,7 @@ def get(self) -> T: env_value = self.getenv() if env_value is not None: return env_value - return self.default + return self.get_default() def set(self, value: T | None) -> None: """Set the environment variable. None unsets the variable. @@ -394,16 +393,24 @@ class env_var: # type: ignore name: str default: Any + default_factory: Optional[ENV_VAR_DEFAULT_FACTORY] internal: bool = False - def __init__(self, default: Any, internal: bool = False) -> None: + def __init__( + self, + default: Any = None, + default_factory: Optional[ENV_VAR_DEFAULT_FACTORY] = None, + internal: bool = False, + ) -> None: """Initialize the descriptor. Args: default: The default value. + default_factory: The default factory. internal: Whether the environment variable is reflex internal. """ self.default = default + self.default_factory = default_factory self.internal = internal def __set_name__(self, owner, name): @@ -429,22 +436,30 @@ def __get__(self, instance, owner): env_name = self.name if self.internal: env_name = f"__{env_name}" - return EnvVar(name=env_name, default=self.default, type_=type_) + return EnvVar( + name=env_name, + default=self.default, + type_=type_, + default_factory=self.default_factory, + ) + if TYPE_CHECKING: -if TYPE_CHECKING: + def __new__( + cls, + default: Optional[T] = None, + default_factory: Optional[ENV_VAR_DEFAULT_FACTORY[T]] = None, + internal: bool = False, + ) -> EnvVar[T]: + """Create a new EnvVar instance. - def env_var(default, internal=False) -> EnvVar: - """Typing helper for the env_var descriptor. - - Args: - default: The default value. - internal: Whether the environment variable is reflex internal. - - Returns: - The EnvVar instance. - """ - return default + Args: + cls: The class. + default: The default value. + default_factory: The default factory. + internal: Whether the environment variable is reflex internal. + """ + ... class PathExistsFlag: @@ -465,6 +480,16 @@ class PerformanceMode(enum.Enum): class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" + def __init__(self): + """Initialize the environment variables. + + Raises: + NotImplementedError: Always. + """ + raise NotImplementedError( + f"{type(self).__name__} is a class singleton and not meant to be instantiated." + ) + # Whether to use npm over bun to install frontend packages. REFLEX_USE_NPM: EnvVar[bool] = env_var(False) @@ -564,8 +589,13 @@ class EnvironmentVariables: # The maximum size of the reflex state in kilobytes. REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000) + # Whether to minify state names. Default to true in prod mode and false otherwise. + REFLEX_MINIFY_STATES: EnvVar[Optional[bool]] = env_var( + default_factory=lambda: environment.REFLEX_ENV_MODE.get() == constants.Env.PROD + ) + -environment = EnvironmentVariables() +environment = EnvironmentVariables class Config(Base): diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b7ffef1613..d2966cd21b 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -61,18 +61,6 @@ class CompileVars(SimpleNamespace): CONNECT_ERROR = "connectErrors" # The name of the function for converting a dict to an event. TO_EVENT = "Event" - # The name of the internal on_load event. - ON_LOAD_INTERNAL = "reflex___state____on_load_internal_state.on_load_internal" - # The name of the internal event to update generic state vars. - UPDATE_VARS_INTERNAL = ( - "reflex___state____update_vars_internal_state.update_vars_internal" - ) - # The name of the frontend event exception state - FRONTEND_EXCEPTION_STATE = "reflex___state____frontend_event_exception_state" - # The full name of the frontend exception state - FRONTEND_EXCEPTION_STATE_FULL = ( - f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}" - ) class PageNames(SimpleNamespace): diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 6be64ae2d3..1406e37bc1 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -65,7 +65,7 @@ def _create_package_config(module_name: str, package_name: str): pyproject = Path(CustomComponents.PYPROJECT_TOML) pyproject.write_text( - templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( + templates.custom_components_pyproject_toml().render( module_name=module_name, package_name=package_name, reflex_version=constants.Reflex.VERSION, @@ -106,7 +106,7 @@ def _create_readme(module_name: str, package_name: str): readme = Path(CustomComponents.PACKAGE_README) readme.write_text( - templates.CUSTOM_COMPONENTS_README.render( + templates.custom_components_readme().render( module_name=module_name, package_name=package_name, ) @@ -129,14 +129,14 @@ def _write_source_and_init_py( module_path = custom_component_src_dir / f"{module_name}.py" module_path.write_text( - templates.CUSTOM_COMPONENTS_SOURCE.render( + templates.custom_components_source().render( component_class_name=component_class_name, module_name=module_name ) ) init_path = custom_component_src_dir / CustomComponents.INIT_FILE init_path.write_text( - templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name) + templates.custom_components_init.render(module_name=module_name) ) @@ -164,7 +164,7 @@ def _populate_demo_app(name_variants: NameVariants): # This source file is rendered using jinja template file. with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f: f.write( - templates.CUSTOM_COMPONENTS_DEMO_APP.render( + templates.custom_components_demo_app().render( custom_component_module_dir=name_variants.custom_component_module_dir, module_name=name_variants.module_name, ) diff --git a/reflex/reflex.py b/reflex/reflex.py index 9781c393ad..9c3a96c89d 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -2,6 +2,7 @@ from __future__ import annotations +# WARNING: do not import any modules that contain rx.State subclasses here import atexit import os from pathlib import Path @@ -16,7 +17,6 @@ from reflex import constants from reflex.config import environment, get_config from reflex.custom_components.custom_components import custom_components_cli -from reflex.state import reset_disk_state_manager from reflex.utils import console, telemetry # Disable typer+rich integration for help panels @@ -134,14 +134,16 @@ def _run( loglevel: constants.LogLevel = config.loglevel, ): """Run the app in the given directory.""" + # Set env mode in the environment + # This must be set before importing modules that contain rx.State subclasses + environment.REFLEX_ENV_MODE.set(env) + + from reflex.state import reset_disk_state_manager from reflex.utils import build, exec, prerequisites, processes # Set the log level. console.set_log_level(loglevel) - # Set env mode in the environment - environment.REFLEX_ENV_MODE.set(env) - # Show system info exec.output_system_info() diff --git a/reflex/state.py b/reflex/state.py index 95f7f64f68..c1e423f9ea 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -301,6 +301,61 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) +# Keep track of all state instances to calculate minified state names +state_count: int = 0 + +minified_state_names: Dict[str, str] = {} + + +def next_minified_state_name() -> str: + """Get the next minified state name. + + Returns: + The next minified state name. + """ + global state_count + num = state_count + + # All possible chars for minified state name + chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ$_" + base = len(chars) + state_name = "" + + if num == 0: + state_name = chars[0] + + while num > 0: + state_name = chars[num % base] + state_name + num = num // base + + state_count += 1 + + return state_name + + +def get_minified_state_name(state_name: str) -> str: + """Generate a minified state name. + + Args: + state_name: The state name to minify. + + Returns: + The minified state name. + + Raises: + ValueError: If no more minified state names are available + """ + if state_name in minified_state_names: + return minified_state_names[state_name] + + while name := next_minified_state_name(): + if name in minified_state_names.values(): + continue + minified_state_names[state_name] = name + return name + raise ValueError("No more minified state names available") + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -896,7 +951,10 @@ def get_name(cls) -> str: The name of the state. """ module = cls.__module__.replace(".", "___") - return format.to_snake_case(f"{module}___{cls.__name__}") + state_name = format.to_snake_case(f"{module}___{cls.__name__}") + if environment.REFLEX_MINIFY_STATES.get(): + return get_minified_state_name(state_name) + return state_name @classmethod @functools.lru_cache() diff --git a/reflex/testing.py b/reflex/testing.py index 9ddb03504a..ac0b563fde 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -46,12 +46,15 @@ from reflex.config import environment from reflex.state import ( BaseState, + State, StateManager, StateManagerDisk, StateManagerMemory, StateManagerRedis, + minified_state_names, reload_state_module, ) +from reflex.utils.types import override try: from selenium import webdriver # pyright: ignore [reportMissingImports] @@ -141,7 +144,7 @@ def create( Callable[[], None] | types.ModuleType | str | functools.partial[Any] ] = None, app_name: Optional[str] = None, - ) -> "AppHarness": + ) -> AppHarness: """Create an AppHarness instance at root. Args: @@ -192,9 +195,12 @@ def get_state_name(self, state_cls_name: str) -> str: Returns: The state name """ - return reflex.utils.format.to_snake_case( + state_name = reflex.utils.format.to_snake_case( f"{self.app_name}___{self.app_name}___" + state_cls_name ) + if environment.REFLEX_MINIFY_STATES.get(): + return minified_state_names.get(state_name, state_name) + return state_name def get_full_state_name(self, path: List[str]) -> str: """Get the full state name for the given state class name. @@ -207,7 +213,7 @@ def get_full_state_name(self, path: List[str]) -> str: """ # NOTE: using State.get_name() somehow causes trouble here # path = [State.get_name()] + [self.get_state_name(p) for p in path] - path = ["reflex___state____state"] + [self.get_state_name(p) for p in path] + path = [State.get_name()] + [self.get_state_name(p) for p in path] return ".".join(path) def _get_globals_from_signature(self, func: Any) -> dict[str, Any]: @@ -412,7 +418,7 @@ def consume_frontend_output(): self.frontend_output_thread = threading.Thread(target=consume_frontend_output) self.frontend_output_thread.start() - def start(self) -> "AppHarness": + def start(self) -> AppHarness: """Start the backend in a new thread and dev frontend as a separate process. Returns: @@ -442,7 +448,7 @@ def get_app_global_source(key, value): return f"{key} = {value!r}" return inspect.getsource(value) - def __enter__(self) -> "AppHarness": + def __enter__(self) -> AppHarness: """Contextmanager protocol for `start()`. Returns: @@ -921,6 +927,7 @@ def _run_frontend(self): ) self.frontend_server.serve_forever() + @override def _start_frontend(self): # Set up the frontend. with chdir(self.app_path): @@ -932,17 +939,19 @@ def _start_frontend(self): zipping=False, frontend=True, backend=False, - loglevel=reflex.constants.LogLevel.INFO, + loglevel=reflex.constants.base.LogLevel.INFO, ) self.frontend_thread = threading.Thread(target=self._run_frontend) self.frontend_thread.start() + @override def _wait_frontend(self): - self._poll_for(lambda: self.frontend_server is not None) + _ = self._poll_for(lambda: self.frontend_server is not None) if self.frontend_server is None or not self.frontend_server.socket.fileno(): raise RuntimeError("Frontend did not start") + @override def _start_backend(self): if self.app_instance is None: raise RuntimeError("App was not initialized.") @@ -959,12 +968,14 @@ def _start_backend(self): self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread.start() + @override def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket: try: return super()._poll_for_servers(timeout) finally: environment.REFLEX_SKIP_COMPILE.set(None) + @override def stop(self): """Stop the frontend python webserver.""" super().stop() diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 14709d99ce..8fae83c45a 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -44,7 +44,7 @@ def generate_sitemap_config(deploy_url: str, export=False): config = json.dumps(config) sitemap = prerequisites.get_web_dir() / constants.Next.SITEMAP_CONFIG_FILE - sitemap.write_text(templates.SITEMAP_CONFIG(config=config)) + sitemap.write_text(templates.sitemap_config()(config=config)) def _zip( diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index ec79b3297b..35eb3234fe 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -441,7 +441,7 @@ def create_config(app_name: str): config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config" with open(constants.Config.FILE, "w") as f: console.debug(f"Creating {constants.Config.FILE}") - f.write(templates.RXCONFIG.render(app_name=app_name, config_name=config_name)) + f.write(templates.rxconfig().render(app_name=app_name, config_name=config_name)) def initialize_gitignore( @@ -611,7 +611,7 @@ def initialize_web_directory(): def _compile_package_json(): - return templates.PACKAGE_JSON.render( + return templates.package_json().render( scripts={ "dev": constants.PackageJson.Commands.DEV, "export": constants.PackageJson.Commands.EXPORT, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f7b825f162..75dbd7f77c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,10 +3,13 @@ import os import re from pathlib import Path +from typing import Generator, Type import pytest +import reflex.constants from reflex.config import environment +from reflex.constants.base import Env from reflex.testing import AppHarness, AppHarnessProd DISPLAY = None @@ -64,15 +67,30 @@ def pytest_exception_interact(node, call, report): @pytest.fixture( - scope="session", params=[AppHarness, AppHarnessProd], ids=["dev", "prod"] + scope="session", + params=[ + AppHarness, + AppHarnessProd, + ], + ids=[ + reflex.constants.Env.DEV.value, + reflex.constants.Env.PROD.value, + ], ) -def app_harness_env(request): +def app_harness_env( + request: pytest.FixtureRequest, +) -> Generator[Type[AppHarness], None, None]: """Parametrize the AppHarness class to use for the test, either dev or prod. Args: request: The pytest fixture request object. - Returns: + Yields: The AppHarness class to use for the test. """ - return request.param + harness: Type[AppHarness] = request.param + if issubclass(harness, AppHarnessProd): + environment.REFLEX_ENV_MODE.set(Env.PROD) + yield harness + if isinstance(harness, AppHarnessProd): + environment.REFLEX_ENV_MODE.set(None) diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index a414581736..03aaf18b42 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -106,7 +106,6 @@ def index() -> rx.Component: ), ) - # raise Exception(State.count3._deps(objclass=State)) app = rx.App() app.add_page(index) diff --git a/tests/integration/test_minified_states.py b/tests/integration/test_minified_states.py new file mode 100644 index 0000000000..da63b6f7c4 --- /dev/null +++ b/tests/integration/test_minified_states.py @@ -0,0 +1,161 @@ +"""Integration tests for minified state names.""" + +from __future__ import annotations + +from functools import partial +from typing import Generator, Optional, Type + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.config import environment +from reflex.testing import AppHarness, AppHarnessProd + + +def MinifiedStatesApp(minify: bool | None) -> None: + """A test app for minified state names. + + Args: + minify: whether to minify state names + """ + import reflex as rx + + class MinifiedState(rx.State): + """State for the MinifiedStatesApp app.""" + + pass + + app = rx.App() + + def index(): + return rx.vstack( + rx.input( + value=MinifiedState.router.session.client_token, + is_read_only=True, + id="token", + ), + rx.text(f"minify: {minify}", id="minify"), + rx.text(MinifiedState.get_name(), id="state_name"), + rx.text(MinifiedState.get_full_name(), id="state_full_name"), + ) + + app.add_page(index) + + +@pytest.fixture( + params=[ + pytest.param(False), + pytest.param(True), + pytest.param(None), + ], +) +def minify_state_env( + request: pytest.FixtureRequest, +) -> Generator[Optional[bool], None, None]: + """Set the environment variable to minify state names. + + Args: + request: pytest fixture request + + Yields: + minify_states: whether to minify state names + """ + minify_states: Optional[bool] = request.param + environment.REFLEX_MINIFY_STATES.set(minify_states) + yield minify_states + environment.REFLEX_MINIFY_STATES.set(None) + + +@pytest.fixture +def test_app( + app_harness_env: Type[AppHarness], + tmp_path_factory: pytest.TempPathFactory, + minify_state_env: Optional[bool], +) -> Generator[AppHarness, None, None]: + """Start MinifiedStatesApp app at tmp_path via AppHarness. + + Args: + app_harness_env: either AppHarness (dev) or AppHarnessProd (prod) + tmp_path_factory: pytest tmp_path_factory fixture + minify_state_env: need to request this fixture to set env before the app starts + + Yields: + running AppHarness instance + + """ + name = f"testminifiedstates_{app_harness_env.__name__.lower()}" + with app_harness_env.create( + root=tmp_path_factory.mktemp(name), + app_name=name, + app_source=partial(MinifiedStatesApp, minify=minify_state_env), # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the test_app app. + + Args: + test_app: harness for MinifiedStatesApp app + + Yields: + WebDriver instance. + + """ + assert test_app.app_instance is not None, "app is not running" + driver = test_app.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_minified_states( + test_app: AppHarness, + driver: WebDriver, + minify_state_env: Optional[bool], +) -> None: + """Test minified state names. + + Args: + test_app: harness for MinifiedStatesApp + driver: WebDriver instance. + minify_state_env: whether state minification is enabled by env var. + + """ + assert test_app.app_instance is not None, "app is not running" + + is_prod = isinstance(test_app, AppHarnessProd) + + # default to minifying in production + should_minify: bool = is_prod + + # env overrides default + if minify_state_env is not None: + should_minify = minify_state_env + + # get a reference to the connected client + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = test_app.poll_for_value(token_input) + assert token + + state_name_text = driver.find_element(By.ID, "state_name") + assert state_name_text + state_name = state_name_text.text + + state_full_name_text = driver.find_element(By.ID, "state_full_name") + assert state_full_name_text + _ = state_full_name_text.text + + assert test_app.app_module + module_state_prefix = test_app.app_module.__name__.replace(".", "___") + + if should_minify: + assert len(state_name) == 1 + else: + assert state_name == f"{module_state_prefix}____minified_state" diff --git a/tests/test_minify_state.py b/tests/test_minify_state.py new file mode 100644 index 0000000000..1e49a227e6 --- /dev/null +++ b/tests/test_minify_state.py @@ -0,0 +1,14 @@ +from typing import Set + +from reflex.state import next_minified_state_name + + +def test_next_minified_state_name(): + """Test that the next_minified_state_name function returns unique state names.""" + state_names: Set[str] = set() + gen = 10000 + for _ in range(gen): + state_name = next_minified_state_name() + assert state_name not in state_names + state_names.add(state_name) + assert len(state_names) == gen diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 5d3aee6c7a..df19345ed1 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1032,7 +1032,7 @@ def _dynamic_state_event(name, val, **kwargs): prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): on_load_internal = _event( - name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", + name=f"{state.get_full_name()}.on_load_internal", val=exp_val, ) exp_router_data = { diff --git a/tests/units/test_config.py b/tests/units/test_config.py index e5d4622bd6..308c839461 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -253,6 +253,11 @@ class TestEnv: INTERNAL: EnvVar[str] = env_var("default", internal=True) BOOLEAN: EnvVar[bool] = env_var(False) + # default_factory with other env_var as fallback + BLUBB_OR_BLA: EnvVar[str] = env_var( + default_factory=lambda: TestEnv.BLUBB.getenv() or "bla" + ) + assert TestEnv.BLUBB.get() == "default" assert TestEnv.BLUBB.name == "BLUBB" TestEnv.BLUBB.set("new") @@ -280,3 +285,15 @@ class TestEnv: assert TestEnv.BOOLEAN.get() is False TestEnv.BOOLEAN.set(None) assert "BOOLEAN" not in os.environ + + assert TestEnv.BLUBB_OR_BLA.get() == "bla" + TestEnv.BLUBB.set("new") + assert TestEnv.BLUBB_OR_BLA.get() == "new" + TestEnv.BLUBB.set(None) + assert TestEnv.BLUBB_OR_BLA.get() == "bla" + TestEnv.BLUBB_OR_BLA.set("test") + assert TestEnv.BLUBB_OR_BLA.get() == "test" + TestEnv.BLUBB.set("other") + assert TestEnv.BLUBB_OR_BLA.get() == "test" + TestEnv.BLUBB_OR_BLA.set(None) + TestEnv.BLUBB.set(None) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 45c021bd82..e79878c984 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -66,6 +66,7 @@ LOCK_EXPIRATION = 2000 if CI else 300 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 +ON_LOAD_INTERNAL = f"{OnLoadInternalState.get_name()}.on_load_internal" formatted_router = { "session": {"client_token": "", "client_ip": "", "session_id": ""}, @@ -2818,7 +2819,7 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): app=app, event=Event( token=token, - name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}", + name=f"{state.get_name()}.{ON_LOAD_INTERNAL}", router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, ), sid="sid", @@ -2865,7 +2866,7 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): app=app, event=Event( token=token, - name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}", + name=f"{state.get_full_name()}.{ON_LOAD_INTERNAL}", router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, ), sid="sid", diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index dd1a3b3ef7..e3dae7a2be 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -271,7 +271,7 @@ def test_unsupported_literals(cls: type): ], ) def test_create_config(app_name, expected_config_name, mocker): - """Test templates.RXCONFIG is formatted with correct app name and config class name. + """Test templates.rxconfig is formatted with correct app name and config class name. Args: app_name: App name. @@ -279,9 +279,9 @@ def test_create_config(app_name, expected_config_name, mocker): mocker: Mocker object. """ mocker.patch("builtins.open") - tmpl_mock = mocker.patch("reflex.compiler.templates.RXCONFIG") + tmpl_mock = mocker.patch("reflex.compiler.templates.rxconfig") prerequisites.create_config(app_name) - tmpl_mock.render.assert_called_with( + tmpl_mock().render.assert_called_with( app_name=app_name, config_name=expected_config_name ) @@ -592,8 +592,23 @@ def test_style_prop_with_event_handler_value(callable): ) -def test_is_prod_mode() -> None: - """Test that the prod mode is correctly determined.""" +@pytest.fixture +def cleanup_reflex_env_mode(): + """Cleanup the reflex env mode. + + Yields: + None + """ + yield + environment.REFLEX_ENV_MODE.set(None) + + +def test_is_prod_mode(cleanup_reflex_env_mode: None) -> None: + """Test that the prod mode is correctly determined. + + Args: + cleanup_reflex_env_mode: Fixture to cleanup the reflex env mode. + """ environment.REFLEX_ENV_MODE.set(constants.Env.PROD) assert utils_exec.is_prod_mode() environment.REFLEX_ENV_MODE.set(None)