From fda028aa4812285d78a521a60ded6831fac6f2c8 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Sat, 25 Nov 2023 13:18:35 -0500 Subject: [PATCH] subset frac --- zoobot/shared/load_predictions.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/zoobot/shared/load_predictions.py b/zoobot/shared/load_predictions.py index db9876cf..874ae9f9 100644 --- a/zoobot/shared/load_predictions.py +++ b/zoobot/shared/load_predictions.py @@ -192,7 +192,7 @@ def convert_halfprecision_cols(df): return df -def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False): +def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False, subset_frac=None): """ Load predictions (or representations) saved as hdf5 into pd.DataFrame with id_str and label_cols columns @@ -208,9 +208,11 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False): _type_: _description_ """ galaxy_id_df, predictions, label_cols = load_hdf5s(hdf5_locs) + logging.info('HDF5s loaded.') predictions = predictions.squeeze() - + + if len(predictions.shape) > 2: if drop_extra_dims: predictions = predictions[:, :, 0] @@ -221,6 +223,14 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False): I suggest using load_hdf5s directly to work with np.arrays, not with DataFrame - see docstring' ) prediction_df = pd.DataFrame(data=predictions, columns=label_cols) + + if subset_frac is not None: + logging.warning('Selecting a random subset: {}'.format(subset_frac)) + prediction_df = prediction_df.sample(frac=subset_frac, random_state=42) + + + del predictions + logging.info('Saving') # copy over metadata (indices will align) prediction_df['id_str'] = galaxy_id_df['id_str'] prediction_df['hdf5_loc'] = galaxy_id_df['hdf5_loc']