diff --git a/copulas/visualization.py b/copulas/visualization.py index 9eea3d5a..eff0ded8 100644 --- a/copulas/visualization.py +++ b/copulas/visualization.py @@ -74,11 +74,7 @@ def _generate_scatter_plot(all_data, columns): return fig -def _generate_column_plot(real_column, - synthetic_column, - plot_kwargs={}, - plot_title=None, - x_label=None): +def _generate_column_plot(real_column, synthetic_column): """Generate a plot of the real and synthetic data. Args: @@ -86,13 +82,6 @@ def _generate_column_plot(real_column, The real data for the desired column. synthetic_column (pandas.Series): The synthetic data for the desired column. - hist_kwargs (dict, optional): - Dictionary of keyword arguments to pass to px.histogram. Keyword arguments - provided this way will overwrite defaults. - plot_title (str, optional): - Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}' - x_label (str, optional): - Label to use for x-axis. Defaults to 'Category'. Returns: plotly.graph_objects._figure.Figure @@ -112,7 +101,7 @@ def _generate_column_plot(real_column, trace_args = {} - fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) + fig = _generate_column_bar_plot(real_data, synthetic_data) for i, name in enumerate(['Real', 'Synthetic']): fig.update_traces( @@ -122,15 +111,10 @@ def _generate_column_plot(real_column, **trace_args ) - if not plot_title: - plot_title = f"Real vs. Synthetic Data for column '{column_name}'" - - if not x_label: - x_label = 'Category' - + plot_title = f"Real vs. Synthetic Data for column '{column_name}'" fig.update_layout( title=plot_title, - xaxis_title=x_label, + xaxis_title='Category', yaxis_title='Frequency', plot_bgcolor=PlotConfig.BACKGROUND_COLOR, annotations=[], @@ -195,7 +179,7 @@ def compare_1d(real, synth): if not isinstance(synth, pd.Series): synth = pd.Series(synth) - return _generate_column_plot(real, synth, plot_type='bar') + return _generate_column_plot(real, synth) def scatter_2d(data, columns=None): @@ -217,7 +201,7 @@ def scatter_2d(data, columns=None): return _generate_scatter_plot(data, columns) -def compare_2d_(real, synth, columns=None): +def compare_2d(real, synth, columns=None): """Return a plot of the real and synthetic data for a given column pair. Args: @@ -231,8 +215,8 @@ def compare_2d_(real, synth, columns=None): Returns: plotly.graph_objects._figure.Figure """ - real_data = real_data[columns] - synthetic_data = synthetic_data[columns] + real_data = real[columns] + synthetic_data = synth[columns] columns = list(real_data.columns) real_data['Data'] = 'Real' synthetic_data['Data'] = 'Synthetic' @@ -241,7 +225,7 @@ def compare_2d_(real, synth, columns=None): return _generate_scatter_plot(all_data, columns) -def scatter_3d_plotly(data, columns=None): +def scatter_3d(data, columns=None): """Return a 3D scatter plot of the data. Args: @@ -289,7 +273,7 @@ def compare_3d(real, synth, columns=None): """ columns = columns or real.columns - fig = scatter_3d_plotly(real[columns]) - fig = scatter_3d_plotly(synth[columns], fig=fig) + fig = scatter_3d(real[columns]) + fig = scatter_3d(synth[columns], fig=fig) return fig