From 84191305f91a4b21e2804952e39e12b97849aa3b Mon Sep 17 00:00:00 2001 From: egmcbride Date: Wed, 9 Oct 2024 11:19:13 -0700 Subject: [PATCH] add option to input trials, units, session_info directly --- .../decoding_utils.py | 48 +++++++++++-------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index 8cefaf9..d80651f 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -822,7 +822,7 @@ def decode_context_from_units_all_timebins(session,params): # incorporate additional parameters # add option to decode from timebins # add option to use inputs with top decoding weights (use_coefs) -def decode_context_with_linear_shift(session,params): +def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None): decoder_results={} @@ -857,32 +857,39 @@ def decode_context_with_linear_shift(session,params): decoder_type=params['decoder_type'] # use_coefs=params['use_coefs'] # generate_labels=params['generate_labels'] - session_id=str(session.id) + + + if session_info is not None: + session_info=npc_lims.get_session_info(session) - session_info=npc_lims.get_session_info(session) + session_id=str(session_info.id) - ##TODO: change data loading to use helper functions - #load trials and units - try: - trials=pd.read_parquet( - npc_lims.get_cache_path('trials',session_id) - ) - except: - print('no cached trials table, using npc_sessions') - trials = session.trials[:] + ##Option to input session or trials/units/session_info directly + ##note: inputting session may not work with Code Ocean + + if session is not None: + try: + trials=pd.read_parquet( + npc_lims.get_cache_path('trials',session_id) + ) + except: + print('no cached trials table, using npc_sessions') + trials = session.trials[:] if exclude_cue_trials: trials=trials.query('is_reward_scheduled==False').reset_index() if input_data_type=='spikes': #make data array - try: - units=pd.read_parquet( - npc_lims.get_cache_path('units',session_id) - ) - except: - print('no cached units table, using npc_sessions') - units = session.units[:] + if session is not None: + try: + units=pd.read_parquet( + npc_lims.get_cache_path('units',session_id) + ) + except: + print('no cached units table, using npc_sessions') + units = session.units[:] + #add probe to structure name structure_probe=spike_utils.get_structure_probe(units) for uu, unit in units.iterrows(): @@ -891,6 +898,7 @@ def decode_context_with_linear_shift(session,params): #make trial data array for baseline activity trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, spikes_time_before, spikes_time_after, spikes_binsize) + ### TODO: update to work with code ocean elif input_data_type=='facemap': # mean_trial_behav_SVD,mean_trial_behav_motion = load_facemap_data(session,session_info,trials,vid_angle) mean_trial_behav_SVD = data_utils.load_facemap_data(session,session_info,trials,vid_angle_facemotion,keep_n_SVDs) @@ -1091,7 +1099,7 @@ def decode_context_with_linear_shift(session,params): print(f'finished {session_id} {aa}') #save results - (upath.UPath(savepath) / f"{session.id}_{filename}").write_bytes( + (upath.UPath(savepath) / f"{session_id}_{filename}").write_bytes( pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL) ) print(f'finished {session_id}')