diff --git a/copulas/sdmetrics.py b/copulas/sdmetrics.py new file mode 100644 index 00000000..bafe1d52 --- /dev/null +++ b/copulas/sdmetrics.py @@ -0,0 +1,480 @@ +"""Visualization methods for SDMetrics.""" + +import pandas as pd +import plotly.express as px +import plotly.figure_factory as ff +from pandas.api.types import is_datetime64_dtype + +from copulas.utils import get_missing_percentage, is_datetime +from copulas.utils2 import PlotConfig + + +def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): + """Generate a bar plot of the real and synthetic data. + + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + plot_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + + Returns: + plotly.graph_objects._figure.Figure + """ + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + histogram_kwargs = { + 'x': 'values', + # 'color': 'Data', + 'barmode': 'group', + 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + 'pattern_shape': 'Data', + 'pattern_shape_sequence': ['', '/'], + 'histnorm': 'probability density', + } + histogram_kwargs.update(plot_kwargs) + fig = px.histogram( + all_data, + **histogram_kwargs + ) + + return fig + + +def _generate_heatmap_plot(all_data, columns): + """Generate heatmap plot for discrete data. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.density_heatmap( + all_data, + x=columns[0], + y=columns[1], + facet_col='Data', + histnorm='probability' + ) + + fig.update_layout( + title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]}, + font={'size': PlotConfig.FONT_SIZE}, + ) + + fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data')) + + return fig + + +def _generate_box_plot(all_data, columns): + """Generate a box plot for mixed discrete and continuous column data. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.box( + all_data, + x=columns[0], + y=columns[1], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + ) + + fig.update_layout( + title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + + +def _generate_scatter_plot(all_data, columns): + """Generate a scatter plot for column pair plot. + + Args: + all_data (pandas.DataFrame): + The real and synthetic data for the desired column pair containing a + ``Data`` column that indicates whether is real or synthetic. + columns (list): + A list of the columns being plotted. + + Returns: + plotly.graph_objects._figure.Figure + """ + fig = px.scatter( + all_data, + x=columns[0], + y=columns[1], + color='Data', + color_discrete_map={ + 'Real': PlotConfig.DATACEBO_DARK, + 'Synthetic': PlotConfig.DATACEBO_GREEN + }, + symbol='Data' + ) + + fig.update_layout( + title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'", + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig + + +def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}): + """Plot the real and synthetic data as a distplot. + + Args: + real_data (pandas.DataFrame): + The real data for the desired column. + synthetic_data (pandas.DataFrame): + The synthetic data for the desired column. + plot_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + + Returns: + plotly.graph_objects._figure.Figure + """ + default_distplot_kwargs = { + 'show_hist': False, + 'show_rug': False, + 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN] + } + + fig = ff.create_distplot( + [real_data['values'], synthetic_data['values']], + ['Real', 'Synthetic'], + **{**default_distplot_kwargs, **plot_kwargs} + ) + + return fig + + +def _generate_column_plot(real_column, + synthetic_column, + plot_type, + plot_kwargs={}, + plot_title=None, + x_label=None): + """Generate a plot of the real and synthetic data. + + Args: + real_column (pandas.Series): + The real data for the desired column. + synthetic_column (pandas.Series): + The synthetic data for the desired column. + plot_type (str): + The type of plot to use. Must be one of 'bar' or 'distplot'. + hist_kwargs (dict, optional): + Dictionary of keyword arguments to pass to px.histogram. Keyword arguments + provided this way will overwrite defaults. + plot_title (str, optional): + Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}' + x_label (str, optional): + Label to use for x-axis. Defaults to 'Category'. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot']: + raise ValueError( + "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'" + ) + + column_name = real_column.name if hasattr(real_column, 'name') else '' + + missing_data_real = get_missing_percentage(real_column) + missing_data_synthetic = get_missing_percentage(synthetic_column) + + real_data = pd.DataFrame({'values': real_column.copy().dropna()}) + real_data['Data'] = 'Real' + synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()}) + synthetic_data['Data'] = 'Synthetic' + + is_datetime_sdtype = False + if is_datetime64_dtype(real_column.dtype): + is_datetime_sdtype = True + real_data['values'] = real_data['values'].astype('int64') + synthetic_data['values'] = synthetic_data['values'].astype('int64') + + trace_args = {} + + if plot_type == 'bar': + fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs) + elif plot_type == 'distplot': + x_label = x_label or 'Value' + fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs) + trace_args = {'fill': 'tozeroy'} + + for i, name in enumerate(['Real', 'Synthetic']): + fig.update_traces( + x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x, + hovertemplate=f'{name}
Frequency: %{{y}}', + selector={'name': name}, + **trace_args + ) + + show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 + annotations = [] if not show_missing_values else [ + { + 'xref': 'paper', + 'yref': 'paper', + 'x': 1.0, + 'y': 1.05, + 'showarrow': False, + 'text': ( + f'*Missing Values: Real Data ({missing_data_real}%), ' + f'Synthetic Data ({missing_data_synthetic}%)' + ), + }, + ] + + if not plot_title: + plot_title = f"Real vs. Synthetic Data for column '{column_name}'" + + if not x_label: + x_label = 'Category' + + fig.update_layout( + title=plot_title, + xaxis_title=x_label, + yaxis_title='Frequency', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + annotations=annotations, + font={'size': PlotConfig.FONT_SIZE}, + ) + return fig + + +def _generate_cardinality_plot(real_data, + synthetic_data, + parent_primary_key, + child_foreign_key, + plot_type='bar'): + plot_title = ( + f"Relationship (child foreign key='{child_foreign_key}' and parent " + f"primary key='{parent_primary_key}')" + ) + x_label = '# of Children (per Parent)' + + plot_kwargs = {} + if plot_type == 'bar': + max_cardinality = max(max(real_data), max(synthetic_data)) + min_cardinality = min(min(real_data), min(synthetic_data)) + plot_kwargs = { + 'nbins': max_cardinality - min_cardinality + 1 + } + + return _generate_column_plot(real_data, synthetic_data, plot_type, + plot_kwargs, plot_title, x_label) + + +def _get_cardinality(parent_table, child_table, parent_primary_key, child_foreign_key): + """Return the cardinality of the parent-child relationship. + + Args: + parent_table (pandas.DataFrame): + The parent table. + child_table (pandas.DataFrame): + The child table. + parent_primary_key (string): + The name of the primary key column in the parent table. + child_foreign_key (string): + The name of the foreign key column in the child table. + + Returns: + pandas.DataFrame + """ + child_counts = child_table[child_foreign_key].value_counts().rename('# children') + cardinalities = child_counts.reindex(parent_table[parent_primary_key], fill_value=0).to_frame() + + return cardinalities.sort_values('# children')['# children'] + + +def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name, + child_foreign_key, parent_primary_key, plot_type='bar'): + """Return a plot of the cardinality of the parent-child relationship. + + Args: + real_data (dict): + The real data. + synthetic_data (dict): + The synthetic data. + child_table_name (string): + The name of the child table. + parent_table_name (string): + The name of the parent table. + child_foreign_key (string): + The name of the foreign key column in the child table. + parent_primary_key (string): + The name of the primary key column in the parent table. + plot_type (string, optional): + The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'. + Defaults to 'bar'. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot']: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].") + + real_cardinality = _get_cardinality( + real_data[parent_table_name], real_data[child_table_name], + parent_primary_key, child_foreign_key + ) + synth_cardinality = _get_cardinality( + synthetic_data[parent_table_name], + synthetic_data[child_table_name], + parent_primary_key, child_foreign_key + ) + + fig = _generate_cardinality_plot( + real_cardinality, + synth_cardinality, + parent_primary_key, + child_foreign_key, + plot_type=plot_type + ) + + return fig + + +def get_column_plot(real_data, synthetic_data, column_name, plot_type=None): + """Return a plot of the real and synthetic data for a given column. + + Args: + real_data (pandas.DataFrame): + The real table data. + synthetic_data (pandas.DataFrame): + The synthetic table data. + column_name (str): + The name of the column. + plot_type (str or None): + The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None` + select between ``distplot`` or ``bar`` depending on the data that the column contains, + ``distplot`` for datetime and numerical values and ``bar`` for categorical. + Defaults to ``None``. + + Returns: + plotly.graph_objects._figure.Figure + """ + if plot_type not in ['bar', 'distplot', None]: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]." + ) + + if column_name not in real_data.columns: + raise ValueError(f"Column '{column_name}' not found in real table data.") + if column_name not in synthetic_data.columns: + raise ValueError(f"Column '{column_name}' not found in synthetic table data.") + + real_column = real_data[column_name] + if plot_type is None: + column_is_datetime = is_datetime(real_data[column_name]) + dtype = real_column.dropna().infer_objects().dtype.kind + if column_is_datetime or dtype in ('i', 'f'): + plot_type = 'distplot' + else: + plot_type = 'bar' + + real_column = real_data[column_name] + synthetic_column = synthetic_data[column_name] + + fig = _generate_column_plot(real_column, synthetic_column, plot_type) + + return fig + + +def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None): + """Return a plot of the real and synthetic data for a given column pair. + + Args: + real_data (pandas.DataFrame): + The real table data. + synthetic_column (pandas.Dataframe): + The synthetic table data. + column_names (list[string]): + The names of the two columns to plot. + plot_type (str or None): + The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``. + If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data + that the column contains, ``scatter`` used for datetime and numerical values, + ``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``. + + Returns: + plotly.graph_objects._figure.Figure + """ + if len(column_names) != 2: + raise ValueError('Must provide exactly two column names.') + + if not set(column_names).issubset(real_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.' + ) + + if not set(column_names).issubset(synthetic_data.columns): + raise ValueError( + f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} ' + 'in synthetic data.' + ) + + if plot_type not in ['box', 'heatmap', 'scatter', None]: + raise ValueError( + f"Invalid plot_type '{plot_type}'. Please use one of " + "['box', 'heatmap', 'scatter', None]." + ) + + real_data = real_data[column_names] + synthetic_data = synthetic_data[column_names] + if plot_type is None: + plot_type = [] + for column_name in column_names: + column = real_data[column_name] + dtype = column.dropna().infer_objects().dtype.kind + if dtype in ('i', 'f') or is_datetime(column): + plot_type.append('scatter') + else: + plot_type.append('heatmap') + + if len(set(plot_type)) > 1: + plot_type = 'box' + else: + plot_type = plot_type.pop() + + # Merge the real and synthetic data and add a flag ``Data`` to indicate each one. + columns = list(real_data.columns) + real_data = real_data.copy() + real_data['Data'] = 'Real' + synthetic_data = synthetic_data.copy() + synthetic_data['Data'] = 'Synthetic' + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) + + if plot_type == 'scatter': + return _generate_scatter_plot(all_data, columns) + elif plot_type == 'heatmap': + return _generate_heatmap_plot(all_data, columns) + + return _generate_box_plot(all_data, columns) diff --git a/copulas/utils.py b/copulas/utils.py new file mode 100644 index 00000000..30e56ebe --- /dev/null +++ b/copulas/utils.py @@ -0,0 +1,298 @@ +"""SDMetrics utils to be used across all the project.""" + +from collections import Counter +from datetime import datetime + +import numpy as np +import pandas as pd +from sklearn.preprocessing import OneHotEncoder + + +def nested_attrs_meta(nested): + """Metaclass factory that defines a Metaclass with a dynamic attribute name.""" + + class Metaclass(type): + """Metaclass which pulls the attributes from a nested object using properties.""" + + def __getattr__(cls, attr): + """If cls does not have the attribute, try to get it from the nested object.""" + nested_obj = getattr(cls, nested) + if hasattr(nested_obj, attr): + return getattr(nested_obj, attr) + + raise AttributeError(f"type object '{cls.__name__}' has no attribute '{attr}'") + + @property + def name(cls): + return getattr(cls, nested).name + + @property + def goal(cls): + return getattr(cls, nested).goal + + @property + def max_value(cls): + return getattr(cls, nested).max_value + + @property + def min_value(cls): + return getattr(cls, nested).min_value + + return Metaclass + + +def get_frequencies(real, synthetic): + """Get percentual frequencies for each possible real categorical value. + + Given two iterators containing categorical data, this transforms it into + observed/expected frequencies which can be used for statistical tests. It + adds a regularization term to handle cases where the synthetic data contains + values that don't exist in the real data. + + Args: + real (list): + A list of hashable objects. + synthetic (list): + A list of hashable objects. + + Yields: + tuble[list, list]: + The observed and expected frequencies (as a percent). + """ + f_obs, f_exp = [], [] + real, synthetic = Counter(real), Counter(synthetic) + for value in synthetic: + if value not in real: + real[value] += 1e-6 # Regularization to prevent NaN. + + for value in real: + f_obs.append(synthetic[value] / sum(synthetic.values())) # noqa: PD011 + f_exp.append(real[value] / sum(real.values())) # noqa: PD011 + + return f_obs, f_exp + + +def get_missing_percentage(data_column): + """Compute the missing value percentage of a column. + + Args: + data_column (pandas.Series): + The data of the desired column. + + Returns: + pandas.Series: + Percentage of missing values inside the column. + """ + return round((data_column.isna().sum() / len(data_column)) * 100, 2) + + +def get_cardinality_distribution(parent_column, child_column): + """Compute the cardinality distribution of the (parent, child) pairing. + + Args: + parent_column (pandas.Series): + The parent column. + child_column (pandas.Series): + The child column. + + Returns: + pandas.Series: + The cardinality distribution. + """ + child_df = pd.DataFrame({'child_counts': child_column.value_counts()}) + cardinality_df = pd.DataFrame({'parent': parent_column}).join( + child_df, on='parent').fillna(0) + + return cardinality_df['child_counts'] + + +def is_datetime(data): + """Determine if the input is a datetime type or not. + + Args: + data (pandas.DataFrame, int or datetime): + Input to evaluate. + + Returns: + bool: + True if the input is a datetime type, False if not. + """ + return ( + pd.api.types.is_datetime64_any_dtype(data) + or isinstance(data, pd.Timestamp) + or isinstance(data, datetime) + ) + + +class HyperTransformer(): + """HyperTransformer class. + + The ``HyperTransformer`` class contains a set of transforms to transform one or + more columns based on each column's data type. + """ + + column_transforms = {} + column_kind = {} + + def fit(self, data): + """Fit the HyperTransformer to the given data. + + Args: + data (pandas.DataFrame): + The data to transform. + """ + if not isinstance(data, pd.DataFrame): + data = pd.DataFrame(data) + + for field in data: + kind = data[field].dropna().infer_objects().dtype.kind + self.column_kind[field] = kind + + if kind == 'i' or kind == 'f': + # Numerical column. + self.column_transforms[field] = {'mean': data[field].mean()} + elif kind == 'b': + # Boolean column. + numeric = pd.to_numeric(data[field], errors='coerce').astype(float) + self.column_transforms[field] = {'mode': numeric.mode().iloc[0]} + elif kind == 'O': + # Categorical column. + col_data = pd.DataFrame({'field': data[field]}) + enc = OneHotEncoder() + enc.fit(col_data) + self.column_transforms[field] = {'one_hot_encoder': enc} + elif kind == 'M': + # Datetime column. + nulls = data[field].isna() + integers = pd.to_numeric( + data[field], errors='coerce').to_numpy().astype(np.float64) + integers[nulls] = np.nan + self.column_transforms[field] = {'mean': pd.Series(integers).mean()} + + def transform(self, data): + """Transform the given data based on the data type of each column. + + Args: + data (pandas.DataFrame): + The data to transform. + + Returns: + pandas.DataFrame: + The transformed data. + """ + if not isinstance(data, pd.DataFrame): + data = pd.DataFrame(data) + + for field in data: + transform_info = self.column_transforms[field] + + kind = self.column_kind[field] + if kind == 'i' or kind == 'f': + # Numerical column. + data[field] = data[field].fillna(transform_info['mean']) + elif kind == 'b': + # Boolean column. + data[field] = pd.to_numeric(data[field], errors='coerce').astype(float) + data[field] = data[field].fillna(transform_info['mode']) + elif kind == 'O': + # Categorical column. + col_data = pd.DataFrame({'field': data[field]}) + out = transform_info['one_hot_encoder'].transform(col_data).toarray() + transformed = pd.DataFrame( + out, columns=[f'value{i}' for i in range(np.shape(out)[1])]) + data = data.drop(columns=[field]) + data = pd.concat([data, transformed.set_index(data.index)], axis=1) + elif kind == 'M': + # Datetime column. + nulls = data[field].isna() + integers = pd.to_numeric( + data[field], errors='coerce').to_numpy().astype(np.float64) + integers[nulls] = np.nan + data[field] = pd.Series(integers) + data[field] = data[field].fillna(transform_info['mean']) + + return data + + def fit_transform(self, data): + """Fit and transform the given data based on the data type of each column. + + Args: + data (pandas.DataFrame): + The data to transform. + + Returns: + pandas.DataFrame: + The transformed data. + """ + self.fit(data) + return self.transform(data) + + +def get_columns_from_metadata(metadata): + """Get the column info from a metadata dict. + + Args: + metadata (dict): + The metadata dict. + + Returns: + dict: + The columns metadata. + """ + return metadata.get('columns', {}) + + +def get_type_from_column_meta(column_metadata): + """Get the type of a given column from the column metadata. + + Args: + column_metadata (dict): + The column metadata. + + Returns: + string: + The column type. + """ + return column_metadata.get('sdtype', '') + + +def get_alternate_keys(metadata): + """Get the alternate keys from a metadata dict. + + Args: + metadata (dict): + The metadata dict. + + Returns: + list: + The list of alternate keys. + """ + alternate_keys = [] + for alternate_key in metadata.get('alternate_keys', []): + if isinstance(alternate_key, list): + alternate_keys.extend(alternate_key) + else: + alternate_keys.append(alternate_key) + + return alternate_keys + + +def strip_characters(list_character, a_string): + """Strip characters from a column name. + + Args: + list_character (list): + The list of characters to strip. + a_string (string): + The string to be stripped. + + Returns: + string: + The string with the characters stripped. + """ + result = a_string + for character in list_character: + if character in result: + result = result.replace(character, '') + + return result diff --git a/copulas/utils2.py b/copulas/utils2.py new file mode 100644 index 00000000..560afcf6 --- /dev/null +++ b/copulas/utils2.py @@ -0,0 +1,233 @@ +"""Report utility methods.""" + +import copy +import itertools +import warnings + +import numpy as np +import pandas as pd +from pandas.core.tools.datetimes import _guess_datetime_format_for_array + +from copulas.utils import ( + get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta, is_datetime) + +CONTINUOUS_SDTYPES = ['numerical', 'datetime'] +DISCRETE_SDTYPES = ['categorical', 'boolean'] + + +class PlotConfig: + """Custom plot settings for visualizations.""" + + GREEN = '#36B37E' + RED = '#FF0000' + ORANGE = '#F16141' + DATACEBO_DARK = '#000036' + DATACEBO_GREEN = '#01E0C9' + DATACEBO_BLUE = '#03AFF1' + BACKGROUND_COLOR = '#F5F5F8' + FONT_SIZE = 18 + + +def convert_to_datetime(column_data, datetime_format=None): + """Convert a column data to pandas datetime. + + Args: + column_data (pandas.Series): + The column data + format (str): + Optional string format of datetime. If ``None``, will attempt to infer the datetime + format from the column data. Defaults to ``None``. + + Returns: + pandas.Series: + The converted column data. + """ + if is_datetime(column_data): + return column_data + + if datetime_format is None: + datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy()) + + return pd.to_datetime(column_data, format=datetime_format) + + +def convert_datetime_columns(real_column, synthetic_column, col_metadata): + """Convert a real and a synthetic column to pandas datetime. + + Args: + real_data (pandas.Series): + The real column data + synthetic_column (pandas.Series): + The synthetic column data + col_metadata: + The metadata associated with the column + + Returns: + (pandas.Series, pandas.Series): + The converted real and synthetic column data. + """ + datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format') + return (convert_to_datetime(real_column, datetime_format), + convert_to_datetime(synthetic_column, datetime_format)) + + +def discretize_table_data(real_data, synthetic_data, metadata): + """Create a copy of the real and synthetic data with discretized data. + + Convert numerical and datetime columns to discrete values, and label them + as categorical. + + Args: + real_data (pandas.DataFrame): + The real data. + synthetic_data (pandas.DataFrame): + The synthetic data. + metadata (dict) + The metadata. + + Returns: + (pandas.DataFrame, pandas.DataFrame, dict): + The binned real and synthetic data, and the updated metadata. + """ + binned_real = real_data.copy() + binned_synthetic = synthetic_data.copy() + binned_metadata = copy.deepcopy(metadata) + + for column_name, column_meta in get_columns_from_metadata(metadata).items(): + sdtype = get_type_from_column_meta(column_meta) + + if sdtype in ('numerical', 'datetime'): + real_col = real_data[column_name] + synthetic_col = synthetic_data[column_name] + if sdtype == 'datetime': + datetime_format = column_meta.get('format') or column_meta.get('datetime_format') + if real_col.dtype == 'O' and datetime_format: + real_col = pd.to_datetime(real_col, format=datetime_format) + synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format) + + real_col = pd.to_numeric(real_col) + synthetic_col = pd.to_numeric(synthetic_col) + + bin_edges = np.histogram_bin_edges(real_col.dropna()) + binned_real_col = np.digitize(real_col, bins=bin_edges) + binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges) + + binned_real[column_name] = binned_real_col + binned_synthetic[column_name] = binned_synthetic_col + get_columns_from_metadata(binned_metadata)[column_name] = {'sdtype': 'categorical'} + + return binned_real, binned_synthetic, binned_metadata + + +def _get_non_id_columns(metadata, binned_metadata): + valid_sdtypes = ['numerical', 'categorical', 'boolean', 'datetime'] + alternate_keys = get_alternate_keys(metadata) + non_id_columns = [] + for column, column_meta in get_columns_from_metadata(binned_metadata).items(): + is_key = column == metadata.get('primary_key', '') or column in alternate_keys + if get_type_from_column_meta(column_meta) in valid_sdtypes and not is_key: + non_id_columns.append(column) + + return non_id_columns + + +def discretize_and_apply_metric(real_data, synthetic_data, metadata, metric, keys_to_skip=[]): + """Discretize the data and apply the given metric. + + Args: + real_data (pandas.DataFrame): + The real data. + synthetic_data (pandas.DataFrame): + The synthetic data. + metadata (dict) + The metadata. + metric (sdmetrics.single_table.MultiColumnPairMetric): + The column pair metric to apply. + keys_to_skip (list[tuple(str)] or None): + A list of keys for which to skip computing the metric. + + Returns: + dict: + The metric results. + """ + metric_results = {} + + binned_real, binned_synthetic, binned_metadata = discretize_table_data( + real_data, synthetic_data, metadata) + + non_id_cols = _get_non_id_columns(metadata, binned_metadata) + for columns in itertools.combinations(non_id_cols, r=2): + sorted_columns = tuple(sorted(columns)) + if ( + sorted_columns not in keys_to_skip and + (sorted_columns[1], sorted_columns[0]) not in keys_to_skip + ): + result = metric.column_pairs_metric.compute_breakdown( + binned_real[list(sorted_columns)], + binned_synthetic[list(sorted_columns)], + ) + metric_results[sorted_columns] = result + metric_results[sorted_columns] = result + + return metric_results + + +def aggregate_metric_results(metric_results): + """Aggregate the scores and errors in a metric results mapping. + + Args: + metric_results (dict): + The metric results to aggregate. + + Returns: + (float, int): + The average of the metric scores, and the number of errors. + """ + if len(metric_results) == 0: + return np.nan, 0 + + metric_scores = [] + num_errors = 0 + + for _, breakdown in metric_results.items(): + metric_score = breakdown.get('score', np.nan) + if not np.isnan(metric_score): + metric_scores.append(metric_score) + if 'error' in breakdown: + num_errors += 1 + + return np.mean(metric_scores), num_errors + + +def _validate_categorical_values(real_data, synthetic_data, metadata, table=None): + """Get categorical values found in synthetic data but not real data for all columns. + + Args: + real_data (pd.DataFrame): + The real data. + synthetic_data (pd.DataFrame): + The synthetic data. + metadata (dict): + The metadata. + table (str, optional): + The name of the current table, if one exists + """ + if table: + warning_format = ('Unexpected values ({values}) in column "{column}" ' + f'and table "{table}"') + else: + warning_format = 'Unexpected values ({values}) in column "{column}"' + + columns = get_columns_from_metadata(metadata) + for column, column_meta in columns.items(): + column_type = get_type_from_column_meta(column_meta) + if column_type == 'categorical': + extra_categories = [ + value for value in synthetic_data[column].unique() + if value not in real_data[column].unique() + ] + if extra_categories: + value_list = '", "'.join(str(value) for value in extra_categories[:5]) + values = f'"{value_list}" + more' if len( + extra_categories) > 5 else f'"{value_list}"' + warnings.warn(warning_format.format(values=values, column=column)) diff --git a/copulas/visualization.py b/copulas/visualization.py index 6df130f0..9eea3d5a 100644 --- a/copulas/visualization.py +++ b/copulas/visualization.py @@ -1,11 +1,11 @@ """Visualization utilities for the Copulas library.""" import pandas as pd - -from copulas.utils2 import PlotConfig import plotly.express as px from pandas.api.types import is_datetime64_dtype +from copulas.utils2 import PlotConfig + def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): """Generate a bar plot of the real and synthetic data. @@ -39,6 +39,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): return fig + def _generate_scatter_plot(all_data, columns): """Generate a scatter plot for column pair plot. @@ -72,6 +73,7 @@ def _generate_scatter_plot(all_data, columns): return fig + def _generate_column_plot(real_column, synthetic_column, plot_kwargs={}, @@ -136,9 +138,10 @@ def _generate_column_plot(real_column, ) return fig + def hist_1d(data, title=None, bins=20, label=None): """Plot 1 dimensional data in a histogram. - + Args: data (pd.DataFrame): The table data. @@ -148,7 +151,7 @@ def hist_1d(data, title=None, bins=20, label=None): The number of bins to use for the histogram. label (str): The label of the plot. - + Returns: plotly.graph_objects._figure.Figure """ @@ -174,6 +177,7 @@ def hist_1d(data, title=None, bins=20, label=None): return fig + def compare_1d(real, synth): """Return a plot of the real and synthetic data for a given column. @@ -193,6 +197,7 @@ def compare_1d(real, synth): return _generate_column_plot(real, synth, plot_type='bar') + def scatter_2d(data, columns=None): """Plot 2 dimensional data in a scatter plot. @@ -211,6 +216,7 @@ def scatter_2d(data, columns=None): return _generate_scatter_plot(data, columns) + def compare_2d_(real, synth, columns=None): """Return a plot of the real and synthetic data for a given column pair. @@ -234,6 +240,7 @@ def compare_2d_(real, synth, columns=None): return _generate_scatter_plot(all_data, columns) + def scatter_3d_plotly(data, columns=None): """Return a 3D scatter plot of the data. @@ -268,6 +275,7 @@ def scatter_3d_plotly(data, columns=None): return fig + def compare_3d(real, synth, columns=None): """Generate a 3d scatter plot comparing real/synthetic data.