diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d89fe9a..17ab658f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] - YYYY-MM-DD +### Added +- [#644](https://github.com/equinor/webviz-config/pull/644) - Added option to download tables in `DataTable` and `PivotTable`. + ## [0.5.0] - 2022-10-10 ### Changed diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 968a4856..44d25803 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -18,7 +18,7 @@ def test_data_table(dash_duo): code_file = "./tests/data/example_data.csv" with mock.patch(GET_DATA) as mock_path: mock_path.return_value = pd.read_csv(code_file) - page = _data_table.DataTable(code_file) + page = _data_table.DataTable(app, code_file) app.layout = page.layout dash_duo.start_server(app) assert dash_duo.get_logs() == [], "browser console should contain no error" @@ -35,7 +35,7 @@ def test_data_table_with_settings(dash_duo): with mock.patch(GET_DATA) as mock_path: mock_path.return_value = pd.read_csv(code_file) page = _data_table.DataTable( - csv_file=code_file, sorting=False, filtering=False, pagination=False + app, csv_file=code_file, sorting=False, filtering=False, pagination=False ) app.layout = page.layout dash_duo.start_server(app) diff --git a/webviz_config/generic_plugins/_data_table.py b/webviz_config/generic_plugins/_data_table.py index 5553ec3c..54b6b639 100644 --- a/webviz_config/generic_plugins/_data_table.py +++ b/webviz_config/generic_plugins/_data_table.py @@ -1,10 +1,11 @@ +import base64 from pathlib import Path -from typing import List +from typing import List, Optional import pandas as pd -from dash import dash_table +from dash import dash_table, Dash -from .. import WebvizPluginABC +from .. import WebvizPluginABC, EncodedFile from ..webviz_store import webvizstore from ..common_cache import CACHE @@ -27,6 +28,7 @@ class DataTable(WebvizPluginABC): def __init__( self, + app: Dash, csv_file: Path, sorting: bool = True, filtering: bool = True, @@ -41,6 +43,8 @@ def __init__( self.filtering = filtering self.pagination = pagination + self.set_callbacks(app) + def add_webvizstore(self) -> List[tuple]: return [(get_data, [{"csv_file": self.csv_file}])] @@ -54,6 +58,21 @@ def layout(self) -> dash_table.DataTable: page_action="native" if self.pagination else "none", ) + def set_callbacks(self, app: Dash) -> None: + @app.callback(self.plugin_data_output, self.plugin_data_requested) + def _user_download_data(data_requested: Optional[int]) -> Optional[EncodedFile]: + return ( + { + "filename": "data-table.csv", + "content": base64.b64encode( + get_data(self.csv_file).to_csv(index=False).encode() + ).decode("ascii"), + "mime_type": "text/csv", + } + if data_requested + else None + ) + @CACHE.memoize() @webvizstore diff --git a/webviz_config/generic_plugins/_pivot_table.py b/webviz_config/generic_plugins/_pivot_table.py index 8e6beaf0..048b1546 100644 --- a/webviz_config/generic_plugins/_pivot_table.py +++ b/webviz_config/generic_plugins/_pivot_table.py @@ -1,10 +1,12 @@ +import base64 from pathlib import Path -from typing import List +from typing import List, Optional import pandas as pd +from dash import Dash import dash_pivottable -from .. import WebvizPluginABC +from .. import WebvizPluginABC, EncodedFile from ..webviz_store import webvizstore from ..common_cache import CACHE @@ -21,13 +23,15 @@ class PivotTable(WebvizPluginABC): (https://github.com/plotly/dash-pivottable#references) for all possible options. """ - def __init__(self, csv_file: Path, options: dict = None): + def __init__(self, app: Dash, csv_file: Path, options: dict = None): super().__init__() self.csv_file = csv_file self.options = options if options is not None else {} + self.set_callbacks(app) + def add_webvizstore(self) -> List[tuple]: return [(get_data, [{"csv_file": self.csv_file}])] @@ -35,6 +39,21 @@ def add_webvizstore(self) -> List[tuple]: def layout(self) -> dash_pivottable.PivotTable: return generate_table(get_data(self.csv_file), **self.options) + def set_callbacks(self, app: Dash) -> None: + @app.callback(self.plugin_data_output, self.plugin_data_requested) + def _user_download_data(data_requested: Optional[int]) -> Optional[EncodedFile]: + return ( + { + "filename": "pivot-table.csv", + "content": base64.b64encode( + get_data(self.csv_file).to_csv(index=False).encode() + ).decode("ascii"), + "mime_type": "text/csv", + } + if data_requested + else None + ) + def generate_table(dframe: pd.DataFrame, **options: str) -> dash_pivottable.PivotTable: return dash_pivottable.PivotTable(