diff --git a/examples/basic_example.yaml b/examples/basic_example.yaml index 5ff0114b..77bfdb42 100644 --- a/examples/basic_example.yaml +++ b/examples/basic_example.yaml @@ -40,6 +40,10 @@ pages: content: - container: TablePlotter csv_file: ./example_data.csv + filter_cols: + - Well + - Segment + - Average permeability (D) contact_person: name: Ola Nordmann phone: +47 12345678 diff --git a/tests/test_table_plotter.py b/tests/test_table_plotter.py index b0636811..75862621 100644 --- a/tests/test_table_plotter.py +++ b/tests/test_table_plotter.py @@ -20,6 +20,8 @@ def test_table_plotter(dash_duo): # Checking that no plot options are defined assert page.plot_options == {} + # Check that filter is not active + assert page.use_filter == False # Checking that the selectors are not hidden selector_row = dash_duo.find_element(f'#{page.selector_row}') @@ -43,6 +45,47 @@ def test_table_plotter(dash_duo): plot_option_dd = dash_duo.find_element(f'#{page.plot_option_id}-{option}') assert plot_option_dd.text == 'Well' +def test_table_plotter_filter(dash_duo): + + app = dash.Dash(__name__) + app.config.suppress_callback_exceptions = True + cache.init_app(app.server) + + csv_file = './tests/data/example_data.csv' + page = _table_plotter.TablePlotter(app, csv_file, filter_cols=['Well']) + app.layout = page.layout + dash_duo.start_server(app) + + # Wait for the app to render(there is probably a better way...) + time.sleep(5) + + # Checking that no plot options are defined + assert page.plot_options == {} + # Check that filter is active + assert page.use_filter == True + assert page.filter_cols == ['Well'] + # Checking that the selectors are not hidden + selector_row = dash_duo.find_element(f'#{page.selector_row}') + assert selector_row.get_attribute('style') == '' + + # Checking that the correct plot type is initialized + plot_dd = dash_duo.find_element(f'#{page.plot_option_id}-plottype') + assert plot_dd.text == 'scatter' + + # Checking that only the relevant options are shown + for plot_option in page.plot_args.keys(): + plot_option_dd = dash_duo.find_element( + f'#{page.plot_option_id}-div-{plot_option}') + if plot_option in page.plots['scatter']: + assert plot_option_dd.get_attribute('style') == 'display: grid;' + else: + assert plot_option_dd.get_attribute('style') == 'display: none;' + + + # Checking that options are initialized correctly + for option in ['x', 'y']: + plot_option_dd = dash_duo.find_element(f'#{page.plot_option_id}-{option}') + assert plot_option_dd.text == 'Well' def test_initialized_table_plotter(dash_duo): diff --git a/webviz_config/containers/_table_plotter.py b/webviz_config/containers/_table_plotter.py index 0472a994..6a3d25e0 100644 --- a/webviz_config/containers/_table_plotter.py +++ b/webviz_config/containers/_table_plotter.py @@ -27,13 +27,14 @@ class TablePlotter(WebvizContainer): ''' def __init__(self, app, csv_file: Path, plot_options: dict = None, - lock: bool = False): + filter_cols: list = None, lock: bool = False): self.plot_options = plot_options if plot_options else {} self.graph_id = f'graph-id{uuid4()}' self.lock = lock self.csv_file = csv_file self.data = get_data(self.csv_file) + self.set_filters(filter_cols) self.columns = list(self.data.columns) self.numeric_columns = list( self.data.select_dtypes(include=[np.number]).columns) @@ -41,6 +42,20 @@ def __init__(self, app, csv_file: Path, plot_options: dict = None, self.plot_option_id = f'plot-option{uuid4()}' self.set_callbacks(app) + def set_filters(self, filter_cols): + self.filter_cols = [] + self.filter_ids = {} + self.use_filter = False + if filter_cols: + for col in filter_cols: + if col in self.data.columns: + if self.data[col].nunique() != 1: + self.filter_cols.append(col) + if self.filter_cols: + self.use_filter = True + self.filter_ids = {col: f'{col}-{str(uuid4())}' + for col in self.filter_cols} + def add_webvizstore(self): return [(get_data, [{'csv_file': self.csv_file}])] @@ -132,6 +147,50 @@ def plot_args(self): }) + def filter_layout(self): + '''Makes dropdowns for each dataframe column used for filtering.''' + if not self.use_filter: + return None + df = self.data + dropdowns = [html.H4('Set filters')] + for col in self.filter_cols: + if(df[col].dtype == np.float64 or df[col].dtype == np.int64): + min_val = df[col].min() + max_val = df[col].max() + mean_val = df[col].mean() + dropdowns.append( + html.Div(children=[ + html.Details(open=True, children=[ + html.Summary(col.lower().capitalize()), + dcc.RangeSlider( + id=self.filter_ids[col], + min=min_val, + max=max_val, + step=(max_val-min_val)/10, + marks={min_val: f'{min_val:.2f}', + mean_val: f'{mean_val:.2f}', + max_val: f'{max_val:.2f}'}, + value=[min_val, max_val]) + ]) + ]) + ) + else: + elements = list(self.data[col].unique()) + dropdowns.append( + html.Div(children=[ + html.Details(open=True, children=[ + html.Summary(col.lower().capitalize()), + dcc.Dropdown( + id=self.filter_ids[col], + options=[{'label': i, 'value': i} + for i in elements], + value=elements, + multi=True) + ]) + ]) + ) + return dropdowns + def plot_option_layout(self): '''Renders a dropdown widget for each plot option''' divs = [] @@ -140,6 +199,7 @@ def plot_option_layout(self): html.Div( style=self.style_options_div, children=[ + html.H4('Set plot options'), html.P('Plot type'), dcc.Dropdown( id=f'{self.plot_option_id}-plottype', @@ -188,11 +248,14 @@ def style_options_div_hidden(self): @property def style_page_layout(self): '''Simple grid layout for the page''' - return {} if self.lock else { + if self.lock: + return {} + return { 'display': 'grid', 'align-content': 'space-around', 'justify-content': 'space-between', - 'grid-template-columns': '1fr 5fr' + 'grid-template-columns': + '1fr 5fr 1fr' if self.use_filter else '1fr 5fr' } @property @@ -210,7 +273,8 @@ def layout(self): html.Div(style={'height': '100%'}, children=dcc.Graph(id=self.graph_id, config={ 'responsive': 'true'}) - ) + ), + html.Div(children=self.filter_layout()) ]) ]) @@ -239,6 +303,8 @@ def plot_input_callbacks(self): Input(f'{self.plot_option_id}-plottype', 'value')) for plot_arg in self.plot_args.keys(): inputs.append(Input(f'{self.plot_option_id}-{plot_arg}', 'value')) + for filtcol in self.filter_cols: + inputs.append(Input(self.filter_ids[filtcol], 'value')) return inputs def set_callbacks(self, app): @@ -259,16 +325,41 @@ def _update_output(*args): plotfunc = getattr(px._chart_types, plot_type) plotargs = {} div_style = [] - for name, plot_arg in zip(self.plot_args.keys(), args[1:]): + data = self.data + # Filter dataframe if filter columns are available + if self.use_filter: + plot_inputs = args[1:-len(self.filter_cols)] + filter_inputs = args[-len(self.filter_cols):] + data = filter_dataframe(data, self.filter_cols, filter_inputs) + else: + plot_inputs = args[1:] + for name, plot_arg in zip(self.plot_args.keys(), plot_inputs): if name in self.plots[plot_type]: plotargs[name] = plot_arg div_style.append(self.style_options_div) else: div_style.append(self.style_options_div_hidden) - return (plotfunc(self.data, **plotargs), *div_style) + + return (plotfunc(data, **plotargs), *div_style) @cache.memoize(timeout=cache.TIMEOUT) @webvizstore def get_data(csv_file) -> pd.DataFrame: return pd.read_csv(csv_file, index_col=None) + + +@cache.memoize(timeout=cache.TIMEOUT) +def filter_dataframe(dframe, columns, column_values): + df = dframe.copy() + if not isinstance(columns, list): + columns = [columns] + for filt, col in zip(column_values, columns): + if isinstance(filt, list): + if (df[col].dtype == np.float64 or df[col].dtype == np.int64): + df = df.loc[df[col].between(filt[0], filt[1])] + else: + df = df.loc[df[col].isin(filt)] + else: + df = df.loc[df[col] == filt] + return df