Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTransformer init parameters #146

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,29 @@ class DataTransformer(object):
Discrete columns are encoded using a scikit-learn OneHotEncoder.
"""

def __init__(self, max_clusters=10, weight_threshold=0.005):
def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None):
"""Create a data transformer.

Args:
max_clusters (int):
Maximum number of Gaussian distributions in Bayesian GMM.
weight_threshold (float):
Weight threshold for a Gaussian distribution to be kept.
max_gm_samples (int):
Maximum number of samples to use during GMM fit.
"""
self._max_clusters = max_clusters
self._weight_threshold = weight_threshold
self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples

def _fit_continuous(self, column_name, raw_column_data):
"""Train Bayesian GMM for continuous column."""
if self._max_gm_samples <= raw_column_data.shape[0]:
raw_column_data = np.random.choice(
raw_column_data,
size=self._max_gm_samples,
replace=False
)
gm = BayesianGaussianMixture(
self._max_clusters,
weight_concentration_prior_type='dirichlet_process',
Expand Down
9 changes: 7 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import warnings

import numpy as np
Expand Down Expand Up @@ -129,12 +130,15 @@ class CTGANSynthesizer(BaseSynthesizer):
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
data_transformer_params (dict):
Dictionary of parameters for ``DataTransformer`` initialization.
"""

def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
discriminator_decay=1e-6, batch_size=500, discriminator_steps=1,
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True):
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True,
data_transformer_params={}):

assert batch_size % 2 == 0

Expand Down Expand Up @@ -163,6 +167,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di

self._device = torch.device(device)

self._data_transformer_params = copy.deepcopy(data_transformer_params)
self._transformer = None
self._data_sampler = None
self._generator = None
Expand Down Expand Up @@ -290,7 +295,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
DeprecationWarning
)

self._transformer = DataTransformer()
self._transformer = DataTransformer(**self._data_transformer_params)
self._transformer.fit(train_data, discrete_columns)

train_data = self._transformer.transform(train_data)
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():

with pytest.raises(ValueError):
ctgan.sample(1, 'discrete', "d")


def test_ctgan_data_transformer_params():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.

data = pd.DataFrame({
'continuous': np.random.random(1000)
})

ctgan = CTGANSynthesizer(epochs=1, data_transformer_params={'max_gm_samples': 100})
ctgan.fit(data, [])

assert ctgan._transformer._max_gm_samples == 100