diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index ab4d74507..d1dca9b9f 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -341,18 +341,14 @@ def update_transformers(self, table_name, column_name_to_transformer): def _store_and_convert_original_cols(self, data): list_of_changed_tables = [] for table, dataframe in data.items(): - data_columns = dataframe.columns - col_name_mapping = {str(col): col for col in data_columns} - reverse_col_name_mapping = {col: str(col) for col in data_columns} - self._original_table_columns[table] = col_name_mapping - dataframe = dataframe.rename(columns=reverse_col_name_mapping) - for column in data_columns: + self._original_table_columns[table] = dataframe.columns + for column in dataframe.columns: if isinstance(column, int): + dataframe.columns = dataframe.columns.astype(str) list_of_changed_tables.append(table) break data[table] = dataframe - return list_of_changed_tables def _transform_helper(self, data): @@ -396,7 +392,7 @@ def preprocess(self, data): raise e for table in list_of_changed_tables: - data[table] = data[table].rename(columns=self._original_table_columns[table]) + data[table].columns = self._original_table_columns[table] return processed_data @@ -481,6 +477,19 @@ def fit(self, data): processed_data = self.preprocess(data) self._print(text='\n', end='') self.fit_processed_data(processed_data) + from rdt.transformers import AnonymizedFaker + for table, synthesizer in self._table_synthesizers.items(): + for column, transformer in synthesizer._data_processor._hyper_transformer.field_transformers.items(): + if isinstance(transformer, AnonymizedFaker): + new_seed = self._set_faker_seed(column, table) + self._table_synthesizers[table]._data_processor._hyper_transformer.field_transformers[column].faker.seed_instance(new_seed) + + def _set_faker_seed(self, column_name, table): + import hashlib + print(column_name, table) + hash_value = column_name + table + hash_value = int(hashlib.sha256(hash_value.encode('utf-8')).hexdigest(), 16) + return hash_value % ((2**32) - 1) # maximum value for a seed def reset_sampling(self): """Reset the sampling to the state that was left right after fitting.""" @@ -528,16 +537,9 @@ def sample(self, scale=1.0): total_columns += len(table.columns) table_columns = getattr(self, '_original_table_columns', {}) - for table in sampled_data: - table_data = sampled_data[table][self.get_metadata().get_column_names(table)] if table in table_columns: - if isinstance(table_columns[table], dict): - table_data = table_data.rename(columns=table_columns[table]) - else: - table_data.columns = table_columns[table] - - sampled_data[table] = table_data + sampled_data[table].columns = table_columns[table] SYNTHESIZER_LOGGER.info({ 'EVENT': 'Sample',