From e96dfe0d9be2eb17c0c76c49584307b61522ba6a Mon Sep 17 00:00:00 2001 From: Brian Healy <42810347+bfhealy@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:48:25 -0500 Subject: [PATCH] Fix dtype bug in combine-preds field lists (#577) * Fix dtype bug in combine-preds field lists * Remove code.interact * Use fields_to_do in loop instead of 'not' fields_to_list --- tools/combine_preds.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tools/combine_preds.py b/tools/combine_preds.py index 2ff4f97f..58358785 100755 --- a/tools/combine_preds.py +++ b/tools/combine_preds.py @@ -105,7 +105,7 @@ def combine_preds( os.makedirs(path_to_preds / combined_preds_dirname, exist_ok=True) done_fields = [ - str(x).split("/")[-1].split(".")[0] + int(str(x).split("/")[-1].split(".")[0].split("_")[1]) for x in (path_to_preds / combined_preds_dirname).glob("field_*.parquet") ] fields_to_list = done_fields.copy() @@ -113,9 +113,15 @@ def combine_preds( if fields_to_exclude is not None: fields_to_list.extend(fields_to_exclude) - fields_to_do = list(set(fields_dnn_dict).difference(done_fields)) + fields_to_do = list( + set([int(x.split("_")[1]) for x in fields_dnn_dict.keys()]).difference( + fields_to_list + ) + ) fields_to_list.extend(fields_to_do) + # Use set to drop duplicate fields before sorting + fields_to_list = list(set(fields_to_list)) fields_to_list.sort() if save: @@ -127,8 +133,11 @@ def combine_preds( counter = 0 print(f"Processing {len(fields_to_do)} fields/files...") + # Reformat fields in field_N format to match filenames + fields_to_do = [f"field_{x}" for x in fields_to_do] + for field in fields_dnn_dict.keys(): - if field not in done_fields: + if field in fields_to_do: if field in fields_xgb_dict.keys(): try: dnn_preds = read_parquet(