Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First take at reactive implementation #94

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 93 additions & 18 deletions dash_extensions/enrich.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import dash_html_components as html
import dash.dependencies as dd
import plotly

from dash.dependencies import Input, Output, State, MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction
from dash.dependencies import MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction
from dash.development.base_component import Component
from flask import session
from flask_caching.backends import FileSystemCache, RedisCache
Expand All @@ -29,9 +28,11 @@ class DashProxy(dash.Dash):
def __init__(self, *args, transforms=None, **kwargs):
super().__init__(*args, **kwargs)
self.callbacks = []
self.reactive_variables = []
self.clientside_callbacks = []
self.arg_types = [dd.Output, dd.Input, dd.State]
self.transforms = transforms if transforms is not None else []
self.layout_extension = LayoutExtension()
# Do the transform initialization.
for transform in self.transforms:
transform.init(self)
Expand Down Expand Up @@ -84,6 +85,38 @@ def clientside_callback(self, clientside_function, *args, **kwargs):
callback["f"] = clientside_function
self.clientside_callbacks.append(callback)

def _collect_reactive(self, name):
self.reactive_variables.append(name)
self.layout_extension.components.append(dcc.Store(name, "data"))

def reactive(self, *args, serverside=None, output=None, **kwargs):
# If the output is not specified, create it. Per default, use serverside if available.
serverside = True if serverside is None else serverside
if output is None:
if serverside and any([isinstance(t, ServersideOutputTransform) for t in self.transforms]):
output = ServersideOutput(None, None)
else:
output = Output(None)
# Collect the callback, delay binding of output id.
callback = self._collect_callback(output, *args, **kwargs)
self.callbacks.append(callback)

def wrapper(f):
component_id = f.__name__
output.component_id = component_id
output.component_property = "data"
self._collect_reactive(component_id)
callback["f"] = f

return wrapper

def clientside_reactive(self, name, clientside_function, *args, **kwargs):
output = Output(name, "data")
self._collect_reactive(name)
callback = self._collect_callback(output, *args, **kwargs)
callback["f"] = clientside_function
self.clientside_callbacks.append(callback)

def _register_callbacks(self, app=None):
callbacks, clientside_callbacks = self._resolve_callbacks()
app = super() if app is None else app
Expand All @@ -98,7 +131,7 @@ def _layout_value(self):
layout = self._layout() if self._layout_is_function else self._layout
for transform in self.transforms:
layout = transform.layout(layout, self._layout_is_function)
return layout
return self.layout_extension.layout(layout, self._layout_is_function)

