diff --git a/src/iblphotometry/io.py b/src/iblphotometry/io.py index 0fe2f42..c98ba83 100644 --- a/src/iblphotometry/io.py +++ b/src/iblphotometry/io.py @@ -102,6 +102,7 @@ def from_ibl_dataframe( channel_column: str = 'name', channel_names: list[str] | None = None, rename: dict | None = None, + validate: bool = True, ) -> dict: """main function to convert to analysis ready format @@ -134,8 +135,8 @@ def from_ibl_dataframe( channel_names = ibl_df[channel_column].unique() # drop empty acquisition channels - to_drop = ['None', ''] - channel_names = [ch for ch in channel_names if ch not in to_drop] + if validate: + ibl_df = validate_ibl_dataframe(ibl_df) dfs = {} for channel in channel_names: @@ -222,7 +223,42 @@ def from_raw_neurophotometrics_file( """ -def validate_ibl_dataframe(df: pd.DataFrame) -> pd.DataFrame: ... +def validate_ibl_dataframe(ibl_df: pd.DataFrame) -> pd.DataFrame: + # for now, check if number of frames are equal and drop the longer one + # to be expanded into a full panderas check + + # 1) drop first frame if invalid + first_frame_name = ibl_df.iloc[0]['name'] + if '+' in first_frame_name or first_frame_name == '': + ibl_df = ibl_df.drop(index=0) + + # 2) if unequal number of frames per acquisition channel, drop excess frames + frame_counts = ibl_df.groupby('name')['times'].count() + if not np.all(frame_counts.values == frame_counts.values[0]): + # find shortest + dfs = [] + min_frames = frame_counts.iloc[np.argmin(frame_counts)] + for name, group in ibl_df.groupby('name'): + dfs.append(group.iloc[:min_frames]) + n_dropped = group.shape[0] - min_frames + warnings.warn(f'dropping {n_dropped} frames for channel {name}') + + ibl_df = pd.concat(dfs).sort_index() + + # 3) panderas validation + data_columns = infer_data_columns(ibl_df) + schema_ibl_df = pandera.DataFrameSchema( + columns=dict( + times=pandera.Column(pandera.Float64), + # valid=pandera.Column(pandera.Bool), # NOTE as of now, it seems like valid is an optional column found in alejandro but not in carolina + wavelength=pandera.Column(pandera.Float64), + name=pandera.Column(pandera.String), + color=pandera.Column(pandera.String), + **{k: pandera.Column(pandera.Float64) for k in data_columns}, + ) + ) + ibl_df = schema_ibl_df.validate(ibl_df) + return ibl_df def _validate_neurophotometrics_df(