diff --git a/src/dynamic_routing_analysis/data_utils.py b/src/dynamic_routing_analysis/data_utils.py index 4e52f84..311ed9d 100644 --- a/src/dynamic_routing_analysis/data_utils.py +++ b/src/dynamic_routing_analysis/data_utils.py @@ -7,6 +7,11 @@ import pandas as pd import pynwb +vid_angle_npc_names={ + 'behavior':'side', + 'face':'front', + 'eye':'eye', +} def load_trials_or_units(session, table_name): # convenience function to load trials or units from cache if available, @@ -43,15 +48,14 @@ def load_trials_or_units(session, table_name): return table -def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_s3=True): +def load_facemap_data(session,session_info=None,trials=None,vid_angle=None,keep_n_SVDs=500,use_s3=True): # function to load facemap data from s3 or local cache - vid_angle_npc_names={ - 'behavior':'side', - 'face':'front', - 'eye':'eye', - } + if not vid_angle: + raise ValueError("vid_angle must be specified") if isinstance(session, pynwb.NWBFile): + if trials is None: + trials = session.trials[:] if not any("facemap" in k for k in session.processing["behavior"].data_interfaces.keys()): raise AttributeError( f"Facemap data not found in {session.session_id} NWB file" @@ -206,8 +210,10 @@ def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_ return mean_trial_behav_SVD #mean_trial_behav_motion -def load_LP_data(session, trials, vid_angle, LP_parts_to_keep=None): - +def load_LP_data(session, trials=None, vid_angle=None, LP_parts_to_keep=None): + if not vid_angle: + raise ValueError("vid_angle must be specified") + def zscore(x): return (x - np.nanmean(x)) / np.nanstd(x) @@ -237,22 +243,48 @@ def part_info(part, df, temp_error, pca_error): if LP_parts_to_keep is None: LP_parts_to_keep = ['ear_base_l', 'jaw', 'nose_tip', 'whisker_pad_l_side'] - vid_angle_npc_names = { + vid_angle_idx = { 'behavior': 0, 'face': 3, } - - df = session._LPFaceParts[vid_angle_npc_names[vid_angle]][:] - df_temp_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 1][:] - df_pca_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 2][:] - cam_frames = df['timestamps'].values.astype('float') - - LP_traces = [] - for part_no, part_name in enumerate(LP_parts_to_keep): - x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'), - df_pca_error[part_name].values.astype('float')) - LP_traces.append(x) - LP_traces.append(y) + camera_idx = vid_angle_idx[vid_angle] + if isinstance(session, pynwb.NWBFile): + if trials is None: + trials = session.trials[:] + if not any( + k.startswith('lp_') + for k in session.processing["behavior"].data_interfaces.keys() + ): + raise AttributeError( + f"lightning_pose data not found in {session.session_id} NWB file" + ) + df = session.processing["behavior"][ + f"lp_{vid_angle_npc_names[vid_angle]}_camera" + ][:] + cam_frames = df.timestamps.values + + LP_traces = [] + for part_no, part_name in enumerate(LP_parts_to_keep): + if f"{part_name}_x" not in df.columns: + continue + x, y = part_info(part_name, df, df[f"{part_name}_error"].values.astype('float'), + df[f"{part_name}_temporal_norm"].values.astype('float')) + LP_traces.append(x) + LP_traces.append(y) + if not LP_traces: + raise ValueError(f"None of requested LP parts found for {vid_angle} camera: {LP_parts_to_keep}") + else: + df = session._LPFaceParts[camera_idx][:] + df_temp_error = session._LPFaceParts[camera_idx + 1][:] + df_pca_error = session._LPFaceParts[camera_idx + 2][:] + cam_frames = df['timestamps'].values.astype('float') + + LP_traces = [] + for part_no, part_name in enumerate(LP_parts_to_keep): + x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'), + df_pca_error[part_name].values.astype('float')) + LP_traces.append(x) + LP_traces.append(y) LP_info = { 'LP_traces': np.array(LP_traces).T