def _setup_server(self):
"""
Expand All @@ -112,11 +145,25 @@ def _setup_server(self):
if not self.server.secret_key:
self.server.secret_key = secrets.token_urlsafe(16)

def _resolve_reactive_variables(self, callbacks):
for callback in callbacks:
for item in callback[dd.Input]:
if item.component_id in self.reactive_variables:
item.component_property = "data"
for item in callback[dd.State]:
if item.component_id in self.reactive_variables:
item.component_property = "data"
return callbacks

def _resolve_callbacks(self):
"""
This method resolves the callbacks, i.e. it applies the callback injections.
"""
callbacks, clientside_callbacks = self.callbacks, self.clientside_callbacks
# Resolve reactive variables.
callbacks = self._resolve_reactive_variables(callbacks)
clientside_callbacks = self._resolve_reactive_variables(clientside_callbacks)
# Apply transforms.
for transform in self.transforms:
callbacks, clientside_callbacks = transform.apply(callbacks, clientside_callbacks)
return callbacks, clientside_callbacks
Expand Down Expand Up @@ -193,6 +240,39 @@ def layout(self, layout, layout_is_function):
return layout



class LayoutExtension:

def __init__(self):
self.initialized = False
self.components = []

def layout(self, layout, layout_is_function):
if layout_is_function or not self.initialized:
children = _as_list(layout.children) + self.components
layout.children = children
self.initialized = True
return layout

# endregion

# region Default component property values

class Input(dd.Input):
def __init__(self, component_id, component_property=None):
component_property = "value" if component_property is None else component_property
super().__init__(component_id, component_property)

class State(dd.State):
def __init__(self, component_id, component_property=None):
component_property = "value" if component_property is None else component_property
super().__init__(component_id, component_property)

class Output(dd.Output):
def __init__(self, component_id, component_property=None):
component_property = "children" if component_property is None else component_property
super().__init__(component_id, component_property)

# endregion

# region Prefix ID transform
Expand Down Expand Up @@ -398,7 +478,7 @@ def apply(self, callbacks, clientside_callbacks):
# Group by output.
output_map = defaultdict(list)
for callback in all_callbacks:
for output in callback[Output]:
for output in callback[dd.Output]:
output_map[output].append(callback)
# Apply multiplexer where needed.
for output in output_map:
Expand All @@ -418,7 +498,7 @@ def _apply_multiplexer(self, output, callbacks):
# Create proxy element.
proxies.append(_mp_element(mp_id_escaped))
# Assign proxy element as output.
callback[Output][callback[Output].index(output)] = Output(mp_id_escaped, _mp_prop())
callback[dd.Output][callback[dd.Output].index(output)] = Output(mp_id_escaped, _mp_prop())
# Create proxy input.
inputs.append(Input(mp_id, _mp_prop()))
# Collect proxy elements to add to layout.
Expand Down Expand Up @@ -548,7 +628,7 @@ def decorated_function(*args):
# Figure out if an update is necessary.
unique_ids = []
update_needed = False
for i, output in enumerate(callback[Output]):
for i, output in enumerate(callback[dd.Output]):
# Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere).
is_trigger = trigger_filter(callback["sorted_args"])
filtered_args = [arg for i, arg in enumerate(args) if not is_trigger[i]]
Expand All @@ -560,20 +640,20 @@ def decorated_function(*args):
break
# If not update is needed, just return the ids (or values, if not serverside output).
if not update_needed:
results = [uid if isinstance(callback[Output][i], ServersideOutput) else
callback[Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)]
results = [uid if isinstance(callback[dd.Output][i], ServersideOutput) else
callback[dd.Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)]
return results if multi_output else results[0]
# Do the update.
data = f(*args)
data = list(data) if multi_output else [data]
if callable(memoize):
data = memoize(data)
for i, output in enumerate(callback[Output]):
for i, output in enumerate(callback[dd.Output]):
# Skip no_update updates.
if isinstance(data[i], type(dash.no_update)):
continue
# Replace only for server side outputs.
serverside_output = isinstance(callback[Output][i], ServersideOutput)
serverside_output = isinstance(callback[dd.Output][i], ServersideOutput)
if serverside_output or memoize:
# Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere).
is_trigger = trigger_filter(callback["sorted_args"])
Expand Down Expand Up @@ -662,23 +742,18 @@ def get(self, key, ignore_expired=False):
class NoOutputTransform(DashTransform):

def __init__(self):
self.initialized = False
self.hidden_divs = []
self.layout_extension = LayoutExtension()

def layout(self, layout, layout_is_function):
if layout_is_function or not self.initialized:
children = _as_list(layout.children) + self.hidden_divs
layout.children = children
self.initialized = True
return layout
return self.layout_extension.layout(layout, layout_is_function)

def _apply(self, callbacks):
for callback in callbacks:
if len(callback[dd.Output]) == 0:
output_id = _get_output_id(callback)
hidden_div = html.Div(id=output_id, style={"display": "none"})
callback[dd.Output] = [dd.Output(output_id, "children")]
self.hidden_divs.append(hidden_div)
self.layout_extension.components.append(hidden_div)
return callbacks

def apply_serverside(self, callbacks):
Expand Down
26 changes: 26 additions & 0 deletions reactive_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objects as go
from dash_extensions.enrich import Dash, Output, Input

app = Dash()
app.layout = html.Div([dcc.Input(value=1, id='x', type='number'),
dcc.Input(value=1, id='power', type='number'),
html.Div(id='result'), dcc.Graph(id='graph')])

@app.reactive(Input('x'), Input('power'))
def z(x, y):
return x ** y if (x and y) else None

#app.clientside_reactive("z", "function(x,y){return x**y}", Input('x'), Input('power')) ??

@app.callback(Output('result'), Input('x'), Input('power'), Input('z'))
def display_result(x, y, z):
return f"{x}^{y} is {z}"

@app.callback(Output('graph', 'figure'), Input('x'), Input('power'), Input('z'))
def plot_result(x, y, z):
return go.Figure([go.Bar(x=['x', 'y', 'x**y'], y=[x, y, z])])

if __name__ == "__main__":
app.run_server()
30 changes: 30 additions & 0 deletions reactive_example2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import dash_core_components as dcc
import dash_html_components as html
import plotly.express as px
import dash_table
from dash_extensions.enrich import DashProxy, Output, Input, ServersideOutputTransform

# Read the full, complex dataset here (for the sake of simplicify, a small px dataset is used). ??
df_all = px.data.gapminder()
# Example app demonstrating how to share state between callbacks via a reactive variable.
app = DashProxy(transforms=[ServersideOutputTransform()])
app.layout = html.Div([
dcc.Dropdown(options=[dict(value=x, label=x) for x in df_all.country.unique()], id="country", value="Denmark"),
dcc.Graph(id='graph'),
dash_table.DataTable(id='table', columns=[{"name": i, "id": i} for i in df_all.columns])
])

@app.reactive(Input('country')) # default prop for input/state is "value"
def df_filtered(country): # reactive variable name = function name
return df_all[df_all.country == country] # defaults to serverside output, i.e. json serialization is not needed

@app.callback(Output('table', 'data'), Input('df_filtered')) # access reactive variable via it's ID
def update_table(df):
return df.to_dict('records') # the reactive variable was stored serverside, i.e. deserialize is not needed

@app.callback(Output('graph', 'figure'), Input('df_filtered'))
def update_graph(df):
return px.bar(df, x='year', y='pop', color='gdpPercap')

if __name__ == "__main__":
app.run_server()