diff --git a/copulas/visualization.py b/copulas/visualization.py index ca754793..6df130f0 100644 --- a/copulas/visualization.py +++ b/copulas/visualization.py @@ -2,135 +2,274 @@ import pandas as pd -try: - import matplotlib.pyplot as plt -except RuntimeError as e: - if 'Python is not installed as a framework.' in e.message: - import matplotlib - matplotlib.use('PS') # Avoid crash on macos - import matplotlib.pyplot as plt +from copulas.utils2 import PlotConfig +import plotly.express as px +from pandas.api.types import is_datetime64_dtype -def scatter_3d(data, columns=None, fig=None, title=None, position=None): - """Plot 3 dimensional data in a scatter plot.""" - fig = fig or plt.figure() - position = position or 111 - - ax = fig.add_subplot(position, projection='3d') - ax.scatter(*( - data[column] - for column in columns or data.columns - )) - if title: - ax.set_title(title) - ax.title.set_position([.5, 1.05]) - - return ax - - -def scatter_2d(data, columns=None, fig=None, title=None, position=None): - """Plot 2 dimensional data in a scatter plot.""" - fig = fig or plt.figure() - position = position or 111 - - ax = fig.add_subplot(position) - columns = columns or data.columns - if len(columns) != 2: - raise ValueError('Only 2 columns can be plotted') - - x, y = columns - - ax.scatter(data[x], data[y]) - plt.xlabel(x) - plt.ylabel(y) - - if title: - ax.set_title(title) - ax.title.set_position([.5, 1.05]) - - return ax +def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): + """Generate a bar plot of the real and synthetic data. + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + plot_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + + Returns: + plotly.graph_objects._figure.Figure + """ + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + histogram_kwargs = { + 'x': 'values', + 'barmode': 'group', + 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'pattern_shape': 'Data', + 'pattern_shape_sequence': ['', '/'], + 'histnorm': 'probability density', + } + histogram_kwargs.update(plot_kwargs) + fig = px.histogram( + all_data, + **histogram_kwargs + ) + + return fig + +def _generate_scatter_plot(all_data, columns): + """Generate a scatter plot for column pair plot. -def hist_1d(data, fig=None, title=None, position=None, bins=20, label=None): - """Plot 1 dimensional data in a histogram.""" - fig = fig or plt.figure() - position = position or 111 + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. - ax = fig.add_subplot(position) - ax.hist(data, density=True, bins=bins, alpha=0.8, label=label) + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.scatter( + all_data, + x=columns[0], + y=columns[1], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + symbol='Data' + ) + + fig.update_layout( + title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + +def _generate_column_plot(real_column, + synthetic_column, + plot_kwargs={}, + plot_title=None, + x_label=None): + """Generate a plot of the real and synthetic data. - if label: - ax.legend() + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + hist_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + plot_title (str, optional): + Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}' + x_label (str, optional): + Label to use for x-axis. Defaults to 'Category'. + + Returns: + plotly.graph_objects._figure.Figure + """ + column_name = real_column.name if hasattr(real_column, 'name') else '' + + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) + real_data['Data'] = 'Real' + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) + synthetic_data['Data'] = 'Synthetic' + + is_datetime_sdtype = False + if is_datetime64_dtype(real_column.dtype): + is_datetime_sdtype = True + real_data['values'] = real_data['values'].astype('int64') + synthetic_data['values'] = synthetic_data['values'].astype('int64') + + trace_args = {} + + fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) + + for i, name in enumerate(['Real', 'Synthetic']): + fig.update_traces( + x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, + hovertemplate=f'{name}
Frequency: %{{y}}', + selector={'name': name}, + **trace_args + ) + + if not plot_title: + plot_title = f"Real vs. Synthetic Data for column '{column_name}'" + + if not x_label: + x_label = 'Category' + + fig.update_layout( + title=plot_title, + xaxis_title=x_label, + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=[], + font={'size': PlotConfig.FONT_SIZE}, + ) + return fig + +def hist_1d(data, title=None, bins=20, label=None): + """Plot 1 dimensional data in a histogram. + + Args: + data (pd.DataFrame): + The table data. + title (str): + The title of the plot. + bins (int): + The number of bins to use for the histogram. + label (str): + The label of the plot. + + Returns: + plotly.graph_objects._figure.Figure + """ + if not isinstance(data, pd.DataFrame): + data = pd.DataFrame(data) + if len(data.columns) > 1: + raise ValueError('Only 1 column can be plotted') + + fig = px.histogram( + data_frame=data, + barmode='group', + color_discrete_sequence=[PlotConfig.DATACEBO_DARK], + histnorm='probability density', + title=title, + nbins=bins, + ) + fig.update_layout( + xaxis_title='', + yaxis_title='', + legend_title=label, + showlegend=True if label else False + ) + + return fig + +def compare_1d(real, synth): + """Return a plot of the real and synthetic data for a given column. - if title: - ax.set_title(title) - ax.title.set_position([.5, 1.05]) + Args: + real (pandas.DataFrame): + The real table data. + synth (pandas.DataFrame): + The synthetic table data. - return ax + Returns: + plotly.graph_objects._figure.Figure + """ + if not isinstance(real, pd.Series): + real = pd.Series(real) + if not isinstance(synth, pd.Series): + synth = pd.Series(synth) + return _generate_column_plot(real, synth, plot_type='bar') -def side_by_side(plotting_func, arrays): - """Plot side-by-side figures. +def scatter_2d(data, columns=None): + """Plot 2 dimensional data in a scatter plot. Args: - plotting_func (callable): - A matplotlib function which takes in the standard plot kwargs. - arrays (dict[str, np.ndarray]): - A mapping from the name of the subplot to the values. - """ - fig = plt.figure(figsize=(10, 4)) + data (pandas.DataFrame): + The table data. + columns (list[string]): + The names of the two columns to plot. - position_base = f'1{len(arrays)}' - for index, (title, array) in enumerate(arrays.items()): - position = int(position_base + str(index + 1)) - plotting_func(array, fig=fig, title=title, position=position) + Returns: + plotly.graph_objects._figure.Figure + """ + data = data[columns] + columns = list(data.columns) + data['Data'] = 'Real' - plt.tight_layout() + return _generate_scatter_plot(data, columns) - -def compare_3d(real, synth, columns=None, figsize=(10, 4)): - """Generate a 3d scatter plot comparing real/synthetic data. +def compare_2d_(real, synth, columns=None): + """Return a plot of the real and synthetic data for a given column pair. Args: - real (pd.DataFrame): - The real data. - synth (pd.DataFrame): - The synthetic data. - columns (list): - The name of the columns to plot. - figsize: - Figure size, passed to matplotlib. + real (pandas.DataFrame): + The real table data. + synth (pandas.Dataframe): + The synthetic table data. + columns (list[string]): + The names of the two columns to plot. + + Returns: + plotly.graph_objects._figure.Figure """ - columns = columns or real.columns - fig = plt.figure(figsize=figsize) - - scatter_3d(real[columns], fig=fig, title='Real Data', position=121) - scatter_3d(synth[columns], fig=fig, title='Synthetic Data', position=122) + real_data = real_data[columns] + synthetic_data = synthetic_data[columns] + columns = list(real_data.columns) + real_data['Data'] = 'Real' + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) - plt.tight_layout() + return _generate_scatter_plot(all_data, columns) - -def compare_2d(real, synth, columns=None, figsize=None): - """Generate a 2d scatter plot comparing real/synthetic data. +def scatter_3d_plotly(data, columns=None): + """Return a 3D scatter plot of the data. Args: - real (pd.DataFrame): - The real data. - synth (pd.DataFrame): - The synthetic data. - columns (list): - The name of the columns to plot. - figsize: - Figure size, passed to matplotlib. + data (pandas.DataFrame): + The table data. Must have at least 3 columns. + column_names (list[string]): + The names of the three columns to plot. If not passed, + the first three columns of the data will be used. + + Returns: + plotly.graph_objects._figure.Figure """ - x, y = columns or real.columns - ax = real.plot.scatter(x, y, color='blue', alpha=0.5, figsize=figsize) - ax = synth.plot.scatter(x, y, ax=ax, color='orange', alpha=0.5, figsize=figsize) - ax.legend(['Real', 'Synthetic']) - - -def compare_1d(real, synth, columns=None, figsize=None): - """Generate a 1d scatter plot comparing real/synthetic data. + fig = px.scatter( + data, + x=columns[0], + y=columns[1], + z=columns[2], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + symbol='Data' + ) + + fig.update_layout( + title=f"Data for columns '{columns[0]}', '{columns[1]}' and '{columns[2]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + +def compare_3d(real, synth, columns=None): + """Generate a 3d scatter plot comparing real/synthetic data. Args: real (pd.DataFrame): @@ -139,26 +278,10 @@ def compare_1d(real, synth, columns=None, figsize=None): The synthetic data. columns (list): The name of the columns to plot. - figsize: - Figure size, passed to matplotlib. """ - if len(real.shape) == 1: - real = pd.DataFrame({'': real}) - synth = pd.DataFrame({'': synth}) - columns = columns or real.columns - num_cols = len(columns) - fig_cols = min(2, num_cols) - fig_rows = (num_cols // fig_cols) + 1 - prefix = f'{fig_rows}{fig_cols}' - - figsize = figsize or (5 * fig_cols, 3 * fig_rows) - fig = plt.figure(figsize=figsize) - - for idx, column in enumerate(columns): - position = int(prefix + str(idx + 1)) - hist_1d(real[column], fig=fig, position=position, title=column, label='Real') - hist_1d(synth[column], fig=fig, position=position, title=column, label='Synthetic') + fig = scatter_3d_plotly(real[columns]) + fig = scatter_3d_plotly(synth[columns], fig=fig) - plt.tight_layout() + return fig diff --git a/setup.py b/setup.py index 1a757695..e2897af9 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,6 @@ history = history_file.read() install_requires = [ - "matplotlib>=3.4.0,<4;python_version<'3.10'", - "matplotlib>=3.6.0,<4;python_version>='3.10'", "numpy>=1.20.0,<2;python_version<'3.10'", "numpy>=1.23.3,<2;python_version>='3.10'", "pandas>=1.1.3;python_version<'3.10'", @@ -21,6 +19,7 @@ "pandas>=1.5.0;python_version>='3.11'", "scipy>=1.5.4,<2;python_version<'3.10'", "scipy>=1.9.2,<2;python_version>='3.10'", + 'plotly>=5.10.0,<6', ] development_requires = [