Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTransformer init parameters #146

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,27 @@ class DataTransformer(object):
Discrete columns are encoded using a scikit-learn OneHotEncoder.
"""

def __init__(self, max_clusters=10, weight_threshold=0.005):
def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None):
"""Create a data transformer.

Args:
max_clusters (int):
Maximum number of Gaussian distributions in Bayesian GMM.
weight_threshold (float):
Weight threshold for a Gaussian distribution to be kept.
_max_gm_samples (int):
FlorentRamb marked this conversation as resolved.
Show resolved Hide resolved
Maximum number of sample to use during GMM fit
FlorentRamb marked this conversation as resolved.
Show resolved Hide resolved
"""
self._max_clusters = max_clusters
self._weight_threshold = weight_threshold
self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples

def _fit_continuous(self, column_name, raw_column_data):
"""Train Bayesian GMM for continuous column."""
if self._max_gm_samples <= raw_column_data.shape[0]:
raw_column_data = np.random.choice(raw_column_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that when it comes to this kind of line breaking this indentation is better:

    raw_column_data = np.random.choice(
        raw_column_data,
        size=self._max_gm_samples,
        replace=False
    )

size=self._max_gm_samples,
replace=False)
gm = BayesianGaussianMixture(
self._max_clusters,
weight_concentration_prior_type='dirichlet_process',
Expand Down
7 changes: 5 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
if invalid_columns:
raise ValueError('Invalid columns found: {}'.format(invalid_columns))

def fit(self, train_data, discrete_columns=tuple(), epochs=None):
def fit(self, train_data, discrete_columns=tuple(), epochs=None,
data_transformer_params={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data_transformer_params should be moved to the __init__ and be asigned as self.data_transformer_params. (Use deepcopy if needed).

"""Fit the CTGAN Synthesizer models to the training data.

Args:
Expand All @@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
data_transformer_params (dict):
Dictionary of parameters for ``DataTransformer`` initialization.
"""
self._validate_discrete_columns(train_data, discrete_columns)

Expand All @@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
DeprecationWarning
)

self._transformer = DataTransformer()
self._transformer = DataTransformer(**data_transformer_params)
self._transformer.fit(train_data, discrete_columns)

train_data = self._transformer.transform(train_data)
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():

with pytest.raises(ValueError):
ctgan.sample(1, 'discrete', "d")


def test_ctgan_data_transformer_params():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.

data = pd.DataFrame({
'continuous': np.random.random(1000)
})

ctgan = CTGANSynthesizer(epochs=1)
ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100})

assert ctgan._transformer._max_gm_samples == 100