diff --git a/copulas/visualization.py b/copulas/visualization.py
index ca754793..6df130f0 100644
--- a/copulas/visualization.py
+++ b/copulas/visualization.py
@@ -2,135 +2,274 @@
import pandas as pd
-try:
- import matplotlib.pyplot as plt
-except RuntimeError as e:
- if 'Python is not installed as a framework.' in e.message:
- import matplotlib
- matplotlib.use('PS') # Avoid crash on macos
- import matplotlib.pyplot as plt
+from copulas.utils2 import PlotConfig
+import plotly.express as px
+from pandas.api.types import is_datetime64_dtype
-def scatter_3d(data, columns=None, fig=None, title=None, position=None):
- """Plot 3 dimensional data in a scatter plot."""
- fig = fig or plt.figure()
- position = position or 111
-
- ax = fig.add_subplot(position, projection='3d')
- ax.scatter(*(
- data[column]
- for column in columns or data.columns
- ))
- if title:
- ax.set_title(title)
- ax.title.set_position([.5, 1.05])
-
- return ax
-
-
-def scatter_2d(data, columns=None, fig=None, title=None, position=None):
- """Plot 2 dimensional data in a scatter plot."""
- fig = fig or plt.figure()
- position = position or 111
-
- ax = fig.add_subplot(position)
- columns = columns or data.columns
- if len(columns) != 2:
- raise ValueError('Only 2 columns can be plotted')
-
- x, y = columns
-
- ax.scatter(data[x], data[y])
- plt.xlabel(x)
- plt.ylabel(y)
-
- if title:
- ax.set_title(title)
- ax.title.set_position([.5, 1.05])
-
- return ax
+def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
+ """Generate a bar plot of the real and synthetic data.
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ plot_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+ histogram_kwargs = {
+ 'x': 'values',
+ 'barmode': 'group',
+ 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
+ 'pattern_shape': 'Data',
+ 'pattern_shape_sequence': ['', '/'],
+ 'histnorm': 'probability density',
+ }
+ histogram_kwargs.update(plot_kwargs)
+ fig = px.histogram(
+ all_data,
+ **histogram_kwargs
+ )
+
+ return fig
+
+def _generate_scatter_plot(all_data, columns):
+ """Generate a scatter plot for column pair plot.
-def hist_1d(data, fig=None, title=None, position=None, bins=20, label=None):
- """Plot 1 dimensional data in a histogram."""
- fig = fig or plt.figure()
- position = position or 111
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
- ax = fig.add_subplot(position)
- ax.hist(data, density=True, bins=bins, alpha=0.8, label=label)
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.scatter(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ symbol='Data'
+ )
+
+ fig.update_layout(
+ title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+def _generate_column_plot(real_column,
+ synthetic_column,
+ plot_kwargs={},
+ plot_title=None,
+ x_label=None):
+ """Generate a plot of the real and synthetic data.
- if label:
- ax.legend()
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ hist_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+ plot_title (str, optional):
+ Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}'
+ x_label (str, optional):
+ Label to use for x-axis. Defaults to 'Category'.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ column_name = real_column.name if hasattr(real_column, 'name') else ''
+
+ real_data = pd.DataFrame({'values': real_column.copy().dropna()})
+ real_data['Data'] = 'Real'
+ synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
+ synthetic_data['Data'] = 'Synthetic'
+
+ is_datetime_sdtype = False
+ if is_datetime64_dtype(real_column.dtype):
+ is_datetime_sdtype = True
+ real_data['values'] = real_data['values'].astype('int64')
+ synthetic_data['values'] = synthetic_data['values'].astype('int64')
+
+ trace_args = {}
+
+ fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
+
+ for i, name in enumerate(['Real', 'Synthetic']):
+ fig.update_traces(
+ x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x,
+ hovertemplate=f'{name}
Frequency: %{{y}}',
+ selector={'name': name},
+ **trace_args
+ )
+
+ if not plot_title:
+ plot_title = f"Real vs. Synthetic Data for column '{column_name}'"
+
+ if not x_label:
+ x_label = 'Category'
+
+ fig.update_layout(
+ title=plot_title,
+ xaxis_title=x_label,
+ yaxis_title='Frequency',
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ annotations=[],
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+ return fig
+
+def hist_1d(data, title=None, bins=20, label=None):
+ """Plot 1 dimensional data in a histogram.
+
+ Args:
+ data (pd.DataFrame):
+ The table data.
+ title (str):
+ The title of the plot.
+ bins (int):
+ The number of bins to use for the histogram.
+ label (str):
+ The label of the plot.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if not isinstance(data, pd.DataFrame):
+ data = pd.DataFrame(data)
+ if len(data.columns) > 1:
+ raise ValueError('Only 1 column can be plotted')
+
+ fig = px.histogram(
+ data_frame=data,
+ barmode='group',
+ color_discrete_sequence=[PlotConfig.DATACEBO_DARK],
+ histnorm='probability density',
+ title=title,
+ nbins=bins,
+ )
+ fig.update_layout(
+ xaxis_title='',
+ yaxis_title='',
+ legend_title=label,
+ showlegend=True if label else False
+ )
+
+ return fig
+
+def compare_1d(real, synth):
+ """Return a plot of the real and synthetic data for a given column.
- if title:
- ax.set_title(title)
- ax.title.set_position([.5, 1.05])
+ Args:
+ real (pandas.DataFrame):
+ The real table data.
+ synth (pandas.DataFrame):
+ The synthetic table data.
- return ax
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if not isinstance(real, pd.Series):
+ real = pd.Series(real)
+ if not isinstance(synth, pd.Series):
+ synth = pd.Series(synth)
+ return _generate_column_plot(real, synth, plot_type='bar')
-def side_by_side(plotting_func, arrays):
- """Plot side-by-side figures.
+def scatter_2d(data, columns=None):
+ """Plot 2 dimensional data in a scatter plot.
Args:
- plotting_func (callable):
- A matplotlib function which takes in the standard plot kwargs.
- arrays (dict[str, np.ndarray]):
- A mapping from the name of the subplot to the values.
- """
- fig = plt.figure(figsize=(10, 4))
+ data (pandas.DataFrame):
+ The table data.
+ columns (list[string]):
+ The names of the two columns to plot.
- position_base = f'1{len(arrays)}'
- for index, (title, array) in enumerate(arrays.items()):
- position = int(position_base + str(index + 1))
- plotting_func(array, fig=fig, title=title, position=position)
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ data = data[columns]
+ columns = list(data.columns)
+ data['Data'] = 'Real'
- plt.tight_layout()
+ return _generate_scatter_plot(data, columns)
-
-def compare_3d(real, synth, columns=None, figsize=(10, 4)):
- """Generate a 3d scatter plot comparing real/synthetic data.
+def compare_2d_(real, synth, columns=None):
+ """Return a plot of the real and synthetic data for a given column pair.
Args:
- real (pd.DataFrame):
- The real data.
- synth (pd.DataFrame):
- The synthetic data.
- columns (list):
- The name of the columns to plot.
- figsize:
- Figure size, passed to matplotlib.
+ real (pandas.DataFrame):
+ The real table data.
+ synth (pandas.Dataframe):
+ The synthetic table data.
+ columns (list[string]):
+ The names of the two columns to plot.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
"""
- columns = columns or real.columns
- fig = plt.figure(figsize=figsize)
-
- scatter_3d(real[columns], fig=fig, title='Real Data', position=121)
- scatter_3d(synth[columns], fig=fig, title='Synthetic Data', position=122)
+ real_data = real_data[columns]
+ synthetic_data = synthetic_data[columns]
+ columns = list(real_data.columns)
+ real_data['Data'] = 'Real'
+ synthetic_data['Data'] = 'Synthetic'
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
- plt.tight_layout()
+ return _generate_scatter_plot(all_data, columns)
-
-def compare_2d(real, synth, columns=None, figsize=None):
- """Generate a 2d scatter plot comparing real/synthetic data.
+def scatter_3d_plotly(data, columns=None):
+ """Return a 3D scatter plot of the data.
Args:
- real (pd.DataFrame):
- The real data.
- synth (pd.DataFrame):
- The synthetic data.
- columns (list):
- The name of the columns to plot.
- figsize:
- Figure size, passed to matplotlib.
+ data (pandas.DataFrame):
+ The table data. Must have at least 3 columns.
+ column_names (list[string]):
+ The names of the three columns to plot. If not passed,
+ the first three columns of the data will be used.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
"""
- x, y = columns or real.columns
- ax = real.plot.scatter(x, y, color='blue', alpha=0.5, figsize=figsize)
- ax = synth.plot.scatter(x, y, ax=ax, color='orange', alpha=0.5, figsize=figsize)
- ax.legend(['Real', 'Synthetic'])
-
-
-def compare_1d(real, synth, columns=None, figsize=None):
- """Generate a 1d scatter plot comparing real/synthetic data.
+ fig = px.scatter(
+ data,
+ x=columns[0],
+ y=columns[1],
+ z=columns[2],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ symbol='Data'
+ )
+
+ fig.update_layout(
+ title=f"Data for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+def compare_3d(real, synth, columns=None):
+ """Generate a 3d scatter plot comparing real/synthetic data.
Args:
real (pd.DataFrame):
@@ -139,26 +278,10 @@ def compare_1d(real, synth, columns=None, figsize=None):
The synthetic data.
columns (list):
The name of the columns to plot.
- figsize:
- Figure size, passed to matplotlib.
"""
- if len(real.shape) == 1:
- real = pd.DataFrame({'': real})
- synth = pd.DataFrame({'': synth})
-
columns = columns or real.columns
- num_cols = len(columns)
- fig_cols = min(2, num_cols)
- fig_rows = (num_cols // fig_cols) + 1
- prefix = f'{fig_rows}{fig_cols}'
-
- figsize = figsize or (5 * fig_cols, 3 * fig_rows)
- fig = plt.figure(figsize=figsize)
-
- for idx, column in enumerate(columns):
- position = int(prefix + str(idx + 1))
- hist_1d(real[column], fig=fig, position=position, title=column, label='Real')
- hist_1d(synth[column], fig=fig, position=position, title=column, label='Synthetic')
+ fig = scatter_3d_plotly(real[columns])
+ fig = scatter_3d_plotly(synth[columns], fig=fig)
- plt.tight_layout()
+ return fig
diff --git a/setup.py b/setup.py
index 1a757695..e2897af9 100644
--- a/setup.py
+++ b/setup.py
@@ -12,8 +12,6 @@
history = history_file.read()
install_requires = [
- "matplotlib>=3.4.0,<4;python_version<'3.10'",
- "matplotlib>=3.6.0,<4;python_version>='3.10'",
"numpy>=1.20.0,<2;python_version<'3.10'",
"numpy>=1.23.3,<2;python_version>='3.10'",
"pandas>=1.1.3;python_version<'3.10'",
@@ -21,6 +19,7 @@
"pandas>=1.5.0;python_version>='3.11'",
"scipy>=1.5.4,<2;python_version<'3.10'",
"scipy>=1.9.2,<2;python_version>='3.10'",
+ 'plotly>=5.10.0,<6',
]
development_requires = [