From 1a844a7065cd5821c216a7ecaafeb183e0c3e0e4 Mon Sep 17 00:00:00 2001 From: Kevin M Jablonka <32935233+kjappelbaum@users.noreply.github.com> Date: Sat, 10 Aug 2024 08:53:13 -0700 Subject: [PATCH] fix: in post-processing of split rewrite file (#531) --- data/postprocess_split.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/data/postprocess_split.py b/data/postprocess_split.py index ef8999b0b..7ef160a54 100644 --- a/data/postprocess_split.py +++ b/data/postprocess_split.py @@ -4,6 +4,9 @@ It also merges files that have been created by `dask` if they are chunks of one large dataset. This script needs to be run after the splitting script. + +An independent check (that does not rewrite files is `check_smiles_split.py`; +this checks also for compliance with the predetermined files) """ import os @@ -143,7 +146,7 @@ def process_file(file: Union[str, Path], id_cols): # appear multiple times ddf = read_ddf(file) ddf = ddf.drop_duplicates(subset=id_cols) - ddf.to_csv("data_clean-{*}.csv", index=False) + ddf.to_csv(os.path.join(dir, "data_clean-{*}.csv"), index=False) merge_files(dir) else: @@ -154,7 +157,7 @@ def process_file(file: Union[str, Path], id_cols): for id in id_cols: test_smiles.extend(df[df["split"] == "test"][id].to_list()) val_smiles.extend(df[df["split"] == "valid"][id].to_list()) - + df.drop_duplicates(subset=[id], inplace=True) test_smiles = set(test_smiles) val_smiles = set(val_smiles) @@ -184,7 +187,7 @@ def process_file(file: Union[str, Path], id_cols): len(this_test_smiles.intersection(this_val_smiles)) == 0 ), f"Smiles in test and valid for {id}" - df.to_csv("data_clean.csv", index=False) + df.to_csv(os.path.join(dir, "data_clean.csv"), index=False) def process_all_files(data_dir):