diff --git a/gnn/preprocessing.py b/gnn/preprocessing.py index aef7b881..e4114611 100644 --- a/gnn/preprocessing.py +++ b/gnn/preprocessing.py @@ -136,6 +136,12 @@ def main(): df_nodes_test = get_nodes(df_bank_test) df_edges_test = get_edges(df_bank_test, aggregated=True, directional=False) + df_nodes_train.reset_index(inplace=True) + node_to_index = pd.Series(df_nodes_train.index, index=df_nodes_train['account']).to_dict() + df_edges_train['src'] = df_edges_train['src'].map(node_to_index) + df_edges_train['dst'] = df_edges_train['dst'].map(node_to_index) + df_nodes_train.drop(columns=['account'], inplace=True) + df_nodes_test.reset_index(inplace=True) node_to_index = pd.Series(df_nodes_test.index, index=df_nodes_test['account']).to_dict() df_edges_test['src'] = df_edges_test['src'].map(node_to_index)