diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index 9f373ac9..e4606ccc 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -19,7 +19,7 @@ class DataTransformer(object): Discrete columns are encoded using a scikit-learn OneHotEncoder. """ - def __init__(self, max_clusters=10, weight_threshold=0.005): + def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None): """Create a data transformer. Args: @@ -27,12 +27,21 @@ def __init__(self, max_clusters=10, weight_threshold=0.005): Maximum number of Gaussian distributions in Bayesian GMM. weight_threshold (float): Weight threshold for a Gaussian distribution to be kept. + max_gm_samples (int): + Maximum number of samples to use during GMM fit. """ self._max_clusters = max_clusters self._weight_threshold = weight_threshold + self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples def _fit_continuous(self, column_name, raw_column_data): """Train Bayesian GMM for continuous column.""" + if self._max_gm_samples <= raw_column_data.shape[0]: + raw_column_data = np.random.choice( + raw_column_data, + size=self._max_gm_samples, + replace=False + ) gm = BayesianGaussianMixture( self._max_clusters, weight_concentration_prior_type='dirichlet_process', diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index d280e72f..646bc79a 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -1,3 +1,4 @@ +import copy import warnings import numpy as np @@ -129,12 +130,15 @@ class CTGANSynthesizer(BaseSynthesizer): Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. + data_transformer_params (dict): + Dictionary of parameters for ``DataTransformer`` initialization. """ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, - log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True): + log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True, + data_transformer_params={}): assert batch_size % 2 == 0 @@ -163,6 +167,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._device = torch.device(device) + self._data_transformer_params = copy.deepcopy(data_transformer_params) self._transformer = None self._data_sampler = None self._generator = None @@ -290,7 +295,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): DeprecationWarning ) - self._transformer = DataTransformer() + self._transformer = DataTransformer(**self._data_transformer_params) self._transformer.fit(train_data, discrete_columns) train_data = self._transformer.transform(train_data) diff --git a/tests/integration/test_ctgan.py b/tests/integration/test_ctgan.py index d84ffdcf..8180f762 100644 --- a/tests/integration/test_ctgan.py +++ b/tests/integration/test_ctgan.py @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions(): with pytest.raises(ValueError): ctgan.sample(1, 'discrete', "d") + + +def test_ctgan_data_transformer_params(): + data = pd.DataFrame({ + 'continuous': np.random.random(1000) + }) + + ctgan = CTGANSynthesizer(epochs=1, data_transformer_params={'max_gm_samples': 100}) + ctgan.fit(data, []) + + assert ctgan._transformer._max_gm_samples == 100