Skip to content

Commit

Permalink
subset frac
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 25, 2023
1 parent d0b7ddd commit fda028a
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions zoobot/shared/load_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def convert_halfprecision_cols(df):
return df


def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False, subset_frac=None):
"""
Load predictions (or representations) saved as hdf5 into pd.DataFrame with id_str and label_cols columns
Expand All @@ -208,9 +208,11 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
_type_: _description_
"""
galaxy_id_df, predictions, label_cols = load_hdf5s(hdf5_locs)
logging.info('HDF5s loaded.')

predictions = predictions.squeeze()



if len(predictions.shape) > 2:
if drop_extra_dims:
predictions = predictions[:, :, 0]
Expand All @@ -221,6 +223,14 @@ def single_forward_pass_hdf5s_to_df(hdf5_locs: List, drop_extra_dims=False):
I suggest using load_hdf5s directly to work with np.arrays, not with DataFrame - see docstring'
)
prediction_df = pd.DataFrame(data=predictions, columns=label_cols)

if subset_frac is not None:
logging.warning('Selecting a random subset: {}'.format(subset_frac))
prediction_df = prediction_df.sample(frac=subset_frac, random_state=42)


del predictions
logging.info('Saving')
# copy over metadata (indices will align)
prediction_df['id_str'] = galaxy_id_df['id_str']
prediction_df['hdf5_loc'] = galaxy_id_df['hdf5_loc']
Expand Down

0 comments on commit fda028a

Please sign in to comment.