Skip to content

Commit

Permalink
Fix issue
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 25, 2024
1 parent ad7579f commit 36c958b
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,14 @@ def update_transformers(self, table_name, column_name_to_transformer):
def _store_and_convert_original_cols(self, data):
list_of_changed_tables = []
for table, dataframe in data.items():
data_columns = dataframe.columns
col_name_mapping = {str(col): col for col in data_columns}
reverse_col_name_mapping = {col: str(col) for col in data_columns}
self._original_table_columns[table] = col_name_mapping
dataframe = dataframe.rename(columns=reverse_col_name_mapping)
for column in data_columns:
self._original_table_columns[table] = dataframe.columns
for column in dataframe.columns:
if isinstance(column, int):
dataframe.columns = dataframe.columns.astype(str)
list_of_changed_tables.append(table)
break

data[table] = dataframe

return list_of_changed_tables

def _transform_helper(self, data):
Expand Down Expand Up @@ -396,7 +392,7 @@ def preprocess(self, data):
raise e

for table in list_of_changed_tables:
data[table] = data[table].rename(columns=self._original_table_columns[table])
data[table].columns = self._original_table_columns[table]

return processed_data

Expand Down Expand Up @@ -481,6 +477,19 @@ def fit(self, data):
processed_data = self.preprocess(data)
self._print(text='\n', end='')
self.fit_processed_data(processed_data)
from rdt.transformers import AnonymizedFaker
for table, synthesizer in self._table_synthesizers.items():
for column, transformer in synthesizer._data_processor._hyper_transformer.field_transformers.items():
if isinstance(transformer, AnonymizedFaker):
new_seed = self._set_faker_seed(column, table)
self._table_synthesizers[table]._data_processor._hyper_transformer.field_transformers[column].faker.seed_instance(new_seed)

def _set_faker_seed(self, column_name, table):
import hashlib
print(column_name, table)
hash_value = column_name + table
hash_value = int(hashlib.sha256(hash_value.encode('utf-8')).hexdigest(), 16)
return hash_value % ((2**32) - 1) # maximum value for a seed

def reset_sampling(self):
"""Reset the sampling to the state that was left right after fitting."""
Expand Down Expand Up @@ -528,16 +537,9 @@ def sample(self, scale=1.0):
total_columns += len(table.columns)

table_columns = getattr(self, '_original_table_columns', {})

for table in sampled_data:
table_data = sampled_data[table][self.get_metadata().get_column_names(table)]
if table in table_columns:
if isinstance(table_columns[table], dict):
table_data = table_data.rename(columns=table_columns[table])
else:
table_data.columns = table_columns[table]

sampled_data[table] = table_data
sampled_data[table].columns = table_columns[table]

SYNTHESIZER_LOGGER.info({
'EVENT': 'Sample',
Expand Down

0 comments on commit 36c958b

Please sign in to comment.