diff --git a/copulas/sdmetrics.py b/copulas/sdmetrics.py
new file mode 100644
index 00000000..bafe1d52
--- /dev/null
+++ b/copulas/sdmetrics.py
@@ -0,0 +1,480 @@
+"""Visualization methods for SDMetrics."""
+
+import pandas as pd
+import plotly.express as px
+import plotly.figure_factory as ff
+from pandas.api.types import is_datetime64_dtype
+
+from copulas.utils import get_missing_percentage, is_datetime
+from copulas.utils2 import PlotConfig
+
+
+def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
+ """Generate a bar plot of the real and synthetic data.
+
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ plot_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+ histogram_kwargs = {
+ 'x': 'values',
+ # 'color': 'Data',
+ 'barmode': 'group',
+ 'color_discrete_sequence': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN],
+ 'pattern_shape': 'Data',
+ 'pattern_shape_sequence': ['', '/'],
+ 'histnorm': 'probability density',
+ }
+ histogram_kwargs.update(plot_kwargs)
+ fig = px.histogram(
+ all_data,
+ **histogram_kwargs
+ )
+
+ return fig
+
+
+def _generate_heatmap_plot(all_data, columns):
+ """Generate heatmap plot for discrete data.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.density_heatmap(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ facet_col='Data',
+ histnorm='probability'
+ )
+
+ fig.update_layout(
+ title_text=f"Real vs Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ coloraxis={'colorscale': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]},
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ fig.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1] + ' Data'))
+
+ return fig
+
+
+def _generate_box_plot(all_data, columns):
+ """Generate a box plot for mixed discrete and continuous column data.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.box(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ )
+
+ fig.update_layout(
+ title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+
+def _generate_scatter_plot(all_data, columns):
+ """Generate a scatter plot for column pair plot.
+
+ Args:
+ all_data (pandas.DataFrame):
+ The real and synthetic data for the desired column pair containing a
+ ``Data`` column that indicates whether is real or synthetic.
+ columns (list):
+ A list of the columns being plotted.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ fig = px.scatter(
+ all_data,
+ x=columns[0],
+ y=columns[1],
+ color='Data',
+ color_discrete_map={
+ 'Real': PlotConfig.DATACEBO_DARK,
+ 'Synthetic': PlotConfig.DATACEBO_GREEN
+ },
+ symbol='Data'
+ )
+
+ fig.update_layout(
+ title=f"Real vs. Synthetic Data for columns '{columns[0]}' and '{columns[1]}'",
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+
+ return fig
+
+
+def _generate_column_distplot(real_data, synthetic_data, plot_kwargs={}):
+ """Plot the real and synthetic data as a distplot.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real data for the desired column.
+ synthetic_data (pandas.DataFrame):
+ The synthetic data for the desired column.
+ plot_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ default_distplot_kwargs = {
+ 'show_hist': False,
+ 'show_rug': False,
+ 'colors': [PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN]
+ }
+
+ fig = ff.create_distplot(
+ [real_data['values'], synthetic_data['values']],
+ ['Real', 'Synthetic'],
+ **{**default_distplot_kwargs, **plot_kwargs}
+ )
+
+ return fig
+
+
+def _generate_column_plot(real_column,
+ synthetic_column,
+ plot_type,
+ plot_kwargs={},
+ plot_title=None,
+ x_label=None):
+ """Generate a plot of the real and synthetic data.
+
+ Args:
+ real_column (pandas.Series):
+ The real data for the desired column.
+ synthetic_column (pandas.Series):
+ The synthetic data for the desired column.
+ plot_type (str):
+ The type of plot to use. Must be one of 'bar' or 'distplot'.
+ hist_kwargs (dict, optional):
+ Dictionary of keyword arguments to pass to px.histogram. Keyword arguments
+ provided this way will overwrite defaults.
+ plot_title (str, optional):
+ Title to use for the plot. Defaults to 'Real vs. Synthetic Data for column {column}'
+ x_label (str, optional):
+ Label to use for x-axis. Defaults to 'Category'.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot']:
+ raise ValueError(
+ "Unrecognized plot_type '{plot_type}'. Pleas use one of 'bar' or 'distplot'"
+ )
+
+ column_name = real_column.name if hasattr(real_column, 'name') else ''
+
+ missing_data_real = get_missing_percentage(real_column)
+ missing_data_synthetic = get_missing_percentage(synthetic_column)
+
+ real_data = pd.DataFrame({'values': real_column.copy().dropna()})
+ real_data['Data'] = 'Real'
+ synthetic_data = pd.DataFrame({'values': synthetic_column.copy().dropna()})
+ synthetic_data['Data'] = 'Synthetic'
+
+ is_datetime_sdtype = False
+ if is_datetime64_dtype(real_column.dtype):
+ is_datetime_sdtype = True
+ real_data['values'] = real_data['values'].astype('int64')
+ synthetic_data['values'] = synthetic_data['values'].astype('int64')
+
+ trace_args = {}
+
+ if plot_type == 'bar':
+ fig = _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs)
+ elif plot_type == 'distplot':
+ x_label = x_label or 'Value'
+ fig = _generate_column_distplot(real_data, synthetic_data, plot_kwargs)
+ trace_args = {'fill': 'tozeroy'}
+
+ for i, name in enumerate(['Real', 'Synthetic']):
+ fig.update_traces(
+ x=pd.to_datetime(fig.data[i].x) if is_datetime_sdtype else fig.data[i].x,
+ hovertemplate=f'{name}
Frequency: %{{y}}',
+ selector={'name': name},
+ **trace_args
+ )
+
+ show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0
+ annotations = [] if not show_missing_values else [
+ {
+ 'xref': 'paper',
+ 'yref': 'paper',
+ 'x': 1.0,
+ 'y': 1.05,
+ 'showarrow': False,
+ 'text': (
+ f'*Missing Values: Real Data ({missing_data_real}%), '
+ f'Synthetic Data ({missing_data_synthetic}%)'
+ ),
+ },
+ ]
+
+ if not plot_title:
+ plot_title = f"Real vs. Synthetic Data for column '{column_name}'"
+
+ if not x_label:
+ x_label = 'Category'
+
+ fig.update_layout(
+ title=plot_title,
+ xaxis_title=x_label,
+ yaxis_title='Frequency',
+ plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
+ annotations=annotations,
+ font={'size': PlotConfig.FONT_SIZE},
+ )
+ return fig
+
+
+def _generate_cardinality_plot(real_data,
+ synthetic_data,
+ parent_primary_key,
+ child_foreign_key,
+ plot_type='bar'):
+ plot_title = (
+ f"Relationship (child foreign key='{child_foreign_key}' and parent "
+ f"primary key='{parent_primary_key}')"
+ )
+ x_label = '# of Children (per Parent)'
+
+ plot_kwargs = {}
+ if plot_type == 'bar':
+ max_cardinality = max(max(real_data), max(synthetic_data))
+ min_cardinality = min(min(real_data), min(synthetic_data))
+ plot_kwargs = {
+ 'nbins': max_cardinality - min_cardinality + 1
+ }
+
+ return _generate_column_plot(real_data, synthetic_data, plot_type,
+ plot_kwargs, plot_title, x_label)
+
+
+def _get_cardinality(parent_table, child_table, parent_primary_key, child_foreign_key):
+ """Return the cardinality of the parent-child relationship.
+
+ Args:
+ parent_table (pandas.DataFrame):
+ The parent table.
+ child_table (pandas.DataFrame):
+ The child table.
+ parent_primary_key (string):
+ The name of the primary key column in the parent table.
+ child_foreign_key (string):
+ The name of the foreign key column in the child table.
+
+ Returns:
+ pandas.DataFrame
+ """
+ child_counts = child_table[child_foreign_key].value_counts().rename('# children')
+ cardinalities = child_counts.reindex(parent_table[parent_primary_key], fill_value=0).to_frame()
+
+ return cardinalities.sort_values('# children')['# children']
+
+
+def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name,
+ child_foreign_key, parent_primary_key, plot_type='bar'):
+ """Return a plot of the cardinality of the parent-child relationship.
+
+ Args:
+ real_data (dict):
+ The real data.
+ synthetic_data (dict):
+ The synthetic data.
+ child_table_name (string):
+ The name of the child table.
+ parent_table_name (string):
+ The name of the parent table.
+ child_foreign_key (string):
+ The name of the foreign key column in the child table.
+ parent_primary_key (string):
+ The name of the primary key column in the parent table.
+ plot_type (string, optional):
+ The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'.
+ Defaults to 'bar'.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot']:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot'].")
+
+ real_cardinality = _get_cardinality(
+ real_data[parent_table_name], real_data[child_table_name],
+ parent_primary_key, child_foreign_key
+ )
+ synth_cardinality = _get_cardinality(
+ synthetic_data[parent_table_name],
+ synthetic_data[child_table_name],
+ parent_primary_key, child_foreign_key
+ )
+
+ fig = _generate_cardinality_plot(
+ real_cardinality,
+ synth_cardinality,
+ parent_primary_key,
+ child_foreign_key,
+ plot_type=plot_type
+ )
+
+ return fig
+
+
+def get_column_plot(real_data, synthetic_data, column_name, plot_type=None):
+ """Return a plot of the real and synthetic data for a given column.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real table data.
+ synthetic_data (pandas.DataFrame):
+ The synthetic table data.
+ column_name (str):
+ The name of the column.
+ plot_type (str or None):
+ The plot to be used. Can choose between ``distplot``, ``bar`` or ``None``. If ``None`
+ select between ``distplot`` or ``bar`` depending on the data that the column contains,
+ ``distplot`` for datetime and numerical values and ``bar`` for categorical.
+ Defaults to ``None``.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if plot_type not in ['bar', 'distplot', None]:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of ['bar', 'distplot', None]."
+ )
+
+ if column_name not in real_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in real table data.")
+ if column_name not in synthetic_data.columns:
+ raise ValueError(f"Column '{column_name}' not found in synthetic table data.")
+
+ real_column = real_data[column_name]
+ if plot_type is None:
+ column_is_datetime = is_datetime(real_data[column_name])
+ dtype = real_column.dropna().infer_objects().dtype.kind
+ if column_is_datetime or dtype in ('i', 'f'):
+ plot_type = 'distplot'
+ else:
+ plot_type = 'bar'
+
+ real_column = real_data[column_name]
+ synthetic_column = synthetic_data[column_name]
+
+ fig = _generate_column_plot(real_column, synthetic_column, plot_type)
+
+ return fig
+
+
+def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None):
+ """Return a plot of the real and synthetic data for a given column pair.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real table data.
+ synthetic_column (pandas.Dataframe):
+ The synthetic table data.
+ column_names (list[string]):
+ The names of the two columns to plot.
+ plot_type (str or None):
+ The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``.
+ If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data
+ that the column contains, ``scatter`` used for datetime and numerical values,
+ ``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``.
+
+ Returns:
+ plotly.graph_objects._figure.Figure
+ """
+ if len(column_names) != 2:
+ raise ValueError('Must provide exactly two column names.')
+
+ if not set(column_names).issubset(real_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(real_data.columns)} in real data.'
+ )
+
+ if not set(column_names).issubset(synthetic_data.columns):
+ raise ValueError(
+ f'Missing column(s) {set(column_names) - set(synthetic_data.columns)} '
+ 'in synthetic data.'
+ )
+
+ if plot_type not in ['box', 'heatmap', 'scatter', None]:
+ raise ValueError(
+ f"Invalid plot_type '{plot_type}'. Please use one of "
+ "['box', 'heatmap', 'scatter', None]."
+ )
+
+ real_data = real_data[column_names]
+ synthetic_data = synthetic_data[column_names]
+ if plot_type is None:
+ plot_type = []
+ for column_name in column_names:
+ column = real_data[column_name]
+ dtype = column.dropna().infer_objects().dtype.kind
+ if dtype in ('i', 'f') or is_datetime(column):
+ plot_type.append('scatter')
+ else:
+ plot_type.append('heatmap')
+
+ if len(set(plot_type)) > 1:
+ plot_type = 'box'
+ else:
+ plot_type = plot_type.pop()
+
+ # Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
+ columns = list(real_data.columns)
+ real_data = real_data.copy()
+ real_data['Data'] = 'Real'
+ synthetic_data = synthetic_data.copy()
+ synthetic_data['Data'] = 'Synthetic'
+ all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
+
+ if plot_type == 'scatter':
+ return _generate_scatter_plot(all_data, columns)
+ elif plot_type == 'heatmap':
+ return _generate_heatmap_plot(all_data, columns)
+
+ return _generate_box_plot(all_data, columns)
diff --git a/copulas/utils.py b/copulas/utils.py
new file mode 100644
index 00000000..30e56ebe
--- /dev/null
+++ b/copulas/utils.py
@@ -0,0 +1,298 @@
+"""SDMetrics utils to be used across all the project."""
+
+from collections import Counter
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+from sklearn.preprocessing import OneHotEncoder
+
+
+def nested_attrs_meta(nested):
+ """Metaclass factory that defines a Metaclass with a dynamic attribute name."""
+
+ class Metaclass(type):
+ """Metaclass which pulls the attributes from a nested object using properties."""
+
+ def __getattr__(cls, attr):
+ """If cls does not have the attribute, try to get it from the nested object."""
+ nested_obj = getattr(cls, nested)
+ if hasattr(nested_obj, attr):
+ return getattr(nested_obj, attr)
+
+ raise AttributeError(f"type object '{cls.__name__}' has no attribute '{attr}'")
+
+ @property
+ def name(cls):
+ return getattr(cls, nested).name
+
+ @property
+ def goal(cls):
+ return getattr(cls, nested).goal
+
+ @property
+ def max_value(cls):
+ return getattr(cls, nested).max_value
+
+ @property
+ def min_value(cls):
+ return getattr(cls, nested).min_value
+
+ return Metaclass
+
+
+def get_frequencies(real, synthetic):
+ """Get percentual frequencies for each possible real categorical value.
+
+ Given two iterators containing categorical data, this transforms it into
+ observed/expected frequencies which can be used for statistical tests. It
+ adds a regularization term to handle cases where the synthetic data contains
+ values that don't exist in the real data.
+
+ Args:
+ real (list):
+ A list of hashable objects.
+ synthetic (list):
+ A list of hashable objects.
+
+ Yields:
+ tuble[list, list]:
+ The observed and expected frequencies (as a percent).
+ """
+ f_obs, f_exp = [], []
+ real, synthetic = Counter(real), Counter(synthetic)
+ for value in synthetic:
+ if value not in real:
+ real[value] += 1e-6 # Regularization to prevent NaN.
+
+ for value in real:
+ f_obs.append(synthetic[value] / sum(synthetic.values())) # noqa: PD011
+ f_exp.append(real[value] / sum(real.values())) # noqa: PD011
+
+ return f_obs, f_exp
+
+
+def get_missing_percentage(data_column):
+ """Compute the missing value percentage of a column.
+
+ Args:
+ data_column (pandas.Series):
+ The data of the desired column.
+
+ Returns:
+ pandas.Series:
+ Percentage of missing values inside the column.
+ """
+ return round((data_column.isna().sum() / len(data_column)) * 100, 2)
+
+
+def get_cardinality_distribution(parent_column, child_column):
+ """Compute the cardinality distribution of the (parent, child) pairing.
+
+ Args:
+ parent_column (pandas.Series):
+ The parent column.
+ child_column (pandas.Series):
+ The child column.
+
+ Returns:
+ pandas.Series:
+ The cardinality distribution.
+ """
+ child_df = pd.DataFrame({'child_counts': child_column.value_counts()})
+ cardinality_df = pd.DataFrame({'parent': parent_column}).join(
+ child_df, on='parent').fillna(0)
+
+ return cardinality_df['child_counts']
+
+
+def is_datetime(data):
+ """Determine if the input is a datetime type or not.
+
+ Args:
+ data (pandas.DataFrame, int or datetime):
+ Input to evaluate.
+
+ Returns:
+ bool:
+ True if the input is a datetime type, False if not.
+ """
+ return (
+ pd.api.types.is_datetime64_any_dtype(data)
+ or isinstance(data, pd.Timestamp)
+ or isinstance(data, datetime)
+ )
+
+
+class HyperTransformer():
+ """HyperTransformer class.
+
+ The ``HyperTransformer`` class contains a set of transforms to transform one or
+ more columns based on each column's data type.
+ """
+
+ column_transforms = {}
+ column_kind = {}
+
+ def fit(self, data):
+ """Fit the HyperTransformer to the given data.
+
+ Args:
+ data (pandas.DataFrame):
+ The data to transform.
+ """
+ if not isinstance(data, pd.DataFrame):
+ data = pd.DataFrame(data)
+
+ for field in data:
+ kind = data[field].dropna().infer_objects().dtype.kind
+ self.column_kind[field] = kind
+
+ if kind == 'i' or kind == 'f':
+ # Numerical column.
+ self.column_transforms[field] = {'mean': data[field].mean()}
+ elif kind == 'b':
+ # Boolean column.
+ numeric = pd.to_numeric(data[field], errors='coerce').astype(float)
+ self.column_transforms[field] = {'mode': numeric.mode().iloc[0]}
+ elif kind == 'O':
+ # Categorical column.
+ col_data = pd.DataFrame({'field': data[field]})
+ enc = OneHotEncoder()
+ enc.fit(col_data)
+ self.column_transforms[field] = {'one_hot_encoder': enc}
+ elif kind == 'M':
+ # Datetime column.
+ nulls = data[field].isna()
+ integers = pd.to_numeric(
+ data[field], errors='coerce').to_numpy().astype(np.float64)
+ integers[nulls] = np.nan
+ self.column_transforms[field] = {'mean': pd.Series(integers).mean()}
+
+ def transform(self, data):
+ """Transform the given data based on the data type of each column.
+
+ Args:
+ data (pandas.DataFrame):
+ The data to transform.
+
+ Returns:
+ pandas.DataFrame:
+ The transformed data.
+ """
+ if not isinstance(data, pd.DataFrame):
+ data = pd.DataFrame(data)
+
+ for field in data:
+ transform_info = self.column_transforms[field]
+
+ kind = self.column_kind[field]
+ if kind == 'i' or kind == 'f':
+ # Numerical column.
+ data[field] = data[field].fillna(transform_info['mean'])
+ elif kind == 'b':
+ # Boolean column.
+ data[field] = pd.to_numeric(data[field], errors='coerce').astype(float)
+ data[field] = data[field].fillna(transform_info['mode'])
+ elif kind == 'O':
+ # Categorical column.
+ col_data = pd.DataFrame({'field': data[field]})
+ out = transform_info['one_hot_encoder'].transform(col_data).toarray()
+ transformed = pd.DataFrame(
+ out, columns=[f'value{i}' for i in range(np.shape(out)[1])])
+ data = data.drop(columns=[field])
+ data = pd.concat([data, transformed.set_index(data.index)], axis=1)
+ elif kind == 'M':
+ # Datetime column.
+ nulls = data[field].isna()
+ integers = pd.to_numeric(
+ data[field], errors='coerce').to_numpy().astype(np.float64)
+ integers[nulls] = np.nan
+ data[field] = pd.Series(integers)
+ data[field] = data[field].fillna(transform_info['mean'])
+
+ return data
+
+ def fit_transform(self, data):
+ """Fit and transform the given data based on the data type of each column.
+
+ Args:
+ data (pandas.DataFrame):
+ The data to transform.
+
+ Returns:
+ pandas.DataFrame:
+ The transformed data.
+ """
+ self.fit(data)
+ return self.transform(data)
+
+
+def get_columns_from_metadata(metadata):
+ """Get the column info from a metadata dict.
+
+ Args:
+ metadata (dict):
+ The metadata dict.
+
+ Returns:
+ dict:
+ The columns metadata.
+ """
+ return metadata.get('columns', {})
+
+
+def get_type_from_column_meta(column_metadata):
+ """Get the type of a given column from the column metadata.
+
+ Args:
+ column_metadata (dict):
+ The column metadata.
+
+ Returns:
+ string:
+ The column type.
+ """
+ return column_metadata.get('sdtype', '')
+
+
+def get_alternate_keys(metadata):
+ """Get the alternate keys from a metadata dict.
+
+ Args:
+ metadata (dict):
+ The metadata dict.
+
+ Returns:
+ list:
+ The list of alternate keys.
+ """
+ alternate_keys = []
+ for alternate_key in metadata.get('alternate_keys', []):
+ if isinstance(alternate_key, list):
+ alternate_keys.extend(alternate_key)
+ else:
+ alternate_keys.append(alternate_key)
+
+ return alternate_keys
+
+
+def strip_characters(list_character, a_string):
+ """Strip characters from a column name.
+
+ Args:
+ list_character (list):
+ The list of characters to strip.
+ a_string (string):
+ The string to be stripped.
+
+ Returns:
+ string:
+ The string with the characters stripped.
+ """
+ result = a_string
+ for character in list_character:
+ if character in result:
+ result = result.replace(character, '')
+
+ return result
diff --git a/copulas/utils2.py b/copulas/utils2.py
new file mode 100644
index 00000000..560afcf6
--- /dev/null
+++ b/copulas/utils2.py
@@ -0,0 +1,233 @@
+"""Report utility methods."""
+
+import copy
+import itertools
+import warnings
+
+import numpy as np
+import pandas as pd
+from pandas.core.tools.datetimes import _guess_datetime_format_for_array
+
+from copulas.utils import (
+ get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta, is_datetime)
+
+CONTINUOUS_SDTYPES = ['numerical', 'datetime']
+DISCRETE_SDTYPES = ['categorical', 'boolean']
+
+
+class PlotConfig:
+ """Custom plot settings for visualizations."""
+
+ GREEN = '#36B37E'
+ RED = '#FF0000'
+ ORANGE = '#F16141'
+ DATACEBO_DARK = '#000036'
+ DATACEBO_GREEN = '#01E0C9'
+ DATACEBO_BLUE = '#03AFF1'
+ BACKGROUND_COLOR = '#F5F5F8'
+ FONT_SIZE = 18
+
+
+def convert_to_datetime(column_data, datetime_format=None):
+ """Convert a column data to pandas datetime.
+
+ Args:
+ column_data (pandas.Series):
+ The column data
+ format (str):
+ Optional string format of datetime. If ``None``, will attempt to infer the datetime
+ format from the column data. Defaults to ``None``.
+
+ Returns:
+ pandas.Series:
+ The converted column data.
+ """
+ if is_datetime(column_data):
+ return column_data
+
+ if datetime_format is None:
+ datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())
+
+ return pd.to_datetime(column_data, format=datetime_format)
+
+
+def convert_datetime_columns(real_column, synthetic_column, col_metadata):
+ """Convert a real and a synthetic column to pandas datetime.
+
+ Args:
+ real_data (pandas.Series):
+ The real column data
+ synthetic_column (pandas.Series):
+ The synthetic column data
+ col_metadata:
+ The metadata associated with the column
+
+ Returns:
+ (pandas.Series, pandas.Series):
+ The converted real and synthetic column data.
+ """
+ datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format')
+ return (convert_to_datetime(real_column, datetime_format),
+ convert_to_datetime(synthetic_column, datetime_format))
+
+
+def discretize_table_data(real_data, synthetic_data, metadata):
+ """Create a copy of the real and synthetic data with discretized data.
+
+ Convert numerical and datetime columns to discrete values, and label them
+ as categorical.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real data.
+ synthetic_data (pandas.DataFrame):
+ The synthetic data.
+ metadata (dict)
+ The metadata.
+
+ Returns:
+ (pandas.DataFrame, pandas.DataFrame, dict):
+ The binned real and synthetic data, and the updated metadata.
+ """
+ binned_real = real_data.copy()
+ binned_synthetic = synthetic_data.copy()
+ binned_metadata = copy.deepcopy(metadata)
+
+ for column_name, column_meta in get_columns_from_metadata(metadata).items():
+ sdtype = get_type_from_column_meta(column_meta)
+
+ if sdtype in ('numerical', 'datetime'):
+ real_col = real_data[column_name]
+ synthetic_col = synthetic_data[column_name]
+ if sdtype == 'datetime':
+ datetime_format = column_meta.get('format') or column_meta.get('datetime_format')
+ if real_col.dtype == 'O' and datetime_format:
+ real_col = pd.to_datetime(real_col, format=datetime_format)
+ synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format)
+
+ real_col = pd.to_numeric(real_col)
+ synthetic_col = pd.to_numeric(synthetic_col)
+
+ bin_edges = np.histogram_bin_edges(real_col.dropna())
+ binned_real_col = np.digitize(real_col, bins=bin_edges)
+ binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges)
+
+ binned_real[column_name] = binned_real_col
+ binned_synthetic[column_name] = binned_synthetic_col
+ get_columns_from_metadata(binned_metadata)[column_name] = {'sdtype': 'categorical'}
+
+ return binned_real, binned_synthetic, binned_metadata
+
+
+def _get_non_id_columns(metadata, binned_metadata):
+ valid_sdtypes = ['numerical', 'categorical', 'boolean', 'datetime']
+ alternate_keys = get_alternate_keys(metadata)
+ non_id_columns = []
+ for column, column_meta in get_columns_from_metadata(binned_metadata).items():
+ is_key = column == metadata.get('primary_key', '') or column in alternate_keys
+ if get_type_from_column_meta(column_meta) in valid_sdtypes and not is_key:
+ non_id_columns.append(column)
+
+ return non_id_columns
+
+
+def discretize_and_apply_metric(real_data, synthetic_data, metadata, metric, keys_to_skip=[]):
+ """Discretize the data and apply the given metric.
+
+ Args:
+ real_data (pandas.DataFrame):
+ The real data.
+ synthetic_data (pandas.DataFrame):
+ The synthetic data.
+ metadata (dict)
+ The metadata.
+ metric (sdmetrics.single_table.MultiColumnPairMetric):
+ The column pair metric to apply.
+ keys_to_skip (list[tuple(str)] or None):
+ A list of keys for which to skip computing the metric.
+
+ Returns:
+ dict:
+ The metric results.
+ """
+ metric_results = {}
+
+ binned_real, binned_synthetic, binned_metadata = discretize_table_data(
+ real_data, synthetic_data, metadata)
+
+ non_id_cols = _get_non_id_columns(metadata, binned_metadata)
+ for columns in itertools.combinations(non_id_cols, r=2):
+ sorted_columns = tuple(sorted(columns))
+ if (
+ sorted_columns not in keys_to_skip and
+ (sorted_columns[1], sorted_columns[0]) not in keys_to_skip
+ ):
+ result = metric.column_pairs_metric.compute_breakdown(
+ binned_real[list(sorted_columns)],
+ binned_synthetic[list(sorted_columns)],
+ )
+ metric_results[sorted_columns] = result
+ metric_results[sorted_columns] = result
+
+ return metric_results
+
+
+def aggregate_metric_results(metric_results):
+ """Aggregate the scores and errors in a metric results mapping.
+
+ Args:
+ metric_results (dict):
+ The metric results to aggregate.
+
+ Returns:
+ (float, int):
+ The average of the metric scores, and the number of errors.
+ """
+ if len(metric_results) == 0:
+ return np.nan, 0
+
+ metric_scores = []
+ num_errors = 0
+
+ for _, breakdown in metric_results.items():
+ metric_score = breakdown.get('score', np.nan)
+ if not np.isnan(metric_score):
+ metric_scores.append(metric_score)
+ if 'error' in breakdown:
+ num_errors += 1
+
+ return np.mean(metric_scores), num_errors
+
+
+def _validate_categorical_values(real_data, synthetic_data, metadata, table=None):
+ """Get categorical values found in synthetic data but not real data for all columns.
+
+ Args:
+ real_data (pd.DataFrame):
+ The real data.
+ synthetic_data (pd.DataFrame):
+ The synthetic data.
+ metadata (dict):
+ The metadata.
+ table (str, optional):
+ The name of the current table, if one exists
+ """
+ if table:
+ warning_format = ('Unexpected values ({values}) in column "{column}" '
+ f'and table "{table}"')
+ else:
+ warning_format = 'Unexpected values ({values}) in column "{column}"'
+
+ columns = get_columns_from_metadata(metadata)
+ for column, column_meta in columns.items():
+ column_type = get_type_from_column_meta(column_meta)
+ if column_type == 'categorical':
+ extra_categories = [
+ value for value in synthetic_data[column].unique()
+ if value not in real_data[column].unique()
+ ]
+ if extra_categories:
+ value_list = '", "'.join(str(value) for value in extra_categories[:5])
+ values = f'"{value_list}" + more' if len(
+ extra_categories) > 5 else f'"{value_list}"'
+ warnings.warn(warning_format.format(values=values, column=column))
diff --git a/copulas/visualization.py b/copulas/visualization.py
index 6df130f0..9eea3d5a 100644
--- a/copulas/visualization.py
+++ b/copulas/visualization.py
@@ -1,11 +1,11 @@
"""Visualization utilities for the Copulas library."""
import pandas as pd
-
-from copulas.utils2 import PlotConfig
import plotly.express as px
from pandas.api.types import is_datetime64_dtype
+from copulas.utils2 import PlotConfig
+
def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
"""Generate a bar plot of the real and synthetic data.
@@ -39,6 +39,7 @@ def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}):
return fig
+
def _generate_scatter_plot(all_data, columns):
"""Generate a scatter plot for column pair plot.
@@ -72,6 +73,7 @@ def _generate_scatter_plot(all_data, columns):
return fig
+
def _generate_column_plot(real_column,
synthetic_column,
plot_kwargs={},
@@ -136,9 +138,10 @@ def _generate_column_plot(real_column,
)
return fig
+
def hist_1d(data, title=None, bins=20, label=None):
"""Plot 1 dimensional data in a histogram.
-
+
Args:
data (pd.DataFrame):
The table data.
@@ -148,7 +151,7 @@ def hist_1d(data, title=None, bins=20, label=None):
The number of bins to use for the histogram.
label (str):
The label of the plot.
-
+
Returns:
plotly.graph_objects._figure.Figure
"""
@@ -174,6 +177,7 @@ def hist_1d(data, title=None, bins=20, label=None):
return fig
+
def compare_1d(real, synth):
"""Return a plot of the real and synthetic data for a given column.
@@ -193,6 +197,7 @@ def compare_1d(real, synth):
return _generate_column_plot(real, synth, plot_type='bar')
+
def scatter_2d(data, columns=None):
"""Plot 2 dimensional data in a scatter plot.
@@ -211,6 +216,7 @@ def scatter_2d(data, columns=None):
return _generate_scatter_plot(data, columns)
+
def compare_2d_(real, synth, columns=None):
"""Return a plot of the real and synthetic data for a given column pair.
@@ -234,6 +240,7 @@ def compare_2d_(real, synth, columns=None):
return _generate_scatter_plot(all_data, columns)
+
def scatter_3d_plotly(data, columns=None):
"""Return a 3D scatter plot of the data.
@@ -268,6 +275,7 @@ def scatter_3d_plotly(data, columns=None):
return fig
+
def compare_3d(real, synth, columns=None):
"""Generate a 3d scatter plot comparing real/synthetic data.