Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTransformer init parameters #146

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

FlorentRamb
Copy link

This PR solve issue #7, it allows two things:

  1. ability to fit gaussian mixtures on a subsample (help to scale with big datasets while losing only little accuracy)
  2. ability to pass init arguments to the DataTransformer through CTGANSynthesizer.fit (and so to change other parameters as max_clusters).

@CLAassistant
Copy link

CLAassistant commented Apr 16, 2021

CLA assistant check
All committers have signed the CLA.

@FlorentRamb FlorentRamb changed the title Gh 7 feat gmparams DataTransformer init parameters Apr 19, 2021
Copy link
Member

@fealho fealho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think this looks good. @pvk-developer @amontanez24 what do you think?

ctgan/data_transformer.py Outdated Show resolved Hide resolved
ctgan/data_transformer.py Outdated Show resolved Hide resolved
@@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():

with pytest.raises(ValueError):
ctgan.sample(1, 'discrete', "d")


def test_ctgan_data_transformer_params():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.

@pvk-developer pvk-developer self-requested a review April 21, 2021 16:53
@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
if invalid_columns:
raise ValueError('Invalid columns found: {}'.format(invalid_columns))

def fit(self, train_data, discrete_columns=tuple(), epochs=None):
def fit(self, train_data, discrete_columns=tuple(), epochs=None,
data_transformer_params={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data_transformer_params should be moved to the __init__ and be asigned as self.data_transformer_params. (Use deepcopy if needed).


def _fit_continuous(self, column_name, raw_column_data):
"""Train Bayesian GMM for continuous column."""
if self._max_gm_samples <= raw_column_data.shape[0]:
raw_column_data = np.random.choice(raw_column_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that when it comes to this kind of line breaking this indentation is better:

    raw_column_data = np.random.choice(
        raw_column_data,
        size=self._max_gm_samples,
        replace=False
    )

@candalfigomoro
Copy link

@fealho @pvk-developer
Can we merge this? It's basically impossible to fit the CTGAN on a large dataset because the gaussian mixture is a huge bottleneck (even using dozens of CPUs). This PR would allow to speedup the gaussian mixture step. Thanks

@fealho
Copy link
Member

fealho commented Feb 9, 2023

@npatki not sure what you want to do with this?

@candalfigomoro
Copy link

Meanwhile the library code has changed so the PR should be updated.

For example, the _fit_continuous method now receives a pandas DataFrame, so np.random.choice() can be replaced by something like data = data.sample(self._max_gm_samples, replace=False, random_state=SEED).

Also, I wonder if ClusterBasedNormalizer could not be optionally replaced by a power transform, which might be faster (although it might impact the quality of the generated data), see sdv-dev/RDT#613

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants