Skip to content

Commit

Permalink
Raise UserWarnings for Unused Numerical Distributions when using `G…
Browse files Browse the repository at this point in the history
…aussianCopulaSynthesizer` (#2301)
  • Loading branch information
pvk-developer authored Nov 18, 2024
1 parent 5741ee5 commit 18cd2e5
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 27 deletions.
7 changes: 2 additions & 5 deletions sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer
from sdv.single_table.utils import (
log_numerical_distributions_error,
validate_numerical_distributions,
warn_missing_numerical_distributions,
)

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,10 +204,7 @@ def _fit(self, processed_data):
processed_data (pandas.DataFrame):
Data to be learned.
"""
log_numerical_distributions_error(
self.numerical_distributions, processed_data.columns, LOGGER
)

warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns)
gaussian_normalizer_config = self._create_gaussian_normalizer_config(processed_data)
self._gaussian_normalizer_hyper_transformer = rdt.HyperTransformer()
self._gaussian_normalizer_hyper_transformer.set_config(gaussian_normalizer_config)
Expand Down
6 changes: 2 additions & 4 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import (
flatten_dict,
log_numerical_distributions_error,
unflatten_dict,
validate_numerical_distributions,
warn_missing_numerical_distributions,
)

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,9 +132,7 @@ def _fit(self, processed_data):
processed_data (pandas.DataFrame):
Data to be learned.
"""
log_numerical_distributions_error(
self.numerical_distributions, processed_data.columns, LOGGER
)
warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns)
self._num_rows = self._learn_num_rows(processed_data)
numerical_distributions = self._get_numerical_distributions(processed_data)
self._model = self._initialize_model(numerical_distributions)
Expand Down
12 changes: 6 additions & 6 deletions sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ def validate_numerical_distributions(numerical_distributions, metadata_columns):
)


def log_numerical_distributions_error(numerical_distributions, processed_data_columns, logger):
"""Log error when numerical distributions columns don't exist anymore."""
def warn_missing_numerical_distributions(numerical_distributions, processed_data_columns):
"""Raise an `UserWarning` when numerical distribution columns don't exist anymore."""
unseen_columns = numerical_distributions.keys() - set(processed_data_columns)
for column in unseen_columns:
logger.info(
f"Requested distribution '{numerical_distributions[column]}' "
f"cannot be applied to column '{column}' because it no longer "
'exists after preprocessing.'
warnings.warn(
f"Cannot use distribution '{numerical_distributions[column]}' for column "
f"'{column}' because the column is not statistically modeled.",
UserWarning,
)
22 changes: 22 additions & 0 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,25 @@ def test_support_nullable_pandas_dtypes():
assert (synthetic_data.dtypes == data.dtypes).all()
assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True)
assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True)


def test_user_warning_for_unused_numerical_distribution():
"""Ensure that a `UserWarning` is raised when a numerical distribution is not applied.
This test verifies that the synthesizer warns the user if a specified numerical
distribution is not used because the corresponding column does not exist or is not
modeled after preprocessing.
"""
# Setup
data, metadata = download_demo('single_table', 'fake_hotel_guests')
synthesizer = GaussianCopulaSynthesizer(
metadata, numerical_distributions={'credit_card_number': 'beta'}
)

# Run and Assert
message = (
"Cannot use distribution 'beta' for column 'credit_card_number' because the column is not "
'statistically modeled.'
)
with pytest.warns(UserWarning, match=message):
synthesizer.fit(data)
11 changes: 6 additions & 5 deletions tests/unit/single_table/test_copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def test__create_gaussian_normalizer_config(self, mock_rdt):
assert config == expected_config
assert mock_rdt.transformers.GaussianNormalizer.call_args_list == expected_calls

@patch('sdv.single_table.copulagan.LOGGER')
@patch('sdv.single_table.utils.warnings')
@patch('sdv.single_table.copulagan.CTGANSynthesizer._fit')
@patch('sdv.single_table.copulagan.rdt')
def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger):
def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_warnings):
"""Test a message is logged.
A message should be logged if the columns passed in ``numerical_distributions``
Expand All @@ -284,10 +284,11 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger):
instance._fit(processed_data)

# Assert
mock_logger.info.assert_called_once_with(
"Requested distribution 'gamma' cannot be applied to column 'col' "
'because it no longer exists after preprocessing.'
warning_message = (
"Cannot use distribution 'gamma' for column 'col' because the column is not "
'statistically modeled.'
)
mock_warnings.warn.assert_called_once_with(warning_message, UserWarning)

@patch('sdv.single_table.copulagan.CTGANSynthesizer._fit')
@patch('sdv.single_table.copulagan.rdt')
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def test_get_parameters(self):
'default_distribution': 'beta',
}

@patch('sdv.single_table.copulas.LOGGER')
def test__fit_logging(self, mock_logger):
"""Test a message is logged.
@patch('sdv.single_table.utils.warnings')
def test__fit_warning_numerical_distributions(self, mock_warnings):
"""Test that a warning is shown when fitting numerical distributions on a dropped column.
A message should be logged if the columns passed in ``numerical_distributions``
A warning message should be printed if the columns passed in ``numerical_distributions``
were renamed/dropped during preprocessing.
"""
# Setup
Expand All @@ -180,10 +180,11 @@ def test__fit_logging(self, mock_logger):
instance._fit(processed_data)

# Assert
mock_logger.info.assert_called_once_with(
"Requested distribution 'gamma' cannot be applied to column 'col' "
'because it no longer exists after preprocessing.'
warning_message = (
"Cannot use distribution 'gamma' for column 'col' because the column is not "
'statistically modeled.'
)
mock_warnings.warn.assert_called_once_with(warning_message, UserWarning)

@patch('sdv.single_table.copulas.warnings')
@patch('sdv.single_table.copulas.multivariate')
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
handle_sampling_error,
unflatten_dict,
validate_file_path,
warn_missing_numerical_distributions,
)


Expand Down Expand Up @@ -328,3 +329,18 @@ def test_validate_file_path(mock_open):
assert output_path in result
assert none_result is None
mock_open.assert_called_once_with(result, 'w+')


def test_warn_missing_numerical_distributions():
"""Test the warn_missing_numerical_distributions function."""
# Setup
numerical_distributions = {'age': 'beta', 'height': 'uniform'}
processed_data_columns = ['height', 'weight']

# Run and Assert
message = (
"Cannot use distribution 'beta' for column 'age' because the column is not "
'statistically modeled.'
)
with pytest.warns(UserWarning, match=message):
warn_missing_numerical_distributions(numerical_distributions, processed_data_columns)

0 comments on commit 18cd2e5

Please sign in to comment.