-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataTransformer init parameters #146
base: main
Are you sure you want to change the base?
Conversation
a4c4d5b
to
96a6321
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I think this looks good. @pvk-developer @amontanez24 what do you think?
@@ -184,3 +184,14 @@ def test_wrong_sampling_conditions(): | |||
|
|||
with pytest.raises(ValueError): | |||
ctgan.sample(1, 'discrete', "d") | |||
|
|||
|
|||
def test_ctgan_data_transformer_params(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous
column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.
ctgan/synthesizers/ctgan.py
Outdated
@@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns): | |||
if invalid_columns: | |||
raise ValueError('Invalid columns found: {}'.format(invalid_columns)) | |||
|
|||
def fit(self, train_data, discrete_columns=tuple(), epochs=None): | |||
def fit(self, train_data, discrete_columns=tuple(), epochs=None, | |||
data_transformer_params={}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data_transformer_params
should be moved to the __init__
and be asigned as self.data_transformer_params
. (Use deepcopy
if needed).
ctgan/data_transformer.py
Outdated
|
||
def _fit_continuous(self, column_name, raw_column_data): | ||
"""Train Bayesian GMM for continuous column.""" | ||
if self._max_gm_samples <= raw_column_data.shape[0]: | ||
raw_column_data = np.random.choice(raw_column_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that when it comes to this kind of line breaking this indentation is better:
raw_column_data = np.random.choice(
raw_column_data,
size=self._max_gm_samples,
replace=False
)
@fealho @pvk-developer |
@npatki not sure what you want to do with this? |
Meanwhile the library code has changed so the PR should be updated. For example, the Also, I wonder if ClusterBasedNormalizer could not be optionally replaced by a power transform, which might be faster (although it might impact the quality of the generated data), see sdv-dev/RDT#613 |
This PR solve issue #7, it allows two things:
DataTransformer
throughCTGANSynthesizer.fit
(and so to change other parameters asmax_clusters
).