Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when sampling if original dataset contains null values #413

Closed
srinify opened this issue Oct 29, 2024 · 1 comment
Closed

Error when sampling if original dataset contains null values #413

srinify opened this issue Oct 29, 2024 · 1 comment
Labels
resolution:WAI The software is working as intended

Comments

@srinify
Copy link

srinify commented Oct 29, 2024

Environment Details

  • CTGAN version: 0.10.2 (latest)

Error Description

If you fit a CTGAN model on data that contains null values, a ValueError will be thrown during sampling.

Internal Colab Notebook to reproduce

Code

from sdv.datasets.demo import download_demo
from ctgan import CTGAN

# This dataset has null values
data, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

ctgan2 = CTGAN(epochs=10)
ctgan2.fit(real_data)
ctgan2.sample(100)

Error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-38-cd6011a2da19>](https://localhost:8080/#) in <cell line: 1>()
----> 1 synthetic_data2 = ctgan2.sample(100)

6 frames
[/usr/local/lib/python3.10/dist-packages/ctgan/synthesizers/base.py](https://localhost:8080/#) in wrapper(self, *args, **kwargs)
     48     def wrapper(self, *args, **kwargs):
     49         if self.random_states is None:
---> 50             return function(self, *args, **kwargs)
     51 
     52         else:

[/usr/local/lib/python3.10/dist-packages/ctgan/synthesizers/ctgan.py](https://localhost:8080/#) in sample(self, n, condition_column, condition_value)
    514         data = data[:n]
    515 
--> 516         return self._transformer.inverse_transform(data)
    517 
    518     def set_device(self, device):

[/usr/local/lib/python3.10/dist-packages/ctgan/data_transformer.py](https://localhost:8080/#) in inverse_transform(self, data, sigmas)
    217             column_data = data[:, st : st + dim]
    218             if column_transform_info.column_type == 'continuous':
--> 219                 recovered_column_data = self._inverse_transform_continuous(
    220                     column_transform_info, column_data, sigmas, st
    221                 )

[/usr/local/lib/python3.10/dist-packages/ctgan/data_transformer.py](https://localhost:8080/#) in _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st)
    191     def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
    192         gm = column_transform_info.transform
--> 193         data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float)
    194         data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1)
    195         if sigmas is not None:

[/usr/local/lib/python3.10/dist-packages/pandas/core/frame.py](https://localhost:8080/#) in __init__(self, data, index, columns, dtype, copy)
    825                 )
    826             else:
--> 827                 mgr = ndarray_to_mgr(
    828                     data,
    829                     index,

[/usr/local/lib/python3.10/dist-packages/pandas/core/internals/construction.py](https://localhost:8080/#) in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
    334     )
    335 
--> 336     _check_values_indices_shape_match(values, index, columns)
    337 
    338     if typ == "array":

[/usr/local/lib/python3.10/dist-packages/pandas/core/internals/construction.py](https://localhost:8080/#) in _check_values_indices_shape_match(values, index, columns)
    418         passed = values.shape
    419         implied = (len(index), len(columns))
--> 420         raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")
    421 
    422 

ValueError: Shape of passed values is (100, 2), indices imply (100, 3)
@srinify srinify added bug Something isn't working new Label applied to new issues labels Oct 29, 2024
@srinify srinify changed the title ValueError when running sample() Error when sampling if original dataset contains null values Oct 30, 2024
@srinify
Copy link
Author

srinify commented Oct 30, 2024

Underlying issue is that CTGAN expects training data to not contain null values. I've opened a new feature request to surface a more relevant error during fit: #414

Closing as working as intended.

@srinify srinify added resolution:WAI The software is working as intended and removed bug Something isn't working new Label applied to new issues labels Oct 30, 2024
@srinify srinify closed this as completed Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
resolution:WAI The software is working as intended
Projects
None yet
Development

No branches or pull requests

1 participant