Skip to content

Commit

Permalink
add option to save zarr files
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Oct 14, 2024
1 parent e41fedf commit dd6f9aa
Showing 1 changed file with 81 additions and 26 deletions.
107 changes: 81 additions & 26 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@
import pandas as pd
import upath
import xarray as xr
import zarr
from sklearn.metrics import balanced_accuracy_score, classification_report
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import RobustScaler, StandardScaler

import dynamic_routing_analysis as dra
from dynamic_routing_analysis import data_utils, spike_utils


# Dump the dictionary to the Zarr file
def dump_dict_to_zarr(group, data):
for key, value in data.items():
if isinstance(value, dict):
subgroup = group.create_group(key)
dump_dict_to_zarr(subgroup, value)
else:
group[key] = value

# 'linearSVC' or 'LDA' or 'RandomForest'
def decoder_helper(input_data,labels,decoder_type='linearSVC',crossval='5_fold',crossval_index=None,labels_as_index=False):
#helper function to decode labels from input data using different decoder models
Expand Down Expand Up @@ -822,7 +833,7 @@ def decode_context_from_units_all_timebins(session,params):
# incorporate additional parameters
# add option to decode from timebins
# add option to use inputs with top decoding weights (use_coefs)
def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None):
def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None,use_zarr=False):

decoder_results={}

Expand Down Expand Up @@ -1089,14 +1100,21 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=

print(f'finished {session_id} {aa}')
#save results
(upath.UPath(savepath) / f"{session_id}_{filename}").write_bytes(
pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL)
)
if use_zarr==False:
(upath.UPath(savepath) / f"{session_id}_{filename}.pkl").write_bytes(
pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL)
)
else:
# Create a Zarr group
zarr_file = zarr.open(savepath / (filename + '.zarr'), mode='w')

dump_dict_to_zarr(zarr_file, decoder_results)

print(f'finished {session_id}')
# print(f'time elapsed: {time.time()-start_time}')


def concat_decoder_results(files,savepath=None,return_table=True,single_session=False):
def concat_decoder_results(files,savepath=None,return_table=True,single_session=False,use_zarr=False):

use_half_shifts=False
n_repeats=25
Expand Down Expand Up @@ -1252,15 +1270,29 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=


linear_shift_df=pd.DataFrame(linear_shift_dict)
if savepath is not None:
try:
if single_session:
linear_shift_df.to_csv(os.path.join(savepath,session_id+'_linear_shift_decoding_results.csv'))
else:
linear_shift_df.to_csv(os.path.join(savepath,'all_linear_shift_decoding_results.csv'))
except Exception as e:
print(e)
print('error saving linear shift df')

if use_zarr==True and single_session==True:
results={
session_id:{
'linear_shift_summary_table':linear_shift_dict,
},
}

zarr_file = zarr.open(files, mode='w')

dump_dict_to_zarr(zarr_file, results)

elif use_zarr==False:
if savepath is not None:
try:
if single_session:
linear_shift_df.to_csv(os.path.join(savepath,session_id+'_linear_shift_decoding_results.csv'))
else:
linear_shift_df.to_csv(os.path.join(savepath,'all_linear_shift_decoding_results.csv'))
except Exception as e:
print(e)
print('error saving linear shift df')

if return_table:
return linear_shift_df

Expand Down Expand Up @@ -1380,7 +1412,7 @@ def compute_significant_decoding_by_area(all_decoder_results):
return all_frac_sig_df,all_diff_from_null_df


def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False):
def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False,use_zarr=False):

#load sessions as we go

Expand Down Expand Up @@ -1876,6 +1908,12 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
print('failed to load session ',session_id,': ',e)
continue

decoder_confidence_versus_response_type_dict=decoder_confidence_versus_response_type.copy()
decoder_confidence_dprime_by_block_dict=decoder_confidence_dprime_by_block.copy()
decoder_confidence_by_switch_dict=decoder_confidence_by_switch.copy()
decoder_confidence_versus_trials_since_rewarded_target_dict=decoder_confidence_versus_trials_since_rewarded_target.copy()
decoder_confidence_before_after_target_dict=decoder_confidence_before_after_target.copy()

decoder_confidence_versus_response_type=pd.DataFrame(decoder_confidence_versus_response_type)
decoder_confidence_dprime_by_block=pd.DataFrame(decoder_confidence_dprime_by_block)
decoder_confidence_by_switch=pd.DataFrame(decoder_confidence_by_switch)
Expand All @@ -1895,17 +1933,34 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
else:
temp_session_str=''

decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False)
decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False)
decoder_confidence_by_switch.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False)
decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False)
decoder_confidence_before_after_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False)

decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl'))
decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl'))
decoder_confidence_by_switch.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl'))
decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl'))
decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl'))
if use_zarr==True and single_session==True:
results={
session_id:{
'decoder_confidence_versus_response_type'+n_units_str:decoder_confidence_versus_response_type_dict,
'decoder_confidence_dprime_by_block'+n_units_str:decoder_confidence_dprime_by_block_dict,
'decoder_confidence_by_switch'+n_units_str:decoder_confidence_by_switch_dict,
'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str:decoder_confidence_versus_trials_since_rewarded_target_dict,
'decoder_confidence_before_after_target'+n_units_str:decoder_confidence_before_after_target_dict,
},
}

zarr_file = zarr.open(files, mode='w')

dump_dict_to_zarr(zarr_file, results)

elif use_zarr==False:

decoder_confidence_versus_response_type.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.csv'),index=False)
decoder_confidence_dprime_by_block.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.csv'),index=False)
decoder_confidence_by_switch.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.csv'),index=False)
decoder_confidence_versus_trials_since_rewarded_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.csv'),index=False)
decoder_confidence_before_after_target.to_csv(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.csv'),index=False)

decoder_confidence_versus_response_type.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_response_type'+n_units_str+'.pkl'))
decoder_confidence_dprime_by_block.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_dprime_by_block'+n_units_str+'.pkl'))
decoder_confidence_by_switch.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_by_switch'+n_units_str+'.pkl'))
decoder_confidence_versus_trials_since_rewarded_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_versus_trials_since_rewarded_target'+n_units_str+'.pkl'))
decoder_confidence_before_after_target.to_pickle(os.path.join(savepath,temp_session_str+'decoder_confidence_before_after_target'+n_units_str+'.pkl'))

if return_table:
return decoder_confidence_versus_response_type,decoder_confidence_dprime_by_block,decoder_confidence_by_switch,decoder_confidence_versus_trials_since_rewarded_target,decoder_confidence_before_after_target

0 comments on commit dd6f9aa

Please sign in to comment.