diff --git a/src/dynamic_routing_analysis/decoding_utils.py b/src/dynamic_routing_analysis/decoding_utils.py index 9f787f1..f5d16ea 100644 --- a/src/dynamic_routing_analysis/decoding_utils.py +++ b/src/dynamic_routing_analysis/decoding_utils.py @@ -7,13 +7,24 @@ import pandas as pd import upath import xarray as xr +import zarr from sklearn.metrics import balanced_accuracy_score, classification_report from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import RobustScaler, StandardScaler +import dynamic_routing_analysis as dra from dynamic_routing_analysis import data_utils, spike_utils +# Dump the dictionary to the Zarr file +def dump_dict_to_zarr(group, data): + for key, value in data.items(): + if isinstance(value, dict): + subgroup = group.create_group(key) + dump_dict_to_zarr(subgroup, value) + else: + group[key] = value + # 'linearSVC' or 'LDA' or 'RandomForest' def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',crossval_index=None,labels_as_index=False): #helper function to decode labels from input data using different decoder models @@ -822,7 +833,7 @@ def decode_context_from_units_all_timebins(session,params): # incorporate additional parameters # add option to decode from timebins # add option to use inputs with top decoding weights (use_coefs) -def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None): +def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None,use_zarr=False): decoder_results={} @@ -1089,14 +1100,21 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units= print(f'finished {session_id} {aa}') #save results - (upath.UPath(savepath) / f"{session_id}_{filename}").write_bytes( - pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL) - ) + if use_zarr==False: + (upath.UPath(savepath) / f"{session_id}_{filename}.pkl").write_bytes( + pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL) + ) + else: + # Create a Zarr group + zarr_file = zarr.open(savepath / (filename + '.zarr'), mode='w') + + dump_dict_to_zarr(zarr_file, decoder_results) + print(f'finished {session_id}') # print(f'time elapsed: {time.time()-start_time}') -def concat_decoder_results(files,savepath=None,return_table=True,single_session=False): +def concat_decoder_results(files,savepath=None,return_table=True,single_session=False,use_zarr=False): use_half_shifts=False n_repeats=25 @@ -1252,15 +1270,29 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session= linear_shift_df=pd.DataFrame(linear_shift_dict) - if savepath is not None: - try: - if single_session: - linear_shift_df.to_csv(os.path.join(savepath,session_id+'_linear_shift_decoding_results.csv')) - else: - linear_shift_df.to_csv(os.path.join(savepath,'all_linear_shift_decoding_results.csv')) - except Exception as e: - print(e) - print('error saving linear shift df') + + if use_zarr==True and single_session==True: + results={ + session_id:{ + 'linear_shift_summary_table':linear_shift_dict, + }, + } + + zarr_file = zarr.open(files, mode='w') + + dump_dict_to_zarr(zarr_file, results) + + elif use_zarr==False: + if savepath is not None: + try: + if single_session: + linear_shift_df.to_csv(os.path.join(savepath,session_id+'_linear_shift_decoding_results.csv')) + else: + linear_shift_df.to_csv(os.path.join(savepath,'all_linear_shift_decoding_results.csv')) + except Exception as e: + print(e) + print('error saving linear shift df') + if return_table: return linear_shift_df @@ -1380,7 +1412,7 @@ def compute_significant_decoding_by_area(all_decoder_results): return all_frac_sig_df,all_diff_from_null_df -def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False): +def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False,use_zarr=False): #load sessions as we go @@ -1876,6 +1908,12 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un print('failed to load session ',session_id,': ',e) continue + decoder_confidence_versus_response_type_dict=decoder_confidence_versus_response_type.copy() + decoder_confidence_dprime_by_block_dict=decoder_confidence_dprime_by_block.copy() + decoder_confidence_by_switch_dict=decoder_confidence_by_switch.copy() + decoder_confidence_versus_trials_since_rewarded_target_dict=decoder_confidence_versus_trials_since_rewarded_target.copy() + decoder_confidence_before_after_target_dict=decoder_confidence_before_after_target.copy() + decoder_confidence_versus_response_type=pd.DataFrame(decoder_confidence_versus_response_type) decoder_confidence_dprime_by_block=pd.DataFrame(decoder_confidence_dprime_by_block) decoder_confidence_by_switch=pd.DataFrame(decoder_confidence_by_switch) @@ -1895,17 +1933,34 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un else: temp_session_str='' - decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False) - decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False) - decoder_confidence_by_switch.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False) - decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False) - decoder_confidence_before_after_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False) - - decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl')) - decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl')) - decoder_confidence_by_switch.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl')) - decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl')) - decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl')) + if use_zarr==True and single_session==True: + results={ + session_id:{ + 'decoder_confidence_versus_response_type'+n_units_str:decoder_confidence_versus_response_type_dict, + 'decoder_confidence_dprime_by_block'+n_units_str:decoder_confidence_dprime_by_block_dict, + 'decoder_confidence_by_switch'+n_units_str:decoder_confidence_by_switch_dict, + 'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str:decoder_confidence_versus_trials_since_rewarded_target_dict, + 'decoder_confidence_before_after_target'+n_units_str:decoder_confidence_before_after_target_dict, + }, + } + + zarr_file = zarr.open(files, mode='w') + + dump_dict_to_zarr(zarr_file, results) + + elif use_zarr==False: + + decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False) + decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False) + decoder_confidence_by_switch.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False) + decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False) + decoder_confidence_before_after_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False) + + decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl')) + decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl')) + decoder_confidence_by_switch.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl')) + decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl')) + decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl')) if return_table: return decoder_confidence_versus_response_type,decoder_confidence_dprime_by_block,decoder_confidence_by_switch,decoder_confidence_versus_trials_since_rewarded_target,decoder_confidence_before_after_